In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import csv
from transformers import XLMRobertaTokenizer

class CustomDataset(Dataset):
    def __init__(self, encoded_dataset):
        self.encoded_dataset = encoded_dataset

    def __len__(self):
        return len(self.encoded_dataset['labels'])

    def __getitem__(self, idx):
        item = {
            'input_ids': self.encoded_dataset['input_ids'][idx],
            'attention_mask': self.encoded_dataset['attention_mask'][idx],
            'labels': self.encoded_dataset['labels'][idx]
        }
        return item

tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

def tokenize_sentences(df):
    encoded_sentences = {
        'input_ids': [],
        'attention_mask': [],
        'labels': []
    }

    for _, row in df.iterrows():
        encoded_sent = tokenizer(row['target'], padding='max_length', truncation=True, return_tensors='pt')

        encoded_sentences['input_ids'].append(encoded_sent['input_ids'].squeeze())
        encoded_sentences['attention_mask'].append(encoded_sent['attention_mask'].squeeze())
        encoded_sentences['labels'].append(row['labels'])
    
    for key in ['input_ids', 'attention_mask']:
        encoded_sentences[key] = torch.stack(encoded_sentences[key])
    
    encoded_sentences['labels'] = torch.tensor(encoded_sentences['labels'], dtype=torch.long)
    
    return encoded_sentences

df_train = pd.read_csv('training--95-tgt.csv', sep='|', quoting=csv.QUOTE_NONE, encoding='utf-8')
df_test = pd.read_csv('test--95-tgt.csv', sep='|', encoding='utf-8')
df_val = pd.read_csv('validation--95-tgt.csv', sep='|', encoding='utf-8')

df_train['labels'] = df_train['labels'].map({'mt': 0, 'human': 1})
df_test['labels'] = df_test['labels'].map({'mt': 0, 'human': 1})
df_val['labels'] = df_val['labels'].map({'mt': 0, 'human': 1})

train_encoded = tokenize_sentences(df_train)
valid_encoded = tokenize_sentences(df_val)
test_encoded = tokenize_sentences(df_test)

train_dataset = CustomDataset(train_encoded)
valid_dataset = CustomDataset(valid_encoded)
test_dataset = CustomDataset(test_encoded)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from comet import download_model, load_from_checkpoint

model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)
encoder = model.encoder

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(input_dim, num_classes, bias=True)
        )

    def forward(self, x):
        return self.classifier(x)

classification_head = ClassificationHead(input_dim=1024, num_classes=2)

In [None]:
class FineTuneModel(nn.Module):
    def __init__(self, encoder, classification_head):
        super(FineTuneModel, self).__init__()
        self.encoder = encoder
        self.classification_head = classification_head

    def forward(self, input_ids, attention_mask):
        encoder_outputs = self.encoder.model(input_ids, attention_mask)

        hidden_states = encoder_outputs.last_hidden_state[:, 0, :]
        logits = self.classification_head(hidden_states)
        
        return logits

model = FineTuneModel(encoder, classification_head)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
from torcheval.metrics import BinaryAUROC

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
val_roc_auc_metric = BinaryAUROC()
test_roc_auc_metric = BinaryAUROC()
num_epochs = 5

#Training
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        labels = torch.nn.functional.one_hot(batch["labels"].long(), num_classes=2).float()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")
        


    # Validation
    model.eval() 
    val_loss = 0.0
    val_roc_auc_metric.reset()
    with torch.no_grad(): 
        for batch in valid_loader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            outputs = outputs.transpose(1, 0)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1] 
            val_roc_auc_metric.update(probs, labels)

    avg_val_loss = val_loss / len(valid_loader)
    val_roc_auc = val_roc_auc_metric.compute()
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_roc_auc:.4f}")
    
# Testing
model.eval()
test_loss = 0.0
test_roc_auc_metric.reset() 
with torch.no_grad(): 
    for batch in test_loader:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        outputs = outputs.transpose(1, 0)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        probs = torch.softmax(outputs, dim=1)[:, 1] 
        test_roc_auc_metric.update(probs, labels)

avg_test_loss = test_loss / len(test_loader)
test_roc_auc = test_roc_auc_metric.compute()
print(f"Test Loss: {avg_test_loss:.4f}, Test ROC AUC: {test_roc_auc:.4f}")