In [None]:
import json
import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt

from collections import Counter
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import f1_score, cohen_kappa_score, confusion_matrix, classification_report, ConfusionMatrixDisplay

In [None]:
def generate_metrics(predicted_labels_raw, true_labels, return_metrics=False):

    plt.hist(predicted_labels_raw)
    plt.title('BLEURT Score Distribution')
    plt.show()

    max_f1_at_th = -1.0
    max_f1 = 0.0
    
    for th in np.arange(0.5, 1.0, 0.05):
        th = round(th, 2)
        predicted_labels = [1 if bleurt_score >= th else 0 for bleurt_score in predicted_labels_raw]

        temp_f1 = f1_score(true_labels, predicted_labels)

        if temp_f1 >= max_f1:
            max_f1 = temp_f1
            max_f1_at_th = th

    # printing metrics at the threshold for which we got max F1-Score
    predicted_labels = [1 if bleurt_score >= max_f1_at_th else 0 for bleurt_score in predicted_labels_raw]        

    if return_metrics:
        return max_f1_at_th, f1_score(true_labels, predicted_labels), cohen_kappa_score(true_labels, predicted_labels)
    else:
        print(f'Optimal Thresohld: {max_f1_at_th} \n')
        print(f'Predicted Label Count: {Counter(predicted_labels)}\n')
        print('Classification Report:')
        print(classification_report(true_labels, predicted_labels), '\n')
        print('F1 Score: ', f1_score(true_labels, predicted_labels), '\n')
        print('Cohen Kappa: ', cohen_kappa_score(true_labels, predicted_labels), '\n')
        cm_display = ConfusionMatrixDisplay(confusion_matrix = confusion_matrix(true_labels, predicted_labels), display_labels = ['incorrect', 'correct'])
        cm_display.plot()
        plt.show()

    return None

In [None]:
class BLEURTDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        candidate = item["candidate"]
        reference = item["reference"]
        score = item["score"]
        encoding = self.tokenizer(candidate, reference, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(score, dtype=torch.float)
        }

In [None]:
# data_path = "/home/jovyan/active-projects/textbook-question-generation/bleurt-models/bleurtmodel/bleurt/bleurt/test_data/chatgpt-vicuna/{}_samples.jsonl"
# model_save_path = "/home/jovyan/active-projects/textbook-question-generation/src/chatgpt-vicuna-bleurt/"
unique_identifier = "20230912"
dataset_name = "multirc"
data_path = "/home/jovyan/active-projects/textbook-question-generation/bleurt-models/bleurtmodel/bleurt/bleurt/test_data/"+dataset_name+"-dataset/{}_samples.jsonl"
model_save_path = f"/home/jovyan/active-projects/textbook-question-generation/src/{dataset_name}-bleurt/"

model_type = "large"
num_training_epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
is_model_already_trained = False

if is_model_already_trained:
    model = AutoModelForSequenceClassification.from_pretrained(model_save_path+f'model_{model_type}_'+unique_identifier)
    tokenizer = AutoTokenizer.from_pretrained(model_save_path+f'tokenizer_{model_type}_'+unique_identifier)
else:
    model = AutoModelForSequenceClassification.from_pretrained(f"Elron/bleurt-{model_type}-128")
    tokenizer = AutoTokenizer.from_pretrained(f"Elron/bleurt-{model_type}-128")
                         

In [None]:
with open(data_path.format('train'), 'r') as file:
    train_samples = [json.loads(line) for line in file]

with open(data_path.format('validation'), 'r') as file:
    validation_samples = [json.loads(line) for line in file]

train_dataset = BLEURTDataset(train_samples, tokenizer)
validation_dataset = BLEURTDataset(validation_samples, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=8)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# loss_fn = torch.nn.CrossEntropyLoss()

torch.cuda.empty_cache()
model.to(device)
assert device.type == 'cuda'

In [None]:
if is_model_already_trained:

    temp_vp = []
    temp_vtl = []
    with torch.no_grad():
        for batch in tqdm.tqdm(validation_loader):
            temp_vtl.extend(batch['labels'])
            
            outputs = model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device))
            logits = outputs.logits
            preds = logits.detach().cpu().numpy()
            temp_vp.extend(preds)
    
    _, best_f1_score, _ = generate_metrics([arr[0] for arr in temp_vp], [t.item() for t in temp_vtl], return_metrics=True)

else:
    best_f1_score = -1

best_f1_score

In [None]:
all_epochs_validation_losses = []
all_epochs_training_losses = []

for epoch in range(num_training_epochs):
    print(f"Epoch {epoch + 1} of {num_training_epochs}:")

    predicted_labels_raw = []
    true_labels = []
    
    epoch_training_loss = 0
    epoch_validation_loss = 0
    
    model.train()
    for batch in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        torch.cuda.empty_cache()
        outputs = model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))
        batch_training_loss = outputs.loss
        epoch_training_loss += batch_training_loss.item()
        batch_training_loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        for batch in tqdm.tqdm(validation_loader):
            true_labels.extend(batch['labels'])
            outputs = model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))
            batch_validation_loss = outputs.loss
            epoch_validation_loss += batch_validation_loss.item()
            predicted_labels_raw.extend(outputs.logits.detach().cpu().numpy())
            
    all_epochs_training_losses.append(round(epoch_training_loss / len(train_loader), 2))
    all_epochs_validation_losses.append(round(epoch_validation_loss / len(validation_loader), 2))

    batch_optimal_threshold, batch_f1_score, batch_kappa = generate_metrics([arr[0] for arr in predicted_labels_raw], [t.item() for t in true_labels], return_metrics=True)

    print(f"\tValidation loss: {all_epochs_validation_losses[-1]}")
    print(f"\tOptimal Threshold: {batch_optimal_threshold}")
    print(f"\tF1 Score: {batch_f1_score}")
    print(f"\tCohen Kappa: {batch_kappa}")

    if batch_f1_score > best_f1_score:
        print(f"\n\tSaving this epoch's model. Previous Best F1: {best_f1_score}, Current Best F1: {batch_f1_score}")
        best_f1_score = batch_f1_score
        model.save_pretrained(model_save_path+f'model_{model_type}_'+unique_identifier)
        tokenizer.save_pretrained(model_save_path+f'tokenizer_{model_type}_'+unique_identifier)

print("Training complete!")

In [None]:
plt.plot(range(num_training_epochs), all_epochs_training_losses, color='blue')
plt.plot(range(num_training_epochs), all_epochs_validation_losses, color='orange')
plt.show()

In [None]:
ft_best_model = AutoModelForSequenceClassification.from_pretrained(model_save_path+f'model_{model_type}_'+unique_identifier)
ft_best_model.to(device)
ft_best_model.eval()
ft_best_tokenizer = AutoTokenizer.from_pretrained(model_save_path+f'tokenizer_{model_type}_'+unique_identifier)

with open(data_path.format('test'), 'r') as file:
    test_samples = [json.loads(line) for line in file]

test_dataset = BLEURTDataset(test_samples, ft_best_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=8)

In [None]:
predictions = []
true_labels = []
with torch.no_grad():
    for batch in tqdm.tqdm(test_loader):
        true_labels.extend(batch['labels'])
        
        outputs = ft_best_model(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device))
        logits = outputs.logits
        preds = logits.detach().cpu().numpy()
        predictions.extend(preds)

generate_metrics([arr[0] for arr in predictions], [tn.item() for tn in true_labels])