In [None]:
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, MegaForSequenceClassification
import torchinfo
import os
import pickle

In [None]:
device = torch.device("cpu")


In [None]:
tokenizer = AutoTokenizer.from_pretrained("mnaylor/mega-base-wikitext")
model = MegaForSequenceClassification.from_pretrained(
    "mnaylor/mega-base-wikitext")
model.to(device)

torchinfo.summary(model)

In [None]:
print(model)

In [None]:
import pandas as pd

data_check_test = load_dataset(
    "csv", data_files="datasets/Hatemoji-main/HatemojiCheck/test.csv"
    )
print(data_check_test['train'][0])

# Dataset only have training data, hence split it into train/test
data_check_test = data_check_test['train'].train_test_split(test_size=0.25)


In [None]:
def tokenize_func(examples):
    return tokenizer(
        examples["text"], padding="max_length", truncation=True
        )

tokenized_data = data_check_test.map(tokenize_func, batched=True)


In [None]:
# Prepare for torch
from torch.utils.data import DataLoader

tokenized_data = tokenized_data.remove_columns([
    "text", "case_id", "templ_id", "test_group_id", "target", "functionality", "set", "unrealistic_flags", "included_in_test_suite"])

tokenized_data = tokenized_data.rename_column("label_gold", "labels")
tokenized_data.set_format("torch")

print(tokenized_data)


In [None]:

train_dataloader = DataLoader(
    tokenized_data['train'], batch_size=64, shuffle=True)
test_dataloader = DataLoader(tokenized_data['test'], batch_size=64)

### fine-tuning

In [None]:
from tqdm.auto import tqdm
import evaluate
import matplotlib.pyplot as plt
from transformers import get_scheduler
from torch.optim import AdamW
import copy

def train_epoch(model, train_dataloader, optimizer, lr_scheduler):
    progress_bar = tqdm(range(len(train_dataloader)))
    metric = evaluate.load("accuracy")
    model.train()
    criterion = torch.nn.CrossEntropyLoss()

    epoch_loss = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        logits = outputs.logits
        loss = criterion(logits, batch["labels"])
        
        loss.backward()
        optimizer.step()
        
        optimizer.zero_grad()
        
        predictions = torch.argmax(logits, dim=-1)
        
        metric.add_batch(predictions=predictions, references=batch["labels"])

        epoch_loss += loss.item()
        progress_bar.update(1)
    
    lr_scheduler.step()
    epoch_loss /= len(train_dataloader)
    accuracy = metric.compute()['accuracy']

    return epoch_loss, accuracy

def eval(model, test_dataloader):
    progress_bar = tqdm(range(len(test_dataloader)))
    metric = evaluate.load("accuracy")
    model.eval()
    model.mega.requires_grad_(False)
    criterion = torch.nn.CrossEntropyLoss()

    epoch_loss = 0
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        loss = criterion(logits, batch["labels"])

        predictions = torch.argmax(logits, dim=-1)

        metric.add_batch(predictions=predictions, references=batch["labels"])

        epoch_loss += loss.item()
        progress_bar.update(1)

    epoch_loss /= len(test_dataloader)
    accuracy = metric.compute()['accuracy']
    
    return epoch_loss, accuracy

def train(model,
          train_dataloader,
          test_dataloader,
          num_epochs=2,
          learning_rate=5e-5,
          patience=4):

    # count epochs where the model didn't improve
    counter = 0
    best_val_acc = 0
    best_epoch = 0
    best_model = None

    optimizer = AdamW(model.parameters(), lr=learning_rate)

    lr_scheduler = get_scheduler(
        name="polynomial", optimizer=optimizer, num_warmup_steps=1,     num_training_steps=num_epochs*len(train_dataloader)
    )
    train_accuracies = []
    val_accuracies = []
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(
            model, train_dataloader, optimizer, lr_scheduler)
        val_loss, val_acc = eval(model, test_dataloader)

        print(f"Epoch {epoch+1} accuracy: train={train_acc:.3f}, test={val_acc:.3f}")

        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # Unfreeze MEGA
        if epoch >= 1:
            model.mega.requires_grad_(True)
            
        # early stopping
        if val_acc <= best_val_acc + 1e-4:
            counter += 1
            if counter >= patience:
                break
        else:
            best_val_acc = val_acc
            counter = 0
            best_epoch = epoch
            best_model = copy.deepcopy(model.state_dict())

    return train_accuracies, val_accuracies, train_losses, val_losses, best_epoch, best_model

model.mega.requires_grad_(False)
N_EPOCHS = 5
train_acc, val_acc, train_losses, val_losses, best_epoch, best_model = train(model, train_dataloader, test_dataloader, N_EPOCHS, learning_rate=1e-4)

all_data = (train_acc, val_acc, train_losses, val_losses, best_epoch, best_model)
with open(f"mega_hatemoji.obj", "wb") as f:
    pickle.dump(all_data, f)

In [None]:
with open(f"mega_hatemoji.obj", "rb") as f:
    train_acc, val_acc, train_losses, val_losses, best_epoch, best_model = pickle.load(f)

fig, ax = plt.subplots(1,2, figsize=(10,4))
epoch_axis = range(1, len(train_losses)+1)
ax[0].plot(epoch_axis, train_losses, label='train')
ax[0].plot(epoch_axis, val_losses, label='val')
ax[0].axvline(best_epoch+1, label='best', linestyle='--')
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[0].set_xticks(epoch_axis)
ax[0].legend()

ax[1].plot(epoch_axis, train_acc, label='train')
ax[1].plot(epoch_axis, val_acc, label='val')
ax[1].axvline(best_epoch+1, label='best', linestyle='--')
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
ax[1].set_xticks(epoch_axis)
ax[1].legend()
plt.show()

### Evaluating

In [None]:

best_train_accs = []
best_train_losses = []
best_val_accs = []
best_val_losses = []

best_epochs = []

all_train_accs = []
all_train_losses = []
all_val_accs = []
all_val_losses = []

N_SEEDS = 5
N_EPOCHS = 20
for seed in tqdm(range(N_SEEDS)):
    model = MegaForSequenceClassification.from_pretrained(
    "mnaylor/mega-base-wikitext")
    model.to(device)
    model.mega.requires_grad_(False)
      
    train_acc, val_acc, train_losses, val_losses, best_epoch, _ = train(model, train_dataloader, test_dataloader, N_EPOCHS, learning_rate=1e-4, patience=2)

    best_train_accs.append(train_acc[best_epoch])
    best_val_accs.append(val_acc[best_epoch])
    best_train_losses.append(train_losses[best_epoch])
    best_val_losses.append(val_losses[best_epoch])
    best_epochs.append(best_epoch)

    all_train_accs.append(train_acc)
    all_train_losses.append(train_losses)
    all_val_accs.append(val_acc)
    all_val_losses.append(val_losses)

all_data_seeds = (best_train_accs, best_val_accs, best_train_losses, best_val_losses, best_epochs, all_train_accs, all_train_losses, all_val_accs, all_val_losses)

with open(f"mega_hatemoji_seeds.obj", "wb") as f:
    pickle.dump(all_data_seeds, f)

In [None]:


with open(f"mega_hatemoji_seeds.obj", "rb") as f:
    best_train_accs, best_val_accs, best_train_losses, best_val_losses, best_epochs, all_train_accs, all_train_losses, all_val_accs, all_val_losses = pickle.load(f)

# Plot with epochs
def aggregate_over_seeds(data, epochs):
    means = []
    mins = []
    maxs = []
    padded_data = []
    for seed_data in data:
        # some seeds may have stopped early
        # hence pad with edge values
        padded_data.append(np.pad(seed_data, (0, epochs-len(seed_data)), 'edge'))
    for epoch in range(N_EPOCHS):
        epoch_data = torch.Tensor([seed[epoch] for seed in padded_data])
        means.append(torch.mean(epoch_data))
        mins.append(torch.min(epoch_data))
        maxs.append(torch.max(epoch_data))
    
    return means, mins, maxs

mean_train_accs, min_train_accs, max_train_accs = aggregate_over_seeds(all_train_accs, N_EPOCHS)
mean_val_losses, min_val_losses, max_val_losses= aggregate_over_seeds(all_val_losses, N_EPOCHS)
mean_val_accs, min_val_accs, max_val_accs = aggregate_over_seeds(all_val_accs, N_EPOCHS)
mean_train_losses, min_train_losses, max_train_losses = aggregate_over_seeds(all_train_losses, N_EPOCHS)

os.makedirs(f"figures/", exist_ok=True)

fig, ax = plt.subplots(1,2, figsize=(10,4))
epoch_axis = range(1, N_EPOCHS+1)
ax[0].plot(epoch_axis, mean_train_losses, color='C0', label='train')
ax[0].fill_between(epoch_axis, min_train_losses, max_train_losses, color='C0', alpha=0.3)
ax[0].plot(epoch_axis, mean_val_losses, color='C1', label='val')
ax[0].fill_between(epoch_axis, min_val_losses, max_val_losses, color='C1', alpha=0.3)
ax[0].axvline(best_epoch+1, label='best', linestyle='--')
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
#ax[0].set_xticks(epoch_axis)
ax[0].grid()
ax[0].legend()

ax[1].plot(epoch_axis, mean_train_accs, color='C0', label='train')
ax[1].fill_between(epoch_axis, min_train_accs, max_train_accs, color='C0', alpha=0.3)
ax[1].plot(epoch_axis, mean_val_accs, color='C1', label='val')
ax[1].fill_between(epoch_axis, min_val_accs, max_val_accs, color='C1', alpha=0.3)
ax[1].axvline(best_epoch+1, label='best', linestyle='--')
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
#ax[1].set_xticks(epoch_axis)
ax[1].grid()
ax[1].legend()
fig.suptitle(f"Training on Hatemoji averaged across {N_SEEDS} seeds")
plt.tight_layout()
plt.savefig(f"figures/epochs_{N_SEEDS}seeds")
plt.show()

print(f"Best (mean) validation accuracy: {np.mean(best_val_accs)}")

In [None]:
# Only consider best epochs

def plot_errorbar(ax, name: str, data, color):
    bottom = torch.min(data)
    middle = torch.mean(data)
    height = torch.max(data) - middle
    #err = np.abs(torch.tensor([torch.min(data), torch.max(data)]) - torch.mean(data)).unsqueeze(1)
    #ax.errorbar(name, torch.mean(data), yerr=err, fmt='o', color=color, capsize=12)
    ax.bar(name, middle-bottom, bottom=bottom, color=color, width=0.5, ec='k')
    ax.bar(name, height, bottom=middle, color=color, width=0.5, ec='k')
    ax.scatter([name for _ in range(len(data))], data, alpha=0.5, color='k')

fig, ax = plt.subplots(1,3, figsize=(10,4))
ax[0].hist(best_epochs, bins=np.arange(1, N_EPOCHS+1), ec='k')
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Count")

plot_errorbar(ax[1], 'Train accuracy', torch.Tensor(best_train_accs), 'indianred')
plot_errorbar(ax[1], 'Validation accuracy', torch.Tensor(best_val_accs), 'indianred')

plot_errorbar(ax[2], 'Train loss', torch.Tensor(best_train_losses), 'darkorange')
plot_errorbar(ax[2], 'Validation loss', torch.Tensor(best_val_losses), 'darkorange')

fig.suptitle("Best epochs from 10 seeds")
plt.tight_layout()
plt.savefig(f"figures/best_epochs_from_10seeds")
plt.show()


In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

def eval_confusion(model, test_dataloader):
    progress_bar = tqdm(range(len(test_dataloader)))
    model.eval()
    confusion_metric = evaluate.load("confusion_matrix")

    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        confusion_metric.add_batch(predictions=predictions, references=batch["labels"])

        progress_bar.update(1)

    confusion_matrix = confusion_metric.compute()["confusion_matrix"]

    return np.array(confusion_matrix)

model.load_state_dict(best_model)
confusion_matrix = eval_confusion(model, test_dataloader)
confusion_matrix_normalized = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]

disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix_normalized, display_labels=('not hate','hate'))
disp.plot(cmap=plt.cm.binary)
print(confusion_matrix_normalized)

plt.savefig(f"hatemoji_confusion.png", dpi=300, bbox_inches='tight')