In [None]:
import csv
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AdamW
import random

import numpy as np
from datasets import load_metric

In [None]:
# model loading
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [None]:
# data loading, might take a bit
import math
SPLIT = 0.8

# read in and unpack data, ignoring publisher, year, and answer fields
data = csv.reader(open("./trainingdata/combined_2500.csv", 'r'))
data = [*data]
random.shuffle(data)
_, _, texts, _, tags = list(zip(*data))

# get sorted list of unique PTB tags used in the data
ptb_tags = list(set(tags))
ptb_tags.sort()

def tag_to_label(t):
    return ptb_tags.index(t)

def label_to_tag(l):
    return ptb_tags[l]

# convert tags to numbered labels
labels = list(map(tag_to_label, tags))


# train_texts, test_texts = list(text[:math.floor(SPLIT*len(text))]), list(text[math.floor(SPLIT*len(text)):])
# train_labels, test_labels = list(labels[:math.floor(SPLIT*len(text))]), list(labels[math.floor(SPLIT*len(text)):])

# generate train, test, validation sets
from sklearn.model_selection import train_test_split

train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2)
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.2)

# prepare token vectors
train_encodings = tokenizer(train_texts, padding=True, truncation=True, return_tensors='pt')
val_encodings = tokenizer(val_texts, padding=True, truncation=True, return_tensors='pt')
test_encodings = tokenizer(test_texts, padding=True, truncation=True, return_tensors='pt')

In [None]:
# creating pytorch dataset interfaces
import torch

class CrossWordDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, i):
        item = {key: val[i].clone().detach() for key, val in self.encodings.items()}
        print(item)
        item['labels'] = torch.tensor(self.labels[i])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = CrossWordDataset(train_encodings, train_labels)
val_dataset = CrossWordDataset(val_encodings, val_labels)
test_dataset = CrossWordDataset(test_encodings, test_labels)

In [None]:
# download model
model_type = "bert-base-cased" #could use "distilbert-base-cased"?

model = AutoModelForSequenceClassification.from_pretrained(model_type, num_labels=len(ptb_tags), problem_type="single_label_classification")

In [None]:
# create model with arguments

training_args = TrainingArguments(
    output_dir="test_trainer",
    evaluation_strategy="epoch"
    )
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
    )

In [None]:
# hopefully this stops printing
from transformers.trainer_callback import PrinterCallback
trainer.remove_callback(PrinterCallback)

# train model on input datasets
# about 20 minutes for bert-base-cased

train_output = trainer.train()

In [None]:
# use model to predict test set labels
predicted = trainer.predict(test_dataset)
predicted_labels = np.argmax(predicted.predictions, axis=-1)

# generate f1 scores for multiclass classification
from sklearn.metrics import f1_score
for av in ["micro", "macro", "weighted"]:
    print(f"F1 {av}: {f1_score(test_labels, predicted_labels, average=av)}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from collections import Counter

matrix_size = 5
most_common = [label for (label, count) in Counter(predicted_labels).most_common(matrix_size)]
disp_test, disp_predictions = [], []

for i in range(len(test_labels)):
    if test_labels[i] in most_common:
        disp_test.append(most_common.index(test_labels[i]))
    else:
        disp_test.append(matrix_size+1)
    if predicted_labels[i] in most_common:
        disp_predictions.append(most_common.index(predicted_labels[i]))
    else:
        disp_predictions.append(matrix_size+1)
        
disp_tags = [label_to_tag for label in most_common ] + ['other']

fig, ax = plt.subplots(figsize=(10, 10))
disp = ConfusionMatrixDisplay.from_predictions(disp_test, disp_predictions, normalize='all', labels=list(np.arange(matrix_size+1)), display_labels=disp_tags, include_values=False, ax=ax)
plt.savefig('test', dpi=300)