In [None]:
import json
import torch
import string
import random
import numpy as np
import pandas as pd

from tqdm.notebook import *
from transformers import BartModel, BartTokenizer

In [None]:
df_test = pd.read_csv("updated_test_data.csv").fillna("")
df_valid = pd.read_csv("train.csv")
df_train = pd.read_csv("train.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def add_rand(text):
    r = text.split()
    for i in range(random.randrange(max(3, len(text)//16))):
        start = random.randrange(max(3, len(text.split())-64))
        r = r[:start] + ["".join(random.choices(string.ascii_lowercase, k = random.randrange(5) + 4))] + r[start:]
    return " ".join(r)

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.keys = ['Human_story', 'gemma-2-9b', 'mistral-7B', 'qwen-2-72B', 'llama-8B', 'accounts/yi-01-ai/models/yi-large', 'GPT_4-o']
    def __len__(self):
        return len(self.data)*len(self.keys)

    def __getitem__(self, idx):
        data_idx = idx//len(self.keys)
        labels = torch.zeros(len(self.keys))
        labels[idx%len(self.keys)] = 1.0
        text = self.data[self.keys[idx%len(self.keys)]][data_idx]
        
        # if (random.random() < 0.1):
        #     text = add_rand(text)
        # if (random.random() < 0.1):
        #     start = random.randrange(max(2, len(text.split())))
        #     text = " ".join(text.split()[start:])
        
        return {"text": text, "label": labels}

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {"id": idx, "text": self.data["Text"][idx]}

In [None]:
class ModelFactory(torch.nn.Module):
    def __init__(self, text_model, n_classes, hashLength=512):
        super(ModelFactory, self).__init__()
        self.text_model = text_model

        self.fc = torch.nn.Linear(1024, n_classes)
        self.output_layer = torch.nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask):
        text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        logits = self.fc(text_features)
        return self.output_layer(logits)

In [None]:
class Accuracy:
    def __init__(self):
        self.correct_count = 0
        self.total_count = 0

    def update(self, predictions, targets):
        predicted_labels = torch.argmax(predictions, dim=1)
        target_labels = torch.argmax(targets, dim=1)
        self.correct_count += (predicted_labels == target_labels).sum().item()
        self.total_count += targets.size(0)

    def reset(self):
        self.correct_count = 0
        self.total_count = 0

    def compute(self):
        if self.total_count == 0:
            return 0.0
        return self.correct_count / self.total_count

In [None]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
text_model = BartModel.from_pretrained("facebook/bart-large").to(device)

train_dataset = torch.utils.data.DataLoader(TrainDataset(df_train), batch_size=8, shuffle=True, num_workers=0)
valid_dataset = torch.utils.data.DataLoader(TrainDataset(df_valid), batch_size=8, shuffle=True, num_workers=0)
test_dataset = torch.utils.data.DataLoader(TestDataset(df_test), batch_size=8, shuffle=False, num_workers=0)

model = ModelFactory(text_model, 7).to(device)
model.load_state_dict(torch.load("model.pth").state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
loss_function = torch.nn.CrossEntropyLoss()
metrics = [Accuracy()]

scaler = torch.amp.GradScaler(enabled=True)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

In [None]:
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    pbar = tqdm(train_dataset)
    for i, batch in enumerate(pbar):
        texts = tokenizer(batch["text"], return_tensors="pt", padding="max_length", max_length=128, truncation=True)
        input_ids = texts["input_ids"].to(device)
        attention_mask = texts["attention_mask"].to(device)
        labels = batch["label"].to(device)
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16, enabled=True):
            output = model(input_ids, attention_mask)
            loss = loss_function(output, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        for metric in metrics:
            metric.update(output, labels)
        pbar.set_description(f"Epoch: {epoch}, Loss: {loss.item():.5f}, Accuracy: {metric.compute():.5f}")

    metric.reset()
    torch.save(model, "model.pth")
    model.eval()
    
    pbar = tqdm(valid_dataset)
    for batch in pbar:
        texts = tokenizer(batch["text"], return_tensors="pt", padding="max_length", max_length=128, truncation=True)
        input_ids = texts["input_ids"].to(device)
        attention_mask = texts["attention_mask"].to(device)
        labels = batch["label"].to(device)
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16, enabled=True):
            output = model(input_ids, attention_mask)
        for metric in metrics:
            metric.update(output, labels)
        pbar.set_description(f"Valid: {epoch}, Accuracy: {metric.compute():.5f}")
    
    metric.reset()
    pbar = tqdm(test_dataset)
    predictions = []
    for batch in pbar:
        texts = tokenizer(batch["text"], return_tensors="pt", padding="max_length", max_length=128, truncation=True)
        input_ids = texts["input_ids"].to(device)
        attention_mask = texts["attention_mask"].to(device)
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16, enabled=True):
            output = model(input_ids, attention_mask).detach().cpu().numpy()
        for i, prob in enumerate(output):
            argmax_pred = int(prob.argmax())
            predictions.append({'id': int(batch["id"][i]),
                                'Text': batch["text"][i],
                                'Label_A': int(argmax_pred > 0),
                                'Label_B': ['Human_story', 'gemma-2-9b', 'mistral-7B', 'qwen-2-72B', 'llama-8B', 'Yi-Large', 'GPT_4-o'][argmax_pred],
                                'Prob': prob.tolist(),
                               })

    with open(f"answer_{epoch}.json", 'w') as f:
        json.dump(predictions, f, indent=4)