In [None]:
import os
import torch
from torch import nn
import numpy as np
from datasets import ClassLabel
from transformers import AutoModelForSequenceClassification, AutoModel, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)
import CXRBERT
from collections import OrderedDict

In [None]:
checkpoint = "allenai/scibert_scivocab_cased"
config = CXRBERT.CXRBertConfig.from_pretrained(checkpoint)
tokenizer = CXRBERT.CXRBertTokenizer.from_pretrained(checkpoint,padding="max_length", truncation=True, max_length=128)
model = CXRBERT.CXRBertModel(config).from_pretrained(checkpoint)

In [None]:
#for the pre-trained models
# path = './CXR-BERT/2v0m7nnw/checkpoints/epoch=49-step=25000.ckpt'
# model_cpt = torch.load(path)

# new_state_dict = OrderedDict()
# for k, v in model_cpt['state_dict'].items():
#     name = k[6:] # remove `module.`
#     new_state_dict[name] = v


# model.load_state_dict(new_state_dict)


In [None]:
from datasets import load_dataset

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
g_labels = ClassLabel(names = ['contradiction', 'entailment', 'neutral'])
int2label = {0: "contradiction", 1: "entailment", 2: "neutral"}

dataset = load_dataset("csv", data_files={"train": "mednli_train.csv", 
                                                    'validation': "radnli_val.csv", 
                                                    'test': "radnli_test.csv"})

x =  [label2int[element] for element in dataset['train']['gold_label']]
y = [label2int[element] for element in dataset['validation']['gold_label']]
z = [label2int[element] for element in dataset['test']['gold_label']]

dataset['train']=dataset['train'].add_column('label',x)
dataset['validation']=dataset['validation'].add_column('label',y)
dataset['test']=dataset['test'].add_column('label',z)


In [None]:

dataset = dataset.rename_column ("sentence1", "premise")
dataset = dataset.rename_column ("sentence2", "hypothesis")
dataset = dataset.remove_columns (['gold_label', 'pair_id'])#, 'label'])


In [None]:
#to visualise the dataset
import collections
import random
import matplotlib.pyplot as plt
labels = g_labels.names
try:
# Graph for split wise graph
  split_wise_labels = [key for key, value in dataset.items ()]
  split_wise_numbers = [value.num_rows for key, value in dataset.items ()]
  split_wise_explode = [random.randint (1, 5) / 10 if key != 'train' else 0 for key, value in dataset.items ()]

  sen_length = {}

  # Graph for total label wise split
  split_wise_label_wise = {}
  total_label_wise_split = [0] * len (labels)
  for key, value in dataset.items ():
    counter = collections.Counter (value['label'])
    temp = [0] * len (labels)
    for index, total in counter.items ():
      temp[ index ] += total
    split_wise_label_wise[key] = temp  
    total_label_wise_split = [x + y for x, y in zip(total_label_wise_split, temp)]

    sen_length[key] = {'premise': [], 'hypothesis': [], 'sum': []}
    
    for element in value:
      sen_length[key]['premise'].append ( len(element['premise'].split ()) )
      sen_length[key]['hypothesis'].append ( len(element['hypothesis'].split ()) )
      sen_length[key]['sum'].append (sen_length[key]['premise'][-1] + sen_length[key]['hypothesis'][-1])
    
  # Graph for split wise and label wise
  split_wise_label_wise = [split_wise_label_wise[label] for label in split_wise_labels]
  split_wise_label_wise = [[split_wise_label_wise[j][i] for j in range(len(split_wise_label_wise))] for i in range(len(split_wise_label_wise[0]))]

  fig, axs = plt.subplots(2,3, figsize=(20,10))
  fig.tight_layout()
  # Graph 1
  axs[0, 0].pie(x = split_wise_numbers, explode=split_wise_explode, labels=split_wise_labels, autopct='%1.1f%%', startangle=90)
  axs[0, 0].axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
  axs[0, 0].set_title ('Train, Test, Validation Split')
  axs[0, 1].pie(x = total_label_wise_split, labels=labels, autopct='%1.1f%%', startangle=90)
  axs[0, 1].axis('equal')
  axs[0, 1].set_title ('Label wise split for entire Dataset')

  # Graph 3
  X = np.arange(3)
  axs[0, 2].bar(X + 0.00, split_wise_label_wise[0], color = 'b', width = 0.25)
  axs[0, 2].bar(X + 0.25, split_wise_label_wise[1], color = 'g', width = 0.25)
  axs[0, 2].bar(X + 0.50, split_wise_label_wise[2], color = 'r', width = 0.25)
  axs[0, 2].legend(labels=labels)
  axs[0, 2].set_xticklabels(["", split_wise_labels[0], "", split_wise_labels[1], "", split_wise_labels[2]]) 
  axs[0, 2].set_title ('Split Wise, Label Wise Data')

  axs[1, 0].hist(sen_length['train']['premise'], bins=50, label="premise length in train set", alpha=0.5)
  axs[1, 0].hist(sen_length['validation']['premise'], bins=50, label="premise length in validation set", alpha=0.5)
  axs[1, 0].legend(loc='best')
  axs[1, 0].set_title ('Premise Length Comparison')
  axs[1, 0].set_xlabel('Sentence Length')
  axs[1, 0].set_ylabel('Frequency')
  
  axs[1, 1].hist(sen_length['train']['hypothesis'], bins=50, label="hypothesis length in train set", alpha=0.5)
  axs[1, 1].hist(sen_length['validation']['hypothesis'], bins=50, label="hypothesis length in validation set", alpha=0.5)
  axs[1, 1].legend(loc='best')
  axs[1, 1].set_title ('Hypothesis Length Comparison')
  axs[1, 1].set_xlabel('Sentence Length')
  axs[1, 1].set_ylabel('Frequency')

  axs[1, 2].hist(sen_length['train']['sum'], bins=50, label="Input length in train set", alpha=0.5)
  axs[1, 2].hist(sen_length['validation']['sum'], bins=50, label="Input length in validation set", alpha=0.5)
  axs[1, 2].legend(loc='best')
  axs[1, 2].set_title ('Input Length Comparison')
  axs[1, 2].set_xlabel('Sentence Length')
  axs[1, 2].set_ylabel('Frequency')

  plt.show()

except Exception as e:
  print ("Cannot plot the graphs: ", str (e))

In [None]:
def tokenize_dataset (dataset, model_checkpoint):
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True)
  
  def tokenize_function (examples):
    return tokenizer (examples['premise'], 
                      examples['hypothesis'],
                      add_special_tokens = True,
                      max_length=64, 
                      padding = 'max_length', 
                      truncation = True)
    
  tokenized_datasets = dataset.map (tokenize_function,
                                  batched = True,
                                  remove_columns = ['premise', 'hypothesis']
                                  ).with_format("torch")
  return (tokenizer, tokenized_datasets)

def classification_model (dataset, model_checkpoint):



  model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, trust_remote_code = True,
                                                            problem_type="single_label_classification",
                                                             num_labels = 3,#len (classes), 
                                                             id2label = int2label, 
                                                             label2id = label2int
                                                             )
  return model

In [None]:
(tokenizer, tokenized_dataset) = tokenize_dataset (dataset, model_checkpoint = checkpoint)

In [None]:
model2 = classification_model(dataset=dataset, model_checkpoint = checkpoint)
#for the pre-trained models in this thesis
# model2.bert.embeddings = model.base_model.embeddings
# model2.bert.encoder = model.base_model.encoder
# model2.bert.pooler.dense = torch.nn.Linear(in_features=768, out_features=768, bias=True)
# model2.bert.pooler.activation = torch.nn.Tanh()

In [None]:
#one forward pass to check if the code is working correctly
outputs = model2(input_ids=tokenized_dataset['train']['input_ids'][0].unsqueeze(0), labels = tokenized_dataset['train'][0]['label'].unsqueeze(0))
outputs



In [None]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support
from transformers import EvalPrediction
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd 

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
   
    probs = torch.softmax(torch.Tensor(predictions), dim=1)
    y_pred = torch.argmax(probs, dim=1)
    y_true = labels
    
    
    

    precision, recall, f1, _ = precision_recall_fscore_support(labels, y_pred, average="macro")
    f1_w = f1_score(y_true=y_true, y_pred=y_pred, average='weighted')
    acc = accuracy_score(labels, y_pred)
    #cf_matrix = confusion_matrix(y_true, y_pred)
    #print('Confusion Matrix\n')
    #print(cf_matrix)
    
    # df_cm = pd.DataFrame((cf_matrix/np.sum(cf_matrix)) *10, index = [i for i in g_labels.names],
    #                   columns = [i for i in g_labels.names])
    # plt.figure(figsize = (12,7))
    # sn.heatmap(df_cm, annot=True)
    
    return {"accuracy": acc,
            "f1": f1, 
            "f1_weighted" : f1_w,
            "precision": precision, 
            "recall": recall
            }
    

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [None]:
batch_size = 16
metric_name = "accuracy"
from transformers import EarlyStoppingCallback, IntervalStrategy

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"training_with_callbacks",
    #evaluation_strategy = 'epoch', #IntervalStrategy.STEPS,
    #save_strategy = "epoch",
    logging_strategy = 'epoch',
    learning_rate= 1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    weight_decay=1e-4,
    max_grad_norm= 5.0,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #save_total_limit=1,
    report_to="none",
    seed = 9000,
    push_to_hub=False,
)

In [None]:

trainer = Trainer(
    model2,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset['validation'], 
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
    
)

In [None]:
trainer.train()

In [None]:
raw_pred = trainer.predict(tokenized_dataset["test"])
raw_pred.metrics