In [None]:
import datasets
from transformers import AutoTokenizer, BertModel
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset = datasets.load_dataset("universal_dependencies", "en_ewt")

In [None]:
idx2upos = dataset['train'].features['upos'].feature.int2str

In [None]:
unique_tags = set(tag for example in dataset["train"]["upos"] for tag in example)
print("Unique POS tags:", unique_tags) # This is missing 15, this caused issues later haha
print("Unique POS tags:", {idx2upos(tag) for tag in unique_tags})

In [None]:
# This section has to be modified to test each model

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") 
bert = BertModel.from_pretrained("bert-base-uncased") # Change this line to the trained model you want to test
bert.to(device)

In [None]:
def tokenize_and_align_labels(examples, label_all_tokens=False, skip_index=-100):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding=True)
    labels = []

    for i, label in enumerate(examples["upos"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids : list[int] = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(skip_index)

            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])

            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else skip_index)

            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
num_samples = 5

In [None]:
samples = tokenize_and_align_labels(dataset['train'][:num_samples])

In [None]:
samples['tokens_hr'] = [tokenizer.convert_ids_to_tokens(id) for id in samples['input_ids']]

In [None]:
samples['labels_hr'] = [[idx2upos(label) if label != -100 else ' ' for label in sent] for sent in samples['labels']]

In [None]:
tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)

In [None]:
def pad_collate_fn(batch):
    # All this function does is use the pad_sequence to pad all sentences in a batch to the same length

    input_ids = [torch.tensor(x["input_ids"], device=device) for x in batch] # Cast to tensor to use with pad_sequence function
    attention_mask = [torch.tensor(x["attention_mask"], device=device) for x in batch]
    labels = [torch.tensor(x["labels"], device=device) for x in batch]
    # Pad sequences to the same length
    return {
        "input_ids": torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id),
        "attention_mask": torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0), #Ignore pads
        "labels": torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100),
    }

In [None]:
train_dataloader = DataLoader(tokenized_dataset['train'], batch_size=128, shuffle=True, collate_fn=pad_collate_fn)
val_dataloader = DataLoader(tokenized_dataset['validation'], batch_size=1, shuffle=True, collate_fn=pad_collate_fn)
test_dataloader = DataLoader(tokenized_dataset['test'], batch_size=128, shuffle=True, collate_fn=pad_collate_fn)

In [None]:
class NeuralTagger(nn.Module):
    def __init__(self, bert, output_size):
        super(NeuralTagger, self).__init__()
        self.bert = bert
        self.linear = nn.Linear(768, output_size)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hs = outputs.last_hidden_state
        return self.linear(hs)

In [None]:
# Freeze BERT bc this already takes too long to run

for param in bert.parameters():
    param.requires_grad = False

In [None]:
model = NeuralTagger(bert, len(unique_tags) + 1) # Add one because 15 wasn't in the unique_tags set
model.to(device)

In [None]:
# Some hyperparams

lr = 0.001
num_epochs = 5

In [None]:
# Define the optimizer and the loss_fn

optim = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)

In [None]:
train_losses = []
val_accs = []
for epoch in tqdm(range(num_epochs), position=0):
    train_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, position=1, leave=False):
        optim.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        targets = batch['labels']
        predictions = model(input_ids, attention_mask) # Use the model to get y_pred

        predictions = predictions.view(-1, predictions.shape[-1]) # Get the right shape
        targets = targets.view(-1) # Same here

        loss = loss_fn(predictions, targets) # Calculate loss
        loss.backward()
        optim.step()

        train_loss += loss.item()
    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    total = 0
    correct = 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_dataloader, position=1):
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            targets = batch['labels']
            predictions = model(input_ids, attention_mask)

            predictions = predictions.argmax(dim=-1) # Calculate the prediction of the model
            mask = targets != -100 # Create a mask tesnro for all tokens that are relevant [1, 0, 1, 0, 1, 1, 1]
            agreement = (predictions == targets) # Create a tensor for wherever the pred and the targets are the same [0, 1, 1, 1, 0 , 1, 0]
            correct += (agreement * mask).sum().item() # Multiply the agreement tensor by mask so that only relevant tokens that are the same are 1
            total += mask.sum().item() # Add all the ones and take the scalar from the tensor
    val_accs.append(correct/total)

    print(f'Epoch {epoch}: Train Loss {round(train_loss, 2)}, Val accuracy {round(val_accs[-1], 2)}')

In [None]:
fig, ax1 = plt.subplots()

epochs_range = range(1, num_epochs + 1)

ax1.plot(epochs_range, train_losses, color='blue', label='Training Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Training Loss')
ax1.set_xticks(epochs_range)

ax2 = ax1.twinx()
ax2.plot(epochs_range, val_accs, color='red', label='Validation Accuracy')
ax2.set_ylabel('Validation Accuracy')
ax2.set_ylim(0, 1)

fig.legend()

plt.title("Training Loss and Validation Accuracy Curves")
plt.show()