In [None]:
import numpy as np

from tokenization import make_cfn
from model_wrappers import Classifier
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from random import seed, shuffle
from sklearn.metrics import confusion_matrix, f1_score
import seaborn as sns
from matplotlib import pyplot as plt
import pickle

from tqdm import tqdm

DEVICE = 'cuda'
NUM_EPOCHS = 15

with open('../data/tokenized.p', 'rb') as f:
    (tokenized, classes) = pickle.load(f)
tokenized = [(nom, tokens, label) for nom, tokens, label in tokenized if len(tokens) <= 120]
seed(42)
shuffle(tokenized)
print(f'Read {len(tokenized)} data points with {(nc := len(set(y for _, _, y in tokenized)))} classes.')
split_point = int(len(tokenized) * 0.75)
train, dev = tokenized[:split_point], tokenized[split_point:]
print(f'Training on {len(train)} entries, evaluating on {len(dev)}.')

train_dl = DataLoader([(ts, y) for _, ts, y in train], batch_size=32, shuffle=True, collate_fn=make_cfn(DEVICE))  # noqa
dev_dl = DataLoader([(ts, y) for _, ts, y in dev], batch_size=512, shuffle=False, collate_fn=make_cfn(DEVICE))  # noqa

output = []
f1s = []
for it in range(3):
    model = Classifier(nc, False).to(DEVICE)
    optim = AdamW(model.parameters(), lr=5e-5)
    best_loss, best_epoch = 1e10, 0
    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        _ = model.train_epoch(train_dl, optim)
        with torch.no_grad():
            model.eval()
            epoch_loss, _, _ = model.eval_epoch(dev_dl)
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            best_epoch = epoch
            torch.save(model.state_dict(), f'../data/weights_{it}.pt')

    print(f'Best epoch was {best_epoch}')
    model.load_state_dict(torch.load(f'../data/weights_{it}.pt'))
    with torch.no_grad():
        model.eval()
        epoch_loss, predictions, truth = model.eval_epoch(dev_dl)
    print(f'\tDev loss: {epoch_loss}')
    print(f'\tDev F1 (M): {f1_score(truth, predictions, average="macro")}')
    print(f'\tDev F1 (m): {f1_score(truth, predictions, average="micro")}')
    print(f'\tDev F1 (-): {(f1 := f1_score(truth, predictions, average=None))}')
    f1s.append(f1)
    output.append(predictions)


with open('../data/results.pt', 'wb') as f:
    pickle.dump((dev, output), f)

In [None]:
f1s = np.array(f1s)

f1_means = f1s.mean(axis=0)
f1_stds = f1s.std(axis=0)
print(f1_means * 100)
print(f1_stds * 200)

In [None]:
cms = np.array([confusion_matrix(truth, output[i]) for i in range(len(output))])
mus = cms.mean(0)
stds = cms.std(0)
labels = [f'{mu:.0f}±{std:.0f}' for mu, std in zip(mus.flatten(), stds.flatten())]
labels = np.asarray(labels).reshape(mus.shape)
plt.figure(figsize=(6, 6))
sns.heatmap(
    mus,
    annot=labels,
    fmt='',
    cmap=sns.light_palette('seagreen', as_cmap=True),
    xticklabels=classes,
    yticklabels=classes,
    annot_kws={'size': 8},
    cbar=False,
)
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.tight_layout()
plt.savefig('../data/trained.pdf')
plt.show()
