In [None]:
from sequence_aligner.labelset import LabelSet
from sequence_aligner.dataset import TrainingDataset
from sequence_aligner.containers import TrainingBatch
import json
from transformers import XLMRobertaTokenizerFast, XLMRobertaForTokenClassification, TrainingArguments, Trainer
from torch.utils.data import DataLoader
import torch
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

In [None]:
data = json.load(open('/home/luca/Scrivania/CheckThat/annotated_train.json'))

train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')
label_set = LabelSet(labels=["Appeal_to_Authority", "Appeal_to_Popularity", "Appeal_to_Values", "Appeal_to_Fear-Prejudice", "Flag_Waving",
                             "Causal_Oversimplification", "False_Dilemma-No_Choice", "Consequential_Oversimplification", "Straw_Man",
                             "Red_Herring", "Whataboutism", "Slogans", "Appeal_to_Time", "Conversation_Killer", "Loaded_Language",
                             "Repetition", "Exaggeration-Minimisation", "Obfuscation-Vagueness-Confusion", "Name_Calling-Labeling", "Doubt",
                             "Guilt_by_Association", "Appeal_to_Hypocrisy", "Questioning_the_Reputation"])

train_dataset = TrainingDataset(data=train_data,tokenizer=tokenizer,label_set=label_set)
val_dataset = TrainingDataset(data=val_data,tokenizer=tokenizer,label_set=label_set)

train_loader = DataLoader(train_dataset, collate_fn=TrainingBatch, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, collate_fn=TrainingBatch, batch_size=4, shuffle=False)

In [None]:
model = XLMRobertaForTokenClassification.from_pretrained("xlm-roberta-base", num_labels=len(train_dataset.label_set.ids_to_label.values()))
lr = 5e-6
optimizer = AdamW(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
epochs = 5

model.to(device)
criterion.to(device)

# Training loop
model.train()
for epoch in range(epochs):
    tot_loss_train = 0
    tot_acc_train = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_masks'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs['logits'].transpose(1, 2), labels)
        tot_loss_train += loss.item()

        preds = torch.argmax(outputs['logits'], 2)
        correct_predictions = torch.sum(preds == labels)
        tot_acc_train += correct_predictions.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    tot_loss_val = 0
    tot_acc_val = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_masks'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            
            loss = criterion(outputs['logits'].transpose(1, 2), labels)
            tot_loss_val += loss.item()

            preds = torch.argmax(outputs['logits'], 2)
            correct_predictions = torch.sum(preds == labels)
            tot_acc_val += correct_predictions.item()

    # Print statistics
    print(f'Epoch: {epoch + 1}/{epochs}')
    print(f'Train Loss: {tot_loss_train / len(train_loader.dataset)}')
    print(f'Train Accuracy: {tot_acc_train / len(train_loader.dataset) / 100}')
    print(f'Validation Loss: {tot_loss_val / len(val_loader.dataset)}')
    print(f'Validation Accuracy: {tot_acc_val / len(val_loader.dataset) / 100}')