# Imports and installation of libraries

In [None]:
#INSTALL LIBRARIES------------------------------------------
!pip install transformers scikit-learn datasets wandb  torch_geometric word2number nltk num2words


In [None]:
#IMPORTS-----------------------------
from pprint import pprint
from datasets import load_dataset
from transformers import RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from torch.nn import Linear, ReLU
import pdb
import numpy as np, torch, random as rnd, torch.nn as nn, wandb
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import sys, os, json
from transformers import AutoModelForQuestionAnswering
from torch.nn.functional import cosine_similarity
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score,  ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import os, time
import random
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn.norm import LayerNorm, BatchNorm
import torch.nn.init as init
from torch.optim.lr_scheduler import ReduceLROnPlateau
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize

import importlib.util

# import augmentation
file_path = './1883922-augmentation.py'
spec = importlib.util.spec_from_file_location("augmentation_module", file_path)
augmentation_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(augmentation_module)

nltk.download('wordnet')
nltk.download('punkt')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


In [6]:
# Function to print a progress bar
def print_progress_bar(percentuale: float, lunghezza_barra: int = 30, text: str="") -> None:
    blocchi_compilati = int(lunghezza_barra * percentuale)
    barra = "[" + "=" * (blocchi_compilati - 1) + ">" + " " * (lunghezza_barra - blocchi_compilati) + "]"
    sys.stdout.write(f"\r{barra} {percentuale * 100:.2f}% complete " + text)
    sys.stdout.flush()

# Definition of Dataset classes

In [7]:



class NLIDatasetGnn(Dataset):

    def __init__(self, data, file_name='',load = False, adversarial=False, do_remove_stopwords=False,
                 do_remove_punctuation=False, do_use_similarities=False, base_set=True, do_lemmatization=False):
        self.sentence_info = None
        self.labels = None
        self.sentences = None
        self.load = load
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.encode_labels = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        self.model = RobertaModel.from_pretrained('roberta-base').to(self.device)
        self.distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
        self.file_name = file_name
        self.do_use_similarities = do_use_similarities
        self.do_lemmatization = do_lemmatization
        self.do_remove_punctuation = do_remove_punctuation
        self.base_set = base_set
        self.do_remove_stopwords = do_remove_stopwords
        self.adversarial = adversarial
        self.organize_data(data)
        self.tokenizer = None
        self.model = None
        self.count = 0

    def organize_data(self, data):
        info_samples = []
        sentences = []
        labels = []
        premise_hypotesis = []
        if not self.load or not os.path.isfile(self.file_name):
            for num_sample, sample in enumerate(data):
                print_progress_bar(num_sample / len(data), text=" | loading data")
               
                labels.append(self.encode_labels[sample["label"]])
                premise_hypotesis.append(sample['premise'] + " [SEP] " + sample['hypothesis'])
                sentence = []
                list_sentence_premise = [token["rawText"] for token in sample['srl']['premise']['tokens']]
                list_sentence_hypothesis = [token["rawText"] for token in sample['srl']['hypothesis']['tokens']]
                info_sample = {"ids_verbs": [], "edges": []}
                idx = 0
                edge_idx = 0
                verb_edge_idx = -1
                repeated_nodes = {}
                # premise -----------------------------
                for annotation in sample['srl']['premise']['annotations']:
                    if verb_edge_idx != -1:
                        info_sample["edges"].append((verb_edge_idx, edge_idx))
                        info_sample["edges"].append((edge_idx, verb_edge_idx))
                    verb_edge_idx = edge_idx
                    edge_idx += 1
                    verb_word_only = True
                    if len(sample['wsd']['premise']) > annotation['tokenIndex'] and sample['wsd']['premise'][annotation['tokenIndex']]['nltkSynset'] != 'O' and sample['wsd']['premise'][annotation['tokenIndex']]['pos'] == 'VERB':
                        examples = wn.synset(sample['wsd']['premise'][annotation['tokenIndex']]['nltkSynset']).examples()
                        id_verb = []
                        for example in examples:
                            if example != '':
                                id_verb.append(idx)
                                idx += 1
                                verb_word_only = False
                                sentence.append( example + ' [SEP] ')
                        info_sample["ids_verbs"].append(tuple(id_verb))
                    if verb_word_only:
                        info_sample["ids_verbs"].append((idx,))
                        idx += 1
                        sentence.append(list_sentence_premise[annotation['tokenIndex']] + ' [SEP] ')
                    for element in annotation['verbatlas']['roles']:
                        if element['span']:
                            if (element['span'][0], element['span'][1]) in repeated_nodes:
                                info_sample["edges"].append((repeated_nodes[(element['span'][0], element['span'][1])], verb_edge_idx))
                                info_sample["edges"].append((verb_edge_idx, repeated_nodes[(element['span'][0], element['span'][1])]))
                            else:
                                info_sample["edges"].append((verb_edge_idx, edge_idx))
                                info_sample["edges"].append((edge_idx, verb_edge_idx))
                                edge_idx += 1
                                idx += 1
                                sentence.append(' '.join(list_sentence_premise[element['span'][0]:element['span'][1]]) + ' [SEP] ')
                                repeated_nodes[(element['span'][0], element['span'][1])] = edge_idx
                repeated_nodes.clear()
                verb_edge_idx = -1
                # hypothesis -----------------------------
                for annotation in sample['srl']['hypothesis']['annotations']:
                    if verb_edge_idx != -1:
                        info_sample["edges"].append((verb_edge_idx, edge_idx))
                        info_sample["edges"].append((edge_idx, verb_edge_idx))
                    verb_edge_idx = edge_idx
                    edge_idx += 1
                    verb_word_only = True
                    if len(sample['wsd']['hypothesis']) > annotation['tokenIndex'] and sample['wsd']['hypothesis'][annotation['tokenIndex']]['nltkSynset'] != 'O' and sample['wsd']['hypothesis'][annotation['tokenIndex']]['pos'] == 'VERB':
                        examples = wn.synset(sample['wsd']['hypothesis'][annotation['tokenIndex']]['nltkSynset']).examples()
                        id_verb = []
                        for example in examples:
                            if example != '':
                                id_verb.append(idx)
                                idx += 1
                                verb_word_only = False
                                sentence.append(example + ' [SEP] ')
                        info_sample["ids_verbs"].append(tuple(id_verb))
                    if verb_word_only:
                        info_sample["ids_verbs"].append((idx,))
                        idx += 1
                        sentence.append(list_sentence_hypothesis[annotation['tokenIndex']] + ' [SEP] ')
                    for element in annotation['verbatlas']['roles']:
                        if element['span']:
                            if (element['span'][0], element['span'][1]) in repeated_nodes:
                                info_sample["edges"].append(
                                    (repeated_nodes[(element['span'][0], element['span'][1])], verb_edge_idx))
                                info_sample["edges"].append(
                                    (verb_edge_idx, repeated_nodes[(element['span'][0], element['span'][1])]))
                            else:
                                info_sample["edges"].append((verb_edge_idx, edge_idx))
                                info_sample["edges"].append((edge_idx, verb_edge_idx))
                                edge_idx += 1
                                idx += 1
                                sentence.append(' '.join(list_sentence_premise[element['span'][0]:element['span'][1]]) + ' [SEP] ')
                                repeated_nodes[(element['span'][0], element['span'][1])] = edge_idx
                info_sample["num_nodes"] = edge_idx
                info_sample["num_sentences"] = idx
                sentences.append(' '.join(sentence))
                info_samples.append(info_sample)
            if self.file_name != '':
                with open(self.file_name, 'w') as file:
                    json.dump({'sentences': sentences, 'info_samples': info_samples, 'labels': labels, 'premise_hypotesis': premise_hypotesis}, file)
        else:
            with open(self.file_name, 'r') as file:
                loaded_data = json.load(file)
            sentences = loaded_data['sentences']
            info_samples = loaded_data['info_samples']
            labels = loaded_data['labels']
            premise_hypotesis = loaded_data['premise_hypotesis']

        self.sentences = sentences
        self.sentence_info = info_samples
        self.labels = labels
        self.premise_hypotesis = premise_hypotesis

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.sentences[idx], self.labels[idx], self.sentence_info[idx], self.premise_hypotesis[idx]

    def get_deleted_per_epoch(self, num_epochs):
        out = self.count / num_epochs
        self.count = 0
        return out

    def collate(self, batch):
        
        inputs = []
        gold_outputs = []
        edges = []
        verb_ids = []
        offset_edges = 0
        offset_sentences = 0
        num_nodes = []
        classic_inputs = []
        for input_batch, gold_outputs_batch, sample_info_batch, premise_hypotesis in batch:
            if len(word_tokenize(input_batch)) < 500:
                inputs.append(input_batch)
                gold_outputs.append(gold_outputs_batch)
                classic_inputs.append(premise_hypotesis)
                edges.extend([(x + offset_edges, y + offset_edges) for x, y in sample_info_batch["edges"]])

                # Convertiamo da tuple a liste

                offset_edges += sample_info_batch["num_nodes"]


                verb_ids.append(sample_info_batch["ids_verbs"])
                offset_sentences += sample_info_batch["num_sentences"]
            else:
                self.count += 1
        
        inputs = self.distilbert_tokenizer(
            inputs,
            max_length=512,
            return_offsets_mapping=True,
            padding="max_length",
            return_tensors="pt",
            truncation=True
        )
        classic_inputs = self.distilbert_tokenizer(
            classic_inputs,
            max_length=512,
            return_offsets_mapping=True,
            padding="max_length",
            return_tensors="pt",
            truncation=True
        )


        gold_outputs = torch.tensor(gold_outputs, dtype=torch.long)
        edges = torch.tensor(edges, dtype=torch.long).t().contiguous()

        return inputs['input_ids'].to(self.device), inputs['attention_mask'].to(self.device), gold_outputs.to(self.device), \
            edges.to(self.device), verb_ids, classic_inputs.to(self.device)






    def get_dataloader(self, batch_size):
        return DataLoader(self, batch_size=batch_size, shuffle=True, collate_fn=self.collate)





In [8]:



class NLIDataset(Dataset):



    def __init__(self, data, file_name = '', adversarial = False, do_remove_stopwords = False,
                 do_remove_punctuation = False, do_use_similarities = False, base_set = True, do_lemmatization = False):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.encode_labels = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        self.model = RobertaModel.from_pretrained('roberta-base').to(self.device)
        self.distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
        self.file_name = file_name
        self.do_use_similarities = do_use_similarities
        self.do_lemmatization = do_lemmatization
        self.do_remove_punctuation = do_remove_punctuation
        self.base_set = base_set
        self.do_remove_stopwords = do_remove_stopwords
        self.adversarial = adversarial
        self.preprocess_function(data)
        self.tokenizer = None
        self.model = None



    def remove_stopwords(self, sample):
        stop_words = set(stopwords.words('english')) if self.do_remove_stopwords else set()
        sentence = ''
        for data in sample:
            word = data
            is_punct = False
            if self.base_set:
                is_punct = data['pos'] in ['PUNCT', 'CCONJ', 'DET', 'AUX']
                word = data['lemma']
            if word not in stop_words and not is_punct:
                next_word = word
                if self.do_lemmatization and self.base_set:
                    next_word = data['lemma']
                elif self.base_set:
                    next_word = data['text']
                sentence += next_word + ' '
        return sentence.strip()


    def preprocess_function(self, examples):
        answers = []
        premise_hypothesis = []
        ordered_similarities = []
        file_exists = self.file_name != '' and os.path.isfile(self.file_name)
        if file_exists:
            with open(self.file_name, "r") as f:
                data_loaded = json.load(f)
            premise_hypothesis = data_loaded["premise_hypotesis"]
            answers = data_loaded["labels"]
            ordered_similarities = data_loaded["similarities"] if "similarities" in data_loaded else {}

        
        else:
      
            for i,example in enumerate(examples):

                print_progress_bar(i / len(examples), text=" | preprocessing")
                if (self.do_remove_stopwords or self.do_remove_punctuation) and self.base_set:
                    premise_hypothesis.append(self.remove_stopwords(example['wsd']["premise"]) + '[SEP]' + self.remove_stopwords(example['wsd']["hypothesis"]))

                elif self.do_remove_stopwords or self.do_remove_punctuation:
                    continue
                else:
                    premise_hypothesis.append(example["premise"].strip() + '[SEP]' + example["hypothesis"].strip())

                answers.append(self.encode_labels[example["label"]] )

                if self.do_use_similarities:
                    s1 = self.embed_sentence(example["premise"].strip())
                    s2 = self.embed_sentence(example["hypothesis"].strip())
                    if self.adversarial:
                        ordered_similarities.append(cosine_similarity(s1, s2).item())
                    else:
                        ordered_similarities.append(cosine_similarity(s1, s2).item())
       
        inputs = {}
        inputs['inputs'] = premise_hypothesis
        data_to_save = {"premise_hypotesis": premise_hypothesis, "labels": answers}
        if self.do_use_similarities:
            data_to_save["similarities"] = ordered_similarities
            inputs["similarity"] = torch.tensor(ordered_similarities)
        if not file_exists and self.file_name != '':
            if not os.path.exists("data"):
                os.makedirs("data")
            with open("data/" + self.file_name, "w") as f:
                json.dump(data_to_save, f, indent=4)
        inputs["label"] = torch.tensor(answers)
        self.data = inputs


    def embed_sentence(self, sentence):
       
        inputs = self.tokenizer(sentence, return_tensors='pt', truncation=True, padding=True)
       
        with torch.no_grad():
            outputs = self.model(**inputs.to(self.device))
       
        embeddings = outputs.last_hidden_state.mean(dim=1)
        return embeddings

    def __len__(self):
        return len(self.data["label"])

    def __getitem__(self, idx):
        if "similarity" in self.data:
            return self.data['inputs'][idx], self.data["label"][idx], self.data["similarity"][idx]
        else:
            return self.data['inputs'][idx], self.data["label"][idx], torch.zeros(1)

    def collate(self, batch):
        
        x = []
        attention_mask = []
        y = []
        z = []
        for x_batch, y_batch, z_batch in batch:
            x.append(x_batch)
            y.append(y_batch)
            z.append(z_batch)

        x = self.distilbert_tokenizer(
            x,
            max_length = 512,
            return_offsets_mapping=True,
            padding='max_length',
            truncation = True,
            return_tensors="pt"
        )
        attention_mask = x['attention_mask']
        x = x['input_ids']
        y = torch.stack(y)
        z = torch.stack(z)
        
        return x.to(self.device), attention_mask.to(self.device), y.to(self.device), z.to(self.device)



    def get_dataloader(self, batch_size):
        return DataLoader(self, batch_size=batch_size, shuffle=True, collate_fn = self.collate)






# Models

In [9]:

class RobertaClassifier(nn.Module):
  def __init__(self, use_similarity, num_labels=3, dropout = 0):
    super(RobertaClassifier, self).__init__()
    self.distilbert =  AutoModel.from_pretrained("distilbert/distilbert-base-uncased")
    self.linear = nn.Linear(self.distilbert.config.hidden_size+1, 3)
    self.freeze(0)


    self.use_similarity = use_similarity
    self.initialize_weights()

  def initialize_weights(self):
   
      nn.init.kaiming_normal_(self.linear.weight, nonlinearity='leaky_relu')
      
      if self.linear.bias is not None:
          nn.init.zeros_(self.linear.bias)

  def freeze(self, epoch):
    if epoch == 1:
        freeze_until_layer = len(list(self.distilbert.parameters()))  
    elif epoch == 0:
        freeze_until_layer = None 
    else:
        freeze_until_layer = None  

    if freeze_until_layer is not None:
        
        for idx, param in enumerate(self.distilbert.parameters()):
            if idx < freeze_until_layer:
                param.requires_grad = False
            else:
                param.requires_grad = True



  def forward(self, input_ids, attention_mask, similarities):

   
    if not self.use_similarity:
        similarities = torch.zeros(input_ids.shape[0]).to('cuda' if torch.cuda.is_available() else 'cpu')

    outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
    

    pooled_output = torch.cat((outputs.last_hidden_state[:, 0, :], similarities.unsqueeze(1)), dim=1)
    

    logits = self.linear(pooled_output)

    return logits


In [10]:
class RobertaClassifierGnnLayerNorm(nn.Module):
    def __init__(self, use_similarity, num_labels=3, dropout=0):
        super(RobertaClassifierGnnLayerNorm, self).__init__()
        self.distilbert = AutoModel.from_pretrained("distilbert/distilbert-base-uncased")
        self.dropout = nn.Dropout(dropout)
        self.gnn = GATv2Conv(self.distilbert.config.hidden_size, 64, dropout=0.3)
       
        self.gnn2 = GATv2Conv(64, 32, dropout=0.3)  
        self.linear = nn.Linear(self.distilbert.config.hidden_size + 32, 64)
       
        self.classifier = nn.Linear(64, 3)
        self.initialize_weights()
        self.relu = nn.LeakyReLU()
        self.batchnorm_linear = nn.LayerNorm(64)
        self.batchnorm_gnn = LayerNorm(64) 
        self.batchnorm_gnn2 = LayerNorm(32)
        self.special_token = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased").sep_token_id

    def initialize_weights(self):
        for layer in self.children():
            if isinstance(layer, GATv2Conv):
                
                init.kaiming_normal_(layer.lin_l.weight, a=0.01)

                if layer.lin_l.bias is not None:
                    init.zeros_(layer.lin_l.bias)


               
                if hasattr(layer, 'lin_r') and layer.lin_r is not None:
                    init.kaiming_normal_(layer.lin_r.weight, a=0.01) 
                    if layer.lin_r.bias is not None:
                        init.zeros_(layer.lin_r.bias)

            elif isinstance(layer, nn.Linear):
                
                init.kaiming_normal_(layer.weight, a=0.01)

                if layer.bias is not None:
                    init.zeros_(layer.bias)

    

    def freeze(self, epoch):
        if epoch == 1:
            freeze_until_layer = 99999
        else:
            freeze_until_layer = None  

        
        if freeze_until_layer is not None:
            for idx, param in enumerate(self.distilbert.parameters()):
                if idx < freeze_until_layer:
                    param.requires_grad = False
                else:
                    param.requires_grad = True

    def forward(self, input_ids, attention_mask, edges, verb_ids, classic_inputs):
        
        inputs = torch.cat((input_ids, classic_inputs['input_ids']), dim=0)
        attention_mask = torch.cat((attention_mask, classic_inputs['attention_mask']), dim=0)

        outputs = self.distilbert(input_ids=inputs, attention_mask=attention_mask)
        embeddings, output2 = torch.split(self.dropout(outputs.last_hidden_state), input_ids.size(0), dim=0)
        verb_idx = 0
        nodes, node_counts = [], []
        
        for i, sample_input_ids in enumerate(input_ids):

            sample_embeddings = embeddings[i]
             
            sep_positions = (sample_input_ids == self.special_token).nonzero(as_tuple=True)[0]


            
            start, num_nodes, verb_embedding = 0, 0, []
            k, next_verb = 0, verb_ids[verb_idx][0] if 0 < len(verb_ids[verb_idx]) else -1
            for j, sep_pos in enumerate(sep_positions):
                if sep_pos != start:
                    frase_embedding = sample_embeddings[start:sep_pos].mean(dim=0)
                    if j == next_verb:
                        verb_embedding.append(frase_embedding)
                        if k + 1 >= len(verb_ids[verb_idx]):
                            nodes.append(torch.stack(verb_embedding).mean(dim=0))
                            num_nodes += 1
                            verb_embedding, k = [], 0
                            verb_idx += 1
                        else:
                            k += 1
                        next_verb = verb_ids[verb_idx][k]
                    else:
                        nodes.append(frase_embedding)
                        num_nodes += 1
                    start = sep_pos + 1
            node_counts.append(num_nodes)

        
        x = self.gnn(torch.stack(nodes), edges)
        x = self.batchnorm_gnn(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.gnn2(x, edges)

        pooled_embeddings = []

        
        start_idx = 0
        
        for num_nodes in node_counts:
            
            current_nodes = x[start_idx:start_idx + num_nodes]

            
            pooled_embeddings.append(current_nodes.mean(dim=0))

            
            start_idx += num_nodes
        
        pooled_embeddings = torch.stack(pooled_embeddings)

        pooled_embeddings = torch.cat((pooled_embeddings, output2[:, 0, :]), dim=1)

        to_classify = self.relu(self.batchnorm_linear(self.linear(pooled_embeddings)))
        
        to_classify = self.dropout(to_classify)

        out = self.classifier(to_classify)

        return out

# Trainers

In [11]:


class TrainerGnn():

    def __init__(self, model,train_dataloader, validation_dataloader, optimizer, loss_function, device, scheduler=None):
        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.device = device
        self.scheduler = scheduler


    @staticmethod
    def evaluation_parameters(y_true, y_pred):
        
        y_pred = np.argmax(y_pred, axis=1)
        cm = confusion_matrix(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')
        accuracy = accuracy_score(y_true, y_pred)
        return cm, precision, recall, f1, accuracy

    @staticmethod
    def format_time_delay(seconds):
        hours = seconds // 3600
        minutes = (seconds % 3600) // 60
        seconds = seconds % 60
        return hours, minutes, seconds


    def train(self, epochs: int, use_wandb: bool = False, config: dict = {}, name: str="", target_f1: float=0.0, early_stopping=1000):
        start_time = time.time()
        restart_stopping = early_stopping
        best_model = None
        old_name = ''
        if use_wandb:
            wandb.init(
               
                project="nlphw2",
                name=name,
               
                config=config
            )
        validation_loss, precision, recall, f1, accuracy, cm = self.validate(use_wandb)
        total_loss = validation_loss
        if use_wandb:
                wandb.log({"validation_loss": validation_loss,
                      "precision": precision,
                      "recall": recall,
                      "f1": f1,
                      "accuracy": accuracy,
                      "train_loss": total_loss / len(self.train_dataloader)})
        for epoch in range(epochs):

            time_delay = time.time() - start_time
            hours, minutes, seconds = self.format_time_delay(time_delay)
            print(f"\nTempo trascorso: {hours} ore, {minutes} minuti, {seconds} secondi")
            self.model.freeze(epoch)
            self.model.train() 
            total_loss = 0
            
            for i, batch in enumerate(self.train_dataloader):
                print_progress_bar(i / len(self.train_dataloader), text=f" | training epoch {epoch}")
               
                inputs, mask, targets, edges, verbs, classic_inputs = batch

                # Zero the gradients
                self.optimizer.zero_grad()
                # Forward pass
                outputs = self.model(inputs, mask,  edges, verbs, classic_inputs)
               
                # Compute loss
                loss = self.loss_function(outputs, targets)
                
                # Backward pass and optimize
                loss.backward()
                self.optimizer.step()
                # Accumulate the total loss
                total_loss += loss.item()

            # Print the average loss for this epoch
            validation_loss, precision, recall, f1, accuracy, cm = self.validate(use_wandb)
            if f1 > target_f1:
                best_model = self.model.state_dict()
                target_f1 = f1
                if old_name != '':
                    os.remove(old_name)
                old_name = name + f'-{target_f1}.pth'
                torch.save(best_model, name + f'-{target_f1}.pth')
                early_stopping = restart_stopping
            else: 
                early_stopping -= 1
                if early_stopping == 0:
                    self.model.load_state_dict(best_model)
                    print('\nbest f1: ', target_f1)
                    return
            if use_wandb:
                wandb.log({"validation_loss": validation_loss,
                      "precision": precision,
                      "recall": recall,
                      "f1": f1,
                      "accuracy": accuracy,
                      "train_loss": total_loss / len(self.train_dataloader)})
            if self.scheduler is not None:
                self.scheduler.step(f1)
        self.model.load_state_dict(best_model)
        print('\nbest f1: ', target_f1)
        if use_wandb:
            wandb.finish()



    def validate(self, use_wandb: bool = False, test_dataloader=None, load_from=''):
        if os.path.isfile(load_from):
            self.model.load_state_dict(torch.load(load_from, map_location=torch.device('cpu')))
        dataloader = self.validation_dataloader if test_dataloader is None else test_dataloader
        if dataloader is None:
            print("empty dataloader!")
            exit(1)
        self.model.eval()  # Set the model to evaluation mode
        total_loss = 0
        all_predictions = torch.tensor([])
        all_targets = torch.tensor([])
        with torch.no_grad():  # Do not calculate gradients
            for i, batch in enumerate(dataloader):
                print_progress_bar(i / len(dataloader), text=" | validation")
                # Get the inputs and targets from the batch
                inputs, mask, targets,  edges, verbs, classic_inputs  = batch

                # Forward pass
               
                outputs = self.model(inputs, mask,  edges, verbs, classic_inputs)
                # Compute loss
               
                loss = self.loss_function(outputs, targets)
                # Accumulate the total loss
                total_loss += loss.item()
                # Store predictions and targets
                all_predictions = torch.cat((all_predictions, outputs.squeeze().round().cpu()))
                all_targets = torch.cat((all_targets, targets.cpu()))
        validation_loss = total_loss / len(dataloader)
       
        cm, precision, recall, f1, accuracy = self.evaluation_parameters(all_targets, all_predictions)
        return validation_loss, precision, recall, f1, accuracy, cm




In [12]:


class Trainer():

    def __init__(self, model,train_dataloader, validation_dataloader, optimizer, loss_function, device, scheduler=None):
        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.device = device
        self.scheduler = scheduler


    @staticmethod
    def evaluation_parameters(y_true, y_pred):
    
        y_pred = np.argmax(y_pred, axis=1)
        cm = confusion_matrix(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')
        accuracy = accuracy_score(y_true, y_pred)
        return cm, precision, recall, f1, accuracy

    @staticmethod
    def format_time_delay(seconds):
        hours = seconds // 3600
        minutes = (seconds % 3600) // 60
        seconds = seconds % 60
        return hours, minutes, seconds


    def train(self, epochs: int, use_wandb: bool = False, config: dict = {}, name: str="", target_f1: float=0.0, early_stopping=1000):
        start_time = time.time()
        restart_stopping = early_stopping
        best_model = None
        old_name = ''
        if use_wandb:
            wandb.init(
                # Set the project where this run will be logged
                project="nlphw2",
                name=name,
                # Track hyperparameters and run metadata
                config=config
            )
        validation_loss, precision, recall, f1, accuracy, cm = self.validate(use_wandb)
        total_loss = validation_loss
        if use_wandb:
                wandb.log({"validation_loss": validation_loss,
                      "precision": precision,
                      "recall": recall,
                      "f1": f1,
                      "accuracy": accuracy,
                      "train_loss": total_loss / len(self.train_dataloader)})
        for epoch in range(epochs):

            time_delay = time.time() - start_time
            hours, minutes, seconds = self.format_time_delay(time_delay)
            print(f"\nTempo trascorso: {hours} ore, {minutes} minuti, {seconds} secondi")
            self.model.freeze(epoch)
            self.model.train()  # Set the model to training mode
            total_loss = 0
          
            for i, batch in enumerate(self.train_dataloader):
                print_progress_bar(i / len(self.train_dataloader), text=f" | training epoch {epoch}")
                # Get the inputs and targets from the batch
                inputs, mask, targets, similarities = batch

                # Zero the gradients
                self.optimizer.zero_grad()
                # Forward pass
                outputs = self.model(inputs, mask, similarities)
               
                # Compute loss
                loss = self.loss_function(outputs, targets)
            
                # Backward pass and optimize
                loss.backward()
                self.optimizer.step()
                # Accumulate the total loss
                total_loss += loss.item()

            # Print the average loss for this epoch
            validation_loss, precision, recall, f1, accuracy, cm = self.validate(use_wandb)
            if f1 > target_f1:
                best_model = self.model.state_dict()
                target_f1 = f1
                if old_name != '':
                    os.remove(old_name)
                old_name = name + f'-{target_f1}.pth'
                torch.save(best_model, name + f'-{target_f1}.pth')
                early_stopping = restart_stopping
            else: 
                early_stopping -= 1
                if early_stopping == 0:
                    self.model.load_state_dict(best_model)
                    print('\nbest f1: ', target_f1)
                    return
            if use_wandb:
                wandb.log({"validation_loss": validation_loss,
                      "precision": precision,
                      "recall": recall,
                      "f1": f1,
                      "accuracy": accuracy,
                      "train_loss": total_loss / len(self.train_dataloader)})
            if self.scheduler is not None:
                self.scheduler.step(f1)
        self.model.load_state_dict(best_model)
        print('\nbest f1: ', target_f1)
        if use_wandb:
            wandb.finish()



    def validate(self, use_wandb: bool = False, test_dataloader=None, load_from=''):
        if os.path.isfile(load_from):
            self.model.load_state_dict(torch.load(load_from, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
        dataloader = self.validation_dataloader if test_dataloader is None else test_dataloader
        if dataloader is None:
            print("empty dataloader!")
            exit(1)
        self.model.eval()  # Set the model to evaluation mode
        total_loss = 0
        all_predictions = torch.tensor([])
        all_targets = torch.tensor([])
        with torch.no_grad():  # Do not calculate gradients
            for i, batch in enumerate(dataloader):
                print_progress_bar(i / len(dataloader), text=" | validation")
                # Get the inputs and targets from the batch
                inputs, mask, targets, similarities  = batch

                # Forward pass
                outputs = self.model(inputs, mask, similarities)
                # Compute loss
             
                loss = self.loss_function(outputs, targets)
                # Accumulate the total loss
                total_loss += loss.item()
                # Store predictions and targets
                all_predictions = torch.cat((all_predictions, outputs.squeeze().round().cpu()))
                all_targets = torch.cat((all_targets, targets.cpu()))
        validation_loss = total_loss / len(dataloader)
       
        cm, precision, recall, f1, accuracy = self.evaluation_parameters(all_targets, all_predictions)
        return validation_loss, precision, recall, f1, accuracy, cm




# Creating datasets

In [13]:

adversarial = load_dataset("iperbole/adversarial_fever_nli")["test"]

ds = load_dataset("tommasobonomo/sem_augmented_fever_nli")

training_set = ds["train"]

validation_set = ds["validation"]

test_set = ds["test"]


In [14]:
new_seed = 108
def set_seed(seed):
    np.random.seed(seed)
    rnd.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True 
    return seed, seed+1

In [15]:


seed, new_seed = set_seed(new_seed)

In [None]:
train_dataset = NLIDataset(training_set)

validation_dataset = NLIDataset(validation_set)



In [None]:
train_dataset_gnn = NLIDatasetGnn(training_set)

validation_dataset_gnn = NLIDatasetGnn(validation_set)


In [None]:
test_dataset = NLIDataset(test_set)
test_dataset_gnn = NLIDatasetGnn(test_set)

adversarial_dataset = NLIDataset(adversarial, adversarial=True)

# Training Models

In [18]:

set_seed(108)
train_dataloader = train_dataset.get_dataloader(batch_size=40)
validation_dataloader = validation_dataset.get_dataloader(batch_size=40)



In [None]:


model = RobertaClassifier(False, dropout=0.3)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-5)

scheduler = ReduceLROnPlateau(optimizer, 'max', patience=2)

trainer = Trainer(model, train_dataloader, validation_dataloader, optimizer, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu', scheduler=scheduler)
 
trainer.train(10, use_wandb=False, name="base_adamW_3e-5_"+ str(seed), early_stopping=3)

In [None]:
set_seed(108)
train_dataloader_gnn = train_dataset_gnn.get_dataloader(batch_size=20)
validation_dataloader_gnn = validation_dataset_gnn.get_dataloader(batch_size=20)

In [None]:



model_gnn = RobertaClassifierGnnLayerNorm(False, dropout=0.3)

optimizer = torch.optim.AdamW(model_gnn.parameters(), lr=4e-5, weight_decay=1e-5)

scheduler = ReduceLROnPlateau(optimizer, 'max', patience=2)

trainer = TrainerGnn(model_gnn, train_dataloader_gnn, validation_dataloader_gnn, optimizer, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu', scheduler=scheduler)
 
trainer.train(10, use_wandb=False, name="gnn_adamW_4e-5_"+ str(seed), early_stopping=3)


# Comparison between models on testset

In [24]:
set_seed(108)
test_dataloader = test_dataset.get_dataloader(batch_size=64)
test_dataloader_gnn = test_dataset_gnn.get_dataloader(batch_size=64)
adversarial_dataloader = adversarial_dataset.get_dataloader(batch_size=32)

In [None]:



trainer = Trainer(model, None, None, None, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu')

validation_loss_1, precision_1, recall_1, f1_1, accuracy_1, cm_1 = trainer.validate(test_dataloader=test_dataloader)




trainer = TrainerGnn(model_gnn, None, None, None, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu')


validation_loss_5, precision_5, recall_5, f1_5, accuracy_5, cm_5 = trainer.validate(test_dataloader=test_dataloader_gnn)




In [None]:
print("f1_score_base: ", f1_1)

print("accuracy_base: ", accuracy_1)

print("f1_score_gnn_layerNorm: ", f1_5)

print("accuracy_gnn_layerNorm: ", accuracy_5)



In [None]:

disp = ConfusionMatrixDisplay(confusion_matrix=cm_1, display_labels=['CONTRADICTION', 'NEUTRAL','ENTAILMENT' ])
disp.plot()
plt.title('Confusion Matrix base model')
plt.savefig('Confusion_Matrix_base_model.png')


disp = ConfusionMatrixDisplay(confusion_matrix=cm_5, display_labels=['CONTRADICTION', 'NEUTRAL','ENTAILMENT' ])
disp.plot()
plt.title('Confusion Matrix gnn model layer norm')
plt.savefig('Confusion_Matrix_gnn_model_layer_norm.png')

plt.show()


# Test of base model on adversarial dataset

In [None]:


trainer = Trainer(model, None, None, None, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu')

adv_validation_loss_1, adv_precision_1, adv_recall_1, adv_f1_1, adv_accuracy_1, adv_cm_1 = trainer.validate(test_dataloader=adversarial_dataloader)


print("\nf1_score_base_adv: ", adv_f1_1)

print("accuracy_base_adv: ", adv_accuracy_1)


disp = ConfusionMatrixDisplay(confusion_matrix=adv_cm_1, display_labels=['CONTRADICTION', 'NEUTRAL','ENTAILMENT' ])
disp.plot()
plt.title('Confusion Matrix base model on adversarial dataset')
plt.savefig('Confusion_Matrix_base_model_on_adversarial_dataset.png')


plt.show()



# Data Augmentation and Training on Augmented Train Dataset

In [None]:
set_seed(108)


manipulation_dict = {
    'NEGATE_PART_PREMISE': 1,
    'SYNONYM': 1,
    'ANTINOMY_PART_PREMISE': 1,
    'HYPONYM_PREMISE': 1,
    'SWITCH_DATA': 1,
    'SWITCH_PARTIAL_DATA': 1,
    'TAKE_PART_PREMISE': 1,
    'NEGATE_HYPOTHESIS': 1,
    'HYPERNYM_HYPOTHESIS': 1,
    'IMPOSSIBILITY': 1,
    'TRUNCATE_HYPOTHESIS': 1,
    'TAUTOLOGY': 1,
    'DUPLICATE_HYPOTHESIS': 1,
    'CHANGE_NUMBERS': 1,
    'CONVERT_NUMBERS': 3
}

initial_counts = [31128, 12331, 7627]

def print_histogram(example_counts):
    labels = ["Entailment", "Neutral", "Contradiction"]

   
    total_examples = sum(example_counts)
    percentages = [(count / total_examples) * 100 for count in example_counts]


    plt.figure(figsize=(8, 6))
    plt.bar(labels, percentages, color=['skyblue', 'salmon', 'lightgreen'])
    plt.xlabel('Labels')
    plt.ylabel('Percentage (%)')
    plt.title('Distribution of Dataset Labels by Percentage')
    plt.ylim(0, 100)

    
    for i, percentage in enumerate(percentages):
        plt.text(i, percentage + 1, f"{percentage:.2f}%", ha='center')

    plt.show()

seed, new_seed = set_seed(new_seed)
training_set, new_data, info_augmentations = augmentation_module.augment_data(training_set, 10000, manipulation_dict)
grouped_data = {}
new_counts = initial_counts.copy()
for sample in new_data:
    if sample['label'] == 'CONTRADICTION':
        new_counts[1] += 1
    elif sample['label'] == 'NEUTRAL':
        new_counts[2] += 1
    elif sample['label'] == 'ENTAILMENT':
        new_counts[0] += 1
    else:
        print("error")
    augment_value = sample['augment_method']
    grouped_data[augment_value] = grouped_data.get(augment_value, []) + [sample]
pprint(info_augmentations)

total_attempts = 0
total_successes = 0


for key, values in info_augmentations.items():
    total_attempts += values['count']
    total_successes += values['success']


print(f"Total attempts: {total_attempts}")
print(f"Total successes: {total_successes}")

print_histogram(initial_counts)
print_histogram(new_counts)



In [None]:
train_dataset = NLIDataset(training_set)
set_seed(108)


train_dataloader = train_dataset.get_dataloader(batch_size=40)
validation_dataloader = validation_dataset.get_dataloader(batch_size=40)

model = RobertaClassifier(False, dropout=0.3)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-5)

scheduler = ReduceLROnPlateau(optimizer, 'max', patience=2)

trainer = Trainer(model, train_dataloader, validation_dataloader, optimizer, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu', scheduler = scheduler)

trainer.train(10, use_wandb=False, name="base_augmented10k_adamW_3e-5_"+ str(seed), early_stopping=3)

# Copmarison between model trained on augmented data on test and adversarial datasets

In [None]:

set_seed(108)
test_dataloader = test_dataset.get_dataloader(batch_size=64)
adversarial_dataloader = adversarial_dataset.get_dataloader(batch_size=32)

trainer = Trainer(model, None, None, None, nn.CrossEntropyLoss(),'cuda' if torch.cuda.is_available() else 'cpu')


test_validation_loss_1, test_precision_1, test_recall_1, test_f1_1, test_accuracy_1, test_cm_1 = trainer.validate(test_dataloader=test_dataloader)

adv_validation_loss_1, adv_precision_1, adv_recall_1, adv_f1_1, adv_accuracy_1, adv_cm_1 = trainer.validate(test_dataloader=adversarial_dataloader)



print("\nf1_score_base_augmented_test: ", test_f1_1)

print("accuracy_base_augmented_test: ", test_accuracy_1)

print("f1_score_base_augmented_adv: ", adv_f1_1)

print("accuracy_base_augmented_adv: ", adv_accuracy_1)



disp = ConfusionMatrixDisplay(confusion_matrix=test_cm_1, display_labels=['CONTRADICTION', 'NEUTRAL','ENTAILMENT' ])
disp.plot()
plt.title('Confusion Matrix base augmented model on test dataset')
plt.savefig('Confusion_Matrix_base_augmented_model_on_test_dataset.png')

disp = ConfusionMatrixDisplay(confusion_matrix=adv_cm_1, display_labels=['CONTRADICTION', 'NEUTRAL','ENTAILMENT' ])
disp.plot()
plt.title('Confusion Matrix base augmented model on adversarial dataset')
plt.savefig('Confusion_Matrix_base_augmented_model_on_adversarial_dataset.png')


plt.show()