In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AdamW
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import json
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import warnings
warnings.filterwarnings('always')

In [None]:
def get_data(path):
    data=pd.read_csv(path)

    # train=data.iloc[:16]
    # test=data.iloc[16:32]
    # val=data.iloc[32:48]
    train=data.iloc[:6]
    test=data.iloc[6:8]
    val=data.iloc[8:10]
    train['sentences']=train['sentences'].apply(eval)
    test['sentences']=test['sentences'].apply(eval)
    val['sentences']=val['sentences'].apply(eval)
    return train, val, test


class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
        self.labels = self.df['label']
        self.sentence_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        sample = {}
        lines = self.df.iloc[idx]['sentences']
        embeddings = self.sentence_model.encode(
            lines
        )

        sample['embeddings'] = torch.from_numpy(embeddings)
        sample['label'] = torch.Tensor([self.df.iloc[idx]['label']])
        return sample 
    
def custom_collate(batch):
    embs, labels,  = [], [] 
    for item in batch:
        labels.append(item['label'])
        embs.append(item['embeddings'])

    labels = pad_sequence(labels, batch_first=True)
    embs = pad_sequence(embs, batch_first=True)
    return embs,   labels.long()


In [None]:
def train_step( model, dataloader, device, optimizer):
    model.train()
    total_loss = 0
    for batch in dataloader:
        embeddings, label  = batch 
        embeddings = embeddings.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        label_logits = model(embeddings)  

        loss = F.cross_entropy(label_logits, label.squeeze(1))        
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

In [None]:
def eval_step(model, dataloader, device):
    model.eval()
    total_loss = 0
    targets = []
    predictions = []
    for batch in dataloader:
            embeddings, label  = batch 
            embeddings = embeddings.to(device)
            label = label.to(device)

            model.zero_grad()
            with torch.no_grad():
                label_logits  = model(embeddings) ## mask

            loss = F.cross_entropy(label_logits, label.squeeze(1))
            total_loss += loss.item()
            
            pred_label = torch.argmax(label_logits, dim=1).flatten().cpu().numpy()
            predictions.append(pred_label)
            targets.append(label.squeeze(1).cpu().numpy())

    targets = np.concatenate(targets, axis=0)
    predictions = np.concatenate(predictions, axis=0)
    
    epoch_loss = total_loss/len(dataloader)
    return epoch_loss, targets, predictions

In [None]:
input_path = "../results/ranked/"
batch_size = 16
files = os.listdir("../results/ranked/")

device =torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda:0")

train, val, test = get_data(input_path+files[0])
train_dataset = Dataset(train)
val_dataset = Dataset(val)
test_dataset = Dataset(test)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate
)

In [None]:

class MultiTaskModel(nn.Module):
    def __init__(self,
                 nhead=1,
                 nlayers=1,
                 use_cls=True,
                 #  cls_embed=None,
                 d_model=768):
        super(MultiTaskModel, self).__init__()
        self.classifier = nn.Linear(d_model, 2)

        self.use_cls = use_cls
        self.d_model = d_model
        if use_cls:
            self.cls_embed = nn.Embedding(1, self.d_model)

        self.encoder_layer = nn.TransformerEncoder(nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            batch_first=True),
            nlayers,
            norm=None)

    def forward(self, x):
        batch_size = x.size()[0]
        if self.use_cls:
            x = torch.cat([self.cls_embed.weight[0].unsqueeze(
                0).repeat(batch_size, 1, 1), x], dim=1)

        x = self.encoder_layer(x)

        if self.use_cls:
            label_x = x[:, 0, :]
        else:
            label_x = torch.sum(x, dim=1)
        label_logits = self.classifier(label_x)  
        return label_logits 


In [None]:
epochs=100
save_model = True 
os.makedirs("../models/", exist_ok=True)
model_path = "../models/normal.pt"
d_model = 768

model = MultiTaskModel(d_model=d_model)

model.to(device)
optimizer = AdamW(model.parameters(), lr=5*1e-5)
best_loss = np.inf
best_epoch = 0

for epoch in range(epochs):
    train_loss = train_step(model, train_dataloader, device, optimizer)

    val_loss,_,_ = eval_step(model, val_dataloader, device)

    print(f"\nEpoch: {epoch+1} | Training loss: {train_loss} | Validation Loss: {val_loss}")
    if (val_loss < best_loss) and (save_model == True):
        torch.save(model.state_dict(), model_path)
        best_loss = val_loss
        best_epoch = epoch+1

In [None]:
loaded_state_dict = torch.load(model_path,  map_location=device)
model.load_state_dict(loaded_state_dict)

test_loss,targets,predictions = eval_step(model, test_dataloader, device)
accuracy = np.sum(targets == predictions)/len(targets)
print(f"Accuracy: {accuracy}")
ConfusionMatrixDisplay.from_predictions(targets, predictions)
# plt.savefig(f"./confusion_matrix.png", dpi=300)