In [1]:
from collections import deque

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import nltk
from torch.utils.data import Dataset
import pickle

from utils.utils import *
from utils.label_decoding import *
from utils.HierarchicalLoss import *

# SubTask 1

In [2]:
class DataSet(Dataset):
    def __init__(self, df, labels_at_level, features_file):
        super(DataSet, self).__init__()
        self.data_df = df
        self.labels_at_level = labels_at_level
        self.features_file = features_file
        self.features_dict = None
        with open(features_file, 'rb') as f:
            self.features_dict = pickle.load(f)
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, idx):
        id = self.data_df.iloc[idx]['id']
        text = self.data_df.iloc[idx]['cleaned_text']
        level_1_target = self.encode(self.data_df.iloc[idx]['Level 1'], 1)
        level_2_target = self.encode(self.data_df.iloc[idx]['Level 2'], 2)
        level_3_target = self.encode(self.data_df.iloc[idx]['Level 3'], 3)
        level_4_target = self.encode(self.data_df.iloc[idx]['Level 4'], 4)
        level_5_target = self.encode(self.data_df.iloc[idx]['Level 5'], 5)
            
        
        return {'id': id,
                'text': text, 
                'text_features': self.features_dict[id],
                'level_1_target': level_1_target, 
                'level_2_target': level_2_target, 
                'level_3_target': level_3_target, 
                'level_4_target': level_4_target, 
                'level_5_target': level_5_target }

    def encode(self, labels, level):
        level_ = f'Level {level}'
        
        target = torch.zeros(len(self.labels_at_level[level_])+1)
        
        for label in labels:
            label_idx = self.labels_at_level[level_][label]
            target[label_idx] = 1
        
        if len(labels) == 0:
            target[-1] = 1
        
        return target

In [3]:
class TestDataSet(Dataset):
    def __init__(self, df, features_file):
        super(TestDataSet, self).__init__()
        self.data_df = df
        self.features_file = features_file
        self.features_dict = None
        with open(features_file, 'rb') as f:
            self.features_dict = pickle.load(f)
    
    def __len__(self):
        return len(self.data_df)
    
    def __getitem__(self, idx):
        id = self.data_df.iloc[idx]['id']
        text = self.data_df.iloc[idx]['cleaned_text']
        
        return {'id': id,
                'text': text, 
                'text_features': self.features_dict[id] }

In [4]:
def evaluate_model(model, dataloader, pred_file_path, gold_file_path, 
                   evaluator_script_path, id2leaf_label, format=None,validation=False, threshold=0.3):
    model.eval()
    predictions = []
    
    HL = HierarchicalLoss(id2label=id2label_1, hierarchical_labels=hierarchy_1, persuasion_techniques=persuasion_techniques_1, device=device)
    total_loss = 0
    
    
    with torch.no_grad():
        
        for batch in tqdm(dataloader):
            if not isinstance(batch['id'], list):
                ids = batch['id'].detach().numpy().tolist()
            else:
                ids = batch['id']
        
            embeddings = batch['text_features']
            embeddings = embeddings.to(device)
            pred_1, pred_2, pred_3, pred_4, pred_5 = model(embeddings)
            
            if validation:
                y_1, y_2, y_3 = batch['level_1_target'], batch['level_2_target'], batch['level_3_target']
                y_4, y_5 = batch['level_4_target'], batch['level_5_target']
                
                y_1, y_2, y_3, y_4, y_5 = y_1.to(device), y_2.to(device), y_3.to(device), y_4.to(device), y_5.to(device)
                
                dloss = HL.calculate_dloss([pred_1, pred_2, pred_3, pred_4, pred_5], [y_1, y_2, y_3, y_4, y_5])
                lloss = HL.calculate_lloss([pred_1, pred_2, pred_3, pred_4, pred_5], [y_1, y_2, y_3, y_4, y_5])
                
                total_loss += (dloss + lloss).detach().cpu().item()
                
            pred_3 = (pred_3.cpu().detach().numpy() > threshold).astype(int)
            pred_4 = (pred_4.cpu().detach().numpy() > threshold).astype(int)
            pred_5 = (pred_5.cpu().detach().numpy() > threshold).astype(int)
            
            predictions += get_labels(id2leaf_label, ids, pred_3, pred_4, pred_5, format)

        # Writing JSON data
        with open(pred_file_path, 'w') as f:
            json.dump(predictions, f, indent=4)
        
        if gold_file_path is None:
            return
            
        command = [
                "python3", evaluator_script_path,
                "--gold_file_path", gold_file_path,
                "--pred_file_path", pred_file_path
        ]
        
        result = subprocess.run(command, capture_output=True, text=True)
        
        if result.returncode == 0:
            print("Output:\n", result.stdout)
        else:
            print("Error:\n", result.stderr)
            
        if validation:
            return total_loss / len(dataloader)

In [5]:
from torch.utils.data import DataLoader

train_data = process_json('./semeval2024_dev_release/subtask1/train.json', techniques_to_level_1, hierarchy_1)
# val_data = 
validation_data = process_json('./semeval2024_dev_release/subtask1/validation.json', techniques_to_level_1, hierarchy_1)

training_dataset = DataSet(train_data, indexed_persuasion_techniques_1, 
                           './TextFeatures/subtask1a/mBERT/train_text_features.pkl')
validation_dataset = DataSet(validation_data, indexed_persuasion_techniques_1, 
                             './TextFeatures/subtask1a/mBERT/validation_text_features.pkl')

batch_size = 256

train_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)


In [6]:
alpha = 0.7764469620072395
batch_size = 256
beta = 0.95
beta1 = 0.9094170903394552
learning_rate = 3.906930058023181e-05
threshold = 0.8256232754296409

In [7]:
num_epochs = 100
device = get_device()
device = torch.device("cpu")

HL = HierarchicalLoss(id2label=id2label_1, hierarchical_labels=hierarchy_1,
                      persuasion_techniques=persuasion_techniques_1,
                      device=device, alpha=alpha, beta=beta, threshold=threshold)



Using MPS


### Model

In [8]:
from modules.nn.mBERT import mBERT

model = mBERT()
model.to(device)

mBERT(
  (linear_level1): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.15, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.15, inplace=False)
    (6): Linear(in_features=512, out_features=128, bias=True)
    (7): ReLU()
  )
  (linear_level2): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.15, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.15, inplace=False)
    (6): Linear(in_features=512, out_features=128, bias=True)
    (7): ReLU()
  )
  (linear_level3): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.15, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.15, inplace=False)
    (6): Linear(in_fea

In [10]:
from tqdm import tqdm
import json
import subprocess

optimizer = torch.optim.Adam(model.parameters(), lr=0.0004306099142228309, betas=(0.8923286832300139, 0.999))
min_val_loss = float('inf')
best_epoch = None

train_loss_history = []
val_loss_history = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, batch in enumerate(train_dataloader):
        # input_ids, masks, type_ids  = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
        # encoded_inputs = batch['encoded_input']
        # input_ids, masks = encoded_inputs['input_ids'], encoded_inputs['attention_mask'], 
        # type_ids = encoded_inputs['token_type_ids']
        y_1, y_2, y_3 = batch['level_1_target'], batch['level_2_target'], batch['level_3_target']
        y_4, y_5 = batch['level_4_target'], batch['level_5_target']
        
        # input_ids = input_ids.squeeze().to(device)
        # masks = masks.squeeze().to(device)
        # type_ids = type_ids.squeeze().to(device)
        # encoded_inputs = encoded_inputs.to(device)
        
        # print(type(batch['text_features']))
        
        embeddings = batch['text_features']
        embeddings = embeddings.to(device)
        y_1, y_2, y_3, y_4, y_5 = y_1.to(device), y_2.to(device), y_3.to(device), y_4.to(device), y_5.to(device)
        
        
        optimizer.zero_grad()
        pred_1, pred_2, pred_3, pred_4, pred_5 = model(embeddings)
        # loss_ = loss(pred_1, y_1) + loss(pred_2, y_2) + loss(pred_3, y_3) + loss(pred_4, y_4) + loss(pred_5, y_5)
        
        dloss = HL.calculate_dloss([pred_1, pred_2, pred_3, pred_4, pred_5], [y_1, y_2, y_3, y_4, y_5])
        lloss = HL.calculate_lloss([pred_1, pred_2, pred_3, pred_4, pred_5], [y_1, y_2, y_3, y_4, y_5])

        total_loss = lloss + dloss
        # loss_.backward()
        
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.detach().item()
        
        if batch_idx % 20 == 19:
            print(f"[{epoch + 1}, {batch_idx + 1}] loss: {running_loss / 20:.3f}")
            running_loss = 0.0
    
    
    
    running_loss /= len(train_dataloader)
    
    val_pred_file_path = './Predictions/val_predictions_subtask1.json'
    val_gold_file_path = './semeval2024_dev_release/subtask1/validation.json'
    evaluator_script = './scorer-baseline/subtask_1_2a.py'
    
    validation_loss = evaluate_model(model, validation_dataloader, val_pred_file_path, 
                                     val_gold_file_path, evaluator_script,id2leaf_label,
                                     validation=True)
    
    train_loss_history.append(running_loss)
    val_loss_history.append(validation_loss)
    
    if validation_loss < min_val_loss:
        min_val_loss = validation_loss
        best_epoch = epoch
        # torch.save(model.state_dict(), './models/best_subtask1_baseline.pth')

print(f'best validation loss occurred in epoch {best_epoch} ')
        

[1, 20] loss: 2361.058


100%|██████████| 2/2 [00:00<00:00, 15.57it/s]


Output:
 f1_h=0.42557	prec_h=0.54414	rec_h=0.34942
[2, 20] loss: 2276.223


100%|██████████| 2/2 [00:00<00:00, 15.57it/s]


Output:
 f1_h=0.35059	prec_h=0.59283	rec_h=0.24889
[3, 20] loss: 2262.271


100%|██████████| 2/2 [00:00<00:00, 15.78it/s]


Output:
 f1_h=0.39125	prec_h=0.60463	rec_h=0.28919
[4, 20] loss: 2188.651


100%|██████████| 2/2 [00:00<00:00, 15.71it/s]


Output:
 f1_h=0.31909	prec_h=0.64770	rec_h=0.21169
[5, 20] loss: 2168.533


100%|██████████| 2/2 [00:00<00:00, 15.81it/s]


Output:
 f1_h=0.38519	prec_h=0.64749	rec_h=0.27414
[6, 20] loss: 2142.512


100%|██████████| 2/2 [00:00<00:00, 15.96it/s]


Output:
 f1_h=0.45445	prec_h=0.56039	rec_h=0.38220
[7, 20] loss: 2096.203


100%|██████████| 2/2 [00:00<00:00, 16.16it/s]


Output:
 f1_h=0.45274	prec_h=0.58811	rec_h=0.36802
[8, 20] loss: 2085.452


100%|██████████| 2/2 [00:00<00:00, 16.00it/s]


Output:
 f1_h=0.40843	prec_h=0.63842	rec_h=0.30027
[9, 20] loss: 2068.049


100%|██████████| 2/2 [00:00<00:00, 16.19it/s]


Output:
 f1_h=0.41743	prec_h=0.63139	rec_h=0.31178
[10, 20] loss: 2055.489


100%|██████████| 2/2 [00:00<00:00, 16.13it/s]


Output:
 f1_h=0.47423	prec_h=0.55562	rec_h=0.41364
[11, 20] loss: 2051.780


100%|██████████| 2/2 [00:00<00:00, 16.07it/s]


Output:
 f1_h=0.44603	prec_h=0.62781	rec_h=0.34588
[12, 20] loss: 2057.175


100%|██████████| 2/2 [00:00<00:00, 15.92it/s]


Output:
 f1_h=0.49623	prec_h=0.55046	rec_h=0.45173
[13, 20] loss: 2028.081


100%|██████████| 2/2 [00:00<00:00, 15.89it/s]


Output:
 f1_h=0.45399	prec_h=0.60058	rec_h=0.36492
[14, 20] loss: 2008.560


100%|██████████| 2/2 [00:00<00:00, 16.15it/s]


Output:
 f1_h=0.45885	prec_h=0.61288	rec_h=0.36670
[15, 20] loss: 1988.420


100%|██████████| 2/2 [00:00<00:00, 15.95it/s]


Output:
 f1_h=0.50911	prec_h=0.55486	rec_h=0.47033
[16, 20] loss: 2004.613


100%|██████████| 2/2 [00:00<00:00, 16.17it/s]


Output:
 f1_h=0.48731	prec_h=0.56297	rec_h=0.42958
[17, 20] loss: 2015.623


100%|██████████| 2/2 [00:00<00:00, 16.05it/s]


Output:
 f1_h=0.47754	prec_h=0.58683	rec_h=0.40257
[18, 20] loss: 2004.623


100%|██████████| 2/2 [00:00<00:00, 16.54it/s]


Output:
 f1_h=0.50077	prec_h=0.60025	rec_h=0.42958
[19, 20] loss: 1996.413


100%|██████████| 2/2 [00:00<00:00, 16.25it/s]


Output:
 f1_h=0.50840	prec_h=0.55503	rec_h=0.46900
[20, 20] loss: 1994.206


100%|██████████| 2/2 [00:00<00:00, 16.43it/s]


Output:
 f1_h=0.48608	prec_h=0.56738	rec_h=0.42516
[21, 20] loss: 1991.283


100%|██████████| 2/2 [00:00<00:00, 16.27it/s]


Output:
 f1_h=0.53380	prec_h=0.56569	rec_h=0.50531
[22, 20] loss: 1983.247


100%|██████████| 2/2 [00:00<00:00, 16.06it/s]


Output:
 f1_h=0.51082	prec_h=0.59149	rec_h=0.44951
[23, 20] loss: 1979.621


100%|██████████| 2/2 [00:00<00:00, 15.54it/s]


Output:
 f1_h=0.53300	prec_h=0.53867	rec_h=0.52746
[24, 20] loss: 1956.797


100%|██████████| 2/2 [00:00<00:00, 15.76it/s]


Output:
 f1_h=0.47574	prec_h=0.58801	rec_h=0.39947
[25, 20] loss: 1966.459


100%|██████████| 2/2 [00:00<00:00, 16.06it/s]


Output:
 f1_h=0.53333	prec_h=0.55449	rec_h=0.51373
[26, 20] loss: 1962.987


100%|██████████| 2/2 [00:00<00:00, 16.45it/s]


Output:
 f1_h=0.52594	prec_h=0.56842	rec_h=0.48937
[27, 20] loss: 1925.512


100%|██████████| 2/2 [00:00<00:00, 15.60it/s]


Output:
 f1_h=0.54212	prec_h=0.52909	rec_h=0.55580
[28, 20] loss: 1946.205


100%|██████████| 2/2 [00:00<00:00, 16.39it/s]


Output:
 f1_h=0.53128	prec_h=0.55263	rec_h=0.51151
[29, 20] loss: 1946.517


100%|██████████| 2/2 [00:00<00:00, 15.71it/s]


Output:
 f1_h=0.51688	prec_h=0.57235	rec_h=0.47121
[30, 20] loss: 1934.522


100%|██████████| 2/2 [00:00<00:00, 16.37it/s]


Output:
 f1_h=0.52071	prec_h=0.58732	rec_h=0.46767
[31, 20] loss: 1903.017


100%|██████████| 2/2 [00:00<00:00, 16.40it/s]


Output:
 f1_h=0.49436	prec_h=0.60657	rec_h=0.41718
[32, 20] loss: 1918.734


100%|██████████| 2/2 [00:00<00:00, 15.58it/s]


Output:
 f1_h=0.50683	prec_h=0.53133	rec_h=0.48450
[33, 20] loss: 1904.203


100%|██████████| 2/2 [00:00<00:00, 15.89it/s]


Output:
 f1_h=0.54469	prec_h=0.57738	rec_h=0.51550
[34, 20] loss: 1916.937


100%|██████████| 2/2 [00:00<00:00, 16.24it/s]


Output:
 f1_h=0.51063	prec_h=0.57774	rec_h=0.45748
[35, 20] loss: 1880.805


100%|██████████| 2/2 [00:00<00:00, 15.85it/s]


Output:
 f1_h=0.54100	prec_h=0.53863	rec_h=0.54340
[36, 20] loss: 1882.450


100%|██████████| 2/2 [00:00<00:00, 16.25it/s]


Output:
 f1_h=0.52564	prec_h=0.59288	rec_h=0.47210
[37, 20] loss: 1929.659


100%|██████████| 2/2 [00:00<00:00, 16.13it/s]


Output:
 f1_h=0.53739	prec_h=0.54348	rec_h=0.53144
[38, 20] loss: 1859.022


100%|██████████| 2/2 [00:00<00:00, 15.81it/s]


Output:
 f1_h=0.53699	prec_h=0.55419	rec_h=0.52081
[39, 20] loss: 1874.086


100%|██████████| 2/2 [00:00<00:00, 16.00it/s]


Output:
 f1_h=0.54406	prec_h=0.53990	rec_h=0.54827
[40, 20] loss: 1877.268


100%|██████████| 2/2 [00:00<00:00, 15.75it/s]


Output:
 f1_h=0.54984	prec_h=0.54400	rec_h=0.55580
[41, 20] loss: 1839.572


100%|██████████| 2/2 [00:00<00:00, 15.77it/s]


Output:
 f1_h=0.53027	prec_h=0.55675	rec_h=0.50620
[42, 20] loss: 1858.456


100%|██████████| 2/2 [00:00<00:00, 16.20it/s]


Output:
 f1_h=0.53526	prec_h=0.58619	rec_h=0.49247
[43, 20] loss: 1835.846


100%|██████████| 2/2 [00:00<00:00, 15.88it/s]


Output:
 f1_h=0.54999	prec_h=0.52608	rec_h=0.57617
[44, 20] loss: 1800.334


100%|██████████| 2/2 [00:00<00:00, 16.02it/s]


Output:
 f1_h=0.52205	prec_h=0.56173	rec_h=0.48760
[45, 20] loss: 1828.131


100%|██████████| 2/2 [00:00<00:00, 15.90it/s]


Output:
 f1_h=0.52745	prec_h=0.58718	rec_h=0.47874
[46, 20] loss: 1807.787


100%|██████████| 2/2 [00:00<00:00, 15.77it/s]


Output:
 f1_h=0.54604	prec_h=0.52861	rec_h=0.56466
[47, 20] loss: 1808.861


100%|██████████| 2/2 [00:00<00:00, 15.66it/s]


Output:
 f1_h=0.53764	prec_h=0.55259	rec_h=0.52347
[48, 20] loss: 1801.288


100%|██████████| 2/2 [00:00<00:00, 15.61it/s]


Output:
 f1_h=0.52926	prec_h=0.57995	rec_h=0.48671
[49, 20] loss: 1780.822


100%|██████████| 2/2 [00:00<00:00, 16.09it/s]


Output:
 f1_h=0.53803	prec_h=0.52816	rec_h=0.54827
[50, 20] loss: 1796.036


100%|██████████| 2/2 [00:00<00:00, 15.92it/s]


Output:
 f1_h=0.54396	prec_h=0.54409	rec_h=0.54384
[51, 20] loss: 1720.315


100%|██████████| 2/2 [00:00<00:00, 15.87it/s]


Output:
 f1_h=0.51830	prec_h=0.57717	rec_h=0.47033
[52, 20] loss: 1711.140


100%|██████████| 2/2 [00:00<00:00, 16.19it/s]


Output:
 f1_h=0.51927	prec_h=0.55708	rec_h=0.48627
[53, 20] loss: 1734.596


100%|██████████| 2/2 [00:00<00:00, 16.28it/s]


Output:
 f1_h=0.54614	prec_h=0.51765	rec_h=0.57795
[54, 20] loss: 1758.797


100%|██████████| 2/2 [00:00<00:00, 16.48it/s]


Output:
 f1_h=0.53106	prec_h=0.56176	rec_h=0.50354
[55, 20] loss: 1734.780


100%|██████████| 2/2 [00:00<00:00, 16.23it/s]


Output:
 f1_h=0.52729	prec_h=0.55179	rec_h=0.50487
[56, 20] loss: 1742.948


100%|██████████| 2/2 [00:00<00:00, 16.39it/s]


Output:
 f1_h=0.53458	prec_h=0.52521	rec_h=0.54429
[57, 20] loss: 1729.754


100%|██████████| 2/2 [00:00<00:00, 15.71it/s]


Output:
 f1_h=0.54190	prec_h=0.55618	rec_h=0.52834
[58, 20] loss: 1738.156


100%|██████████| 2/2 [00:00<00:00, 16.34it/s]


Output:
 f1_h=0.54682	prec_h=0.54537	rec_h=0.54827
[59, 20] loss: 1678.507


100%|██████████| 2/2 [00:00<00:00, 16.21it/s]


Output:
 f1_h=0.51916	prec_h=0.57403	rec_h=0.47387
[60, 20] loss: 1684.663


100%|██████████| 2/2 [00:00<00:00, 16.10it/s]


Output:
 f1_h=0.54302	prec_h=0.52185	rec_h=0.56599
[61, 20] loss: 1664.171


100%|██████████| 2/2 [00:00<00:00, 16.35it/s]


Output:
 f1_h=0.52592	prec_h=0.55632	rec_h=0.49867
[62, 20] loss: 1674.884


100%|██████████| 2/2 [00:00<00:00, 16.45it/s]


Output:
 f1_h=0.55185	prec_h=0.50515	rec_h=0.60806
[63, 20] loss: 1671.549


100%|██████████| 2/2 [00:00<00:00, 16.56it/s]


Output:
 f1_h=0.51659	prec_h=0.56033	rec_h=0.47919
[64, 20] loss: 1613.896


100%|██████████| 2/2 [00:00<00:00, 16.41it/s]


Output:
 f1_h=0.55532	prec_h=0.53440	rec_h=0.57795
[65, 20] loss: 1627.136


100%|██████████| 2/2 [00:00<00:00, 16.38it/s]


Output:
 f1_h=0.52936	prec_h=0.55368	rec_h=0.50709
[66, 20] loss: 1625.361


100%|██████████| 2/2 [00:00<00:00, 16.56it/s]


Output:
 f1_h=0.53776	prec_h=0.51740	rec_h=0.55979
[67, 20] loss: 1607.981


100%|██████████| 2/2 [00:00<00:00, 14.92it/s]


Output:
 f1_h=0.52404	prec_h=0.55105	rec_h=0.49956
[68, 20] loss: 1573.503


100%|██████████| 2/2 [00:00<00:00, 16.30it/s]


Output:
 f1_h=0.54100	prec_h=0.54902	rec_h=0.53322
[69, 20] loss: 1590.811


100%|██████████| 2/2 [00:00<00:00, 16.08it/s]


Output:
 f1_h=0.51841	prec_h=0.58150	rec_h=0.46767
[70, 20] loss: 1588.467


100%|██████████| 2/2 [00:00<00:00, 16.22it/s]


Output:
 f1_h=0.52148	prec_h=0.50626	rec_h=0.53764
[71, 20] loss: 1593.084


100%|██████████| 2/2 [00:00<00:00, 16.40it/s]


Output:
 f1_h=0.55410	prec_h=0.55374	rec_h=0.55447
[72, 20] loss: 1539.544


100%|██████████| 2/2 [00:00<00:00, 16.29it/s]


Output:
 f1_h=0.51691	prec_h=0.55928	rec_h=0.48051
[73, 20] loss: 1544.549


100%|██████████| 2/2 [00:00<00:00, 16.49it/s]


Output:
 f1_h=0.55416	prec_h=0.52674	rec_h=0.58459
[74, 20] loss: 1560.376


100%|██████████| 2/2 [00:00<00:00, 15.96it/s]


Output:
 f1_h=0.52454	prec_h=0.56457	rec_h=0.48981
[75, 20] loss: 1542.690


100%|██████████| 2/2 [00:00<00:00, 15.66it/s]


Output:
 f1_h=0.54249	prec_h=0.52011	rec_h=0.56687
[76, 20] loss: 1562.019


100%|██████████| 2/2 [00:00<00:00, 16.34it/s]


Output:
 f1_h=0.54346	prec_h=0.56047	rec_h=0.52746
[77, 20] loss: 1527.017


100%|██████████| 2/2 [00:00<00:00, 16.20it/s]


Output:
 f1_h=0.55581	prec_h=0.52028	rec_h=0.59655
[78, 20] loss: 1541.039


100%|██████████| 2/2 [00:00<00:00, 16.30it/s]


Output:
 f1_h=0.53649	prec_h=0.53102	rec_h=0.54207
[79, 20] loss: 1490.422


100%|██████████| 2/2 [00:00<00:00, 15.99it/s]


Output:
 f1_h=0.50542	prec_h=0.57008	rec_h=0.45394
[80, 20] loss: 1493.900


100%|██████████| 2/2 [00:00<00:00, 16.20it/s]


Output:
 f1_h=0.54205	prec_h=0.51139	rec_h=0.57662
[81, 20] loss: 1458.962


100%|██████████| 2/2 [00:00<00:00, 16.29it/s]


Output:
 f1_h=0.53985	prec_h=0.52559	rec_h=0.55492
[82, 20] loss: 1482.395


100%|██████████| 2/2 [00:00<00:00, 16.24it/s]


Output:
 f1_h=0.53927	prec_h=0.52449	rec_h=0.55492
[83, 20] loss: 1442.866


100%|██████████| 2/2 [00:00<00:00, 16.30it/s]


Output:
 f1_h=0.52452	prec_h=0.54628	rec_h=0.50443
[84, 20] loss: 1423.350


100%|██████████| 2/2 [00:00<00:00, 16.32it/s]


Output:
 f1_h=0.53286	prec_h=0.53608	rec_h=0.52967
[85, 20] loss: 1429.754


100%|██████████| 2/2 [00:00<00:00, 16.04it/s]


Output:
 f1_h=0.53318	prec_h=0.54762	rec_h=0.51949
[86, 20] loss: 1397.815


100%|██████████| 2/2 [00:00<00:00, 15.96it/s]


Output:
 f1_h=0.54089	prec_h=0.54282	rec_h=0.53897
[87, 20] loss: 1413.906


100%|██████████| 2/2 [00:00<00:00, 15.67it/s]


Output:
 f1_h=0.50345	prec_h=0.56856	rec_h=0.45173
[88, 20] loss: 1441.654


100%|██████████| 2/2 [00:00<00:00, 16.01it/s]


Output:
 f1_h=0.53893	prec_h=0.54616	rec_h=0.53189
[89, 20] loss: 1389.432


100%|██████████| 2/2 [00:00<00:00, 15.97it/s]


Output:
 f1_h=0.53705	prec_h=0.54140	rec_h=0.53277
[90, 20] loss: 1391.930


100%|██████████| 2/2 [00:00<00:00, 15.90it/s]


Output:
 f1_h=0.54607	prec_h=0.52675	rec_h=0.56687
[91, 20] loss: 1380.029


100%|██████████| 2/2 [00:00<00:00, 16.32it/s]


Output:
 f1_h=0.53937	prec_h=0.53284	rec_h=0.54606
[92, 20] loss: 1378.125


100%|██████████| 2/2 [00:00<00:00, 16.07it/s]


Output:
 f1_h=0.52558	prec_h=0.53750	rec_h=0.51417
[93, 20] loss: 1384.092


100%|██████████| 2/2 [00:00<00:00, 16.23it/s]


Output:
 f1_h=0.52955	prec_h=0.52944	rec_h=0.52967
[94, 20] loss: 1336.486


100%|██████████| 2/2 [00:00<00:00, 15.79it/s]


Output:
 f1_h=0.53876	prec_h=0.52712	rec_h=0.55093
[95, 20] loss: 1361.782


100%|██████████| 2/2 [00:00<00:00, 15.88it/s]


Output:
 f1_h=0.54210	prec_h=0.53692	rec_h=0.54739
[96, 20] loss: 1348.853


100%|██████████| 2/2 [00:00<00:00, 16.16it/s]


Output:
 f1_h=0.53036	prec_h=0.54511	rec_h=0.51639
[97, 20] loss: 1316.197


100%|██████████| 2/2 [00:00<00:00, 14.50it/s]


Output:
 f1_h=0.53313	prec_h=0.51224	rec_h=0.55580
[98, 20] loss: 1345.953


100%|██████████| 2/2 [00:00<00:00, 15.52it/s]


Output:
 f1_h=0.51372	prec_h=0.52624	rec_h=0.50177
[99, 20] loss: 1299.464


100%|██████████| 2/2 [00:00<00:00, 16.37it/s]


Output:
 f1_h=0.53163	prec_h=0.53679	rec_h=0.52657
[100, 20] loss: 1265.277


100%|██████████| 2/2 [00:00<00:00, 16.02it/s]


Output:
 f1_h=0.52133	prec_h=0.52053	rec_h=0.52214

best validation loss occurred in epoch 28 


### Evaluation

#### Bulgarian

In [11]:
from tqdm import tqdm
import json
import subprocess

bulgarian_pred_file_path = './Predictions/bulgarian_predictions_subtask1.txt'
bulgarian_gold_file_path = './test_labels_ar_bg_md_version2/test_subtask1_bg.json'
evaluator_script = './scorer-baseline/subtask_1_2a.py'

bg_test_data = process_test_json(bulgarian_gold_file_path)


bg_test_dataset = TestDataSet(bg_test_data, './TextFeatures/subtask1a/mBERT/bg_test_text_features.pkl')
bg_test_dataloader = DataLoader(bg_test_dataset, batch_size=64, shuffle=True)

evaluate_model(model, bg_test_dataloader, bulgarian_pred_file_path, bulgarian_gold_file_path,
               evaluator_script, id2leaf_label, validation=False, threshold=0.3)

100%|██████████| 7/7 [00:00<00:00, 168.21it/s]


Output:
 f1_h=0.40142	prec_h=0.42295	rec_h=0.38197


#### North Macedonian

In [12]:
macedonian_pred_file_path = './Predictions/macedonian_predictions_subtask1.txt'
macedonian_gold_file_path = './test_labels_ar_bg_md_version2/test_subtask1_md.json'

md_test_data = process_test_json(macedonian_gold_file_path)

md_test_dataset = TestDataSet(md_test_data, './TextFeatures/subtask1a/mBERT/md_test_text_features.pkl')
md_test_dataloader = DataLoader(md_test_dataset, batch_size=64, shuffle=True)

evaluate_model(model, md_test_dataloader, macedonian_pred_file_path, macedonian_gold_file_path,
               evaluator_script, id2leaf_label, validation=False, threshold=0.3)

100%|██████████| 5/5 [00:00<00:00, 169.97it/s]


Output:
 f1_h=0.40495	prec_h=0.49643	rec_h=0.34194


#### Arabian

In [13]:
arabian_pred_file_path = './Predictions/arabian_predictions_subtask1.txt'
arabian_gold_file_path = './test_labels_ar_bg_md_version2/test_subtask1_ar.json'

ar_test_data = process_test_json(arabian_gold_file_path)

ar_test_dataset = TestDataSet(ar_test_data, './TextFeatures/subtask1a/mBERT/ar_test_text_features.pkl')
ar_test_dataloader = DataLoader(ar_test_dataset, batch_size=128, shuffle=True)

evaluate_model(model, ar_test_dataloader, arabian_pred_file_path, arabian_gold_file_path, evaluator_script, 
               id2leaf_label, format=5, validation=False, threshold=0.3)

100%|██████████| 1/1 [00:00<00:00, 86.21it/s]


Output:
 f1_h=0.30508	prec_h=0.38503	rec_h=0.25263


#### English

In [14]:
en_pred_file_path = './Predictions/en_predictions_subtask1.txt'

en_test_data = process_test_json('./test_data/english/en_subtask1_test_unlabeled.json')

en_test_dataset = TestDataSet(en_test_data, './TextFeatures/subtask1a/mBERT/en_test_text_features.pkl')
en_test_dataloader = DataLoader(en_test_dataset, batch_size=16, shuffle=True)

evaluate_model(model, en_test_dataloader, en_pred_file_path, None, evaluator_script, id2leaf_label, validation=False)

100%|██████████| 94/94 [00:00<00:00, 176.52it/s]
