# MNLP class Homework 2: Coarse-Grained WSD
##Notebook used for model training, evaluation and testing
###Ludovico Comito Matr. 1837155

## Google Drive mounting and libraries imports

Install libraries

In [None]:
!pip install transformers lightning wandb evaluate --quiet

Import libraries.
_Notice:_ this project is based on the Pytorch Lightning framework.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset

import transformers
from transformers import AutoTokenizer, AutoModel
from transformers import get_linear_schedule_with_warmup

import evaluate
import wandb

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

import pandas as pd
import numpy as np
import os
import json
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Here are defined all the hyperparameters and parameters for this project:

In [5]:
parameters_dict = {
  'num_epochs':100,
  'batch_size':32,
  'padding_length':150,
  'transformer_name':'microsoft/deberta-base', # specify the transformer to be used from the HuggingFace library.
  'learning_rate':0.01,
  'num_warmup_steps':5, # the number of warmup steps for the lr scheduler
  'run_name':'deberta-base', # sets the run name on W&B
  'debug_mode':True, # switch to False to enable logging
  'checkpoint_dir':'/content/drive/MyDrive/mnlp_homework_2/checkpoints',
}

### Init Weights and Biases

In [None]:
!wandb login

In [None]:
# start a new wandb run if not in debug mode
if not parameters_dict['debug_mode']:
  wandb.init(
      # set the wandb project where this run will be logged
      project='mnlp_homework_2',
      name = parameters_dict['run_name'],
      config=parameters_dict,
      entity='ludocomito'
  )

## Preparing the dataset

In [6]:
# Read dataframes from source
train_df = pd.read_json('/content/drive/MyDrive/mnlp_homework_2/Datasets/coarse-grained/train_coarse_grained.json')
val_df = pd.read_json('/content/drive/MyDrive/mnlp_homework_2/Datasets/coarse-grained/val_coarse_grained.json')
test_df = pd.read_json('/content/drive/MyDrive/mnlp_homework_2/Datasets/coarse-grained/test_coarse_grained.json')

In [None]:
# Collect all the coarse-grained senses from the mapping file
mapping_file = open('/content/drive/MyDrive/mnlp_homework_2/Datasets/coarse_fine_defs_map.json')

data = json.load(mapping_file)

# Iterating through the json
senses_list = []
for sense in data:
  senses_list.append(sense)

mapping_file.close()

num_senses = len(senses_list)
print(f'There are {num_senses} total senses')

In [8]:
def index_senses(senses_list):
  '''
  Assigns an index to each individual sense by creating the sense2idx and idx2sense dictionaries.
  Moreover, it returns the total number of individual senses.
  Args:
    senses_list: a list of lists of dictionaries which maps tokens to their senses.
  '''
  sense2idx = {}
  padding_token = '<PAD>'
  sense2idx[padding_token] = 0

  index = 1

  for sense in senses_list:
    if sense not in sense2idx:
      sense2idx[sense] = index
      index = index + 1
    else:
      raise Exception("There are duplicates in the mapping file!")

  idx2sense =  {value: key for key, value in sense2idx.items()}

  return sense2idx, idx2sense, index


In [9]:
# Index senses and get the number of total senses.
sense2idx, idx2sense, num_senses = index_senses(senses_list)

### Retrieve train, validation and test data from the datasets.

In [10]:
def process_dataframe(df):
  '''
  Given the raw dataframe, returns the sentences, sentences, candidates and target words associated to it.
  Target words are simply a list of lists, where each element is the index of the words to disambiguate.
  Args:
    df: a raw dataframe from the original json dataset.
  '''
  sentences_list = []
  senses_list = []
  candidates_list = []
  target_words = []

  for elem in df:
    data = df[elem]
    sentences_list.append(data['words'])
    senses_list.append(data['senses'])
    candidates = data['candidates']
    candidates_list.append(candidates)
    targets = []
    for idx in candidates.keys():
      targets.append(int(idx))
    target_words.append(targets)


  return {'sentences_list': sentences_list, 'senses_list': senses_list, 'candidates_list': candidates_list, 'target_words':target_words}

In [11]:
train_data = process_dataframe(train_df)

train_sentences_list = train_data['sentences_list']
train_senses_list = train_data['senses_list']
train_candidates_list = train_data['candidates_list']
train_target_words = train_data['target_words']

In [12]:
val_data = process_dataframe(val_df)

val_sentences_list = val_data['sentences_list']
val_senses_list = val_data['senses_list']
val_candidates_list = val_data['candidates_list']
val_target_words = val_data['target_words']

In [13]:
test_data = process_dataframe(test_df)

test_sentences_list = test_data['sentences_list']
test_senses_list = test_data['senses_list']
test_candidates_list = test_data['candidates_list']
test_target_words = test_data['target_words']

### Data tokenization

In [14]:
# Initialize the tokenizer for HuggingFace
tokenizer = AutoTokenizer.from_pretrained(parameters_dict['transformer_name'], add_prefix_space=True)

Downloading (…)okenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/474 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

All the sentences are tokenized a priori in order to reduce the overhead while trainig.

In [15]:
tokenized_training = tokenizer(train_sentences_list,
                               return_tensors='pt',
                               padding='max_length',
                               max_length=parameters_dict['padding_length'],
                               truncation=True,
                               is_split_into_words=True)

tokenized_val = tokenizer(val_sentences_list,
                          return_tensors='pt',
                          padding='max_length',
                          max_length=parameters_dict['padding_length'],
                          truncation=True,
                          is_split_into_words=True)

tokenized_test = tokenizer(test_sentences_list,
                           return_tensors='pt',
                           padding='max_length',
                           max_length=parameters_dict['padding_length'],
                           truncation=True,
                           is_split_into_words=True)

###Retrieve word_ids for all sentences
For transformers like bert-base-uncased, during tokenization words are split in sub-words tokens. This means that when assigning the labels, we want to keep track of a mapping between the sub-tokens and the words they belong to. Fortunately, the HuggingFace's tokenizer class keeps track of the mapping between sub-tokens and words inside the word_ids property.

In [16]:
training_word_ids = [
tokenized_training[i].word_ids
for i in range(len(tokenized_training["input_ids"]))
]

val_word_ids = [
tokenized_val[i].word_ids
for i in range(len(tokenized_val["input_ids"]))
]

test_word_ids = [
tokenized_test[i].word_ids
for i in range(len(tokenized_test["input_ids"]))
]

In [17]:
def pad_target_words(lists, padding_length):
  '''
  Creates padding for target words list.
  Args:
    lists: the list of lists of target words.
    padding_length: the desired padding length.
  '''
  max_length = padding_length
  padded_lists = [torch.tensor(inner_list + [-1] * (max_length - len(inner_list))) for inner_list in lists]
  return padded_lists

In [18]:
train_target_words = pad_target_words(train_target_words, parameters_dict['padding_length'])
val_target_words = pad_target_words(val_target_words, parameters_dict['padding_length'])
test_target_words = pad_target_words(test_target_words, parameters_dict['padding_length'])

## Create labels and logits_mask

The following function implements one of the key concepts of this project. In this task we basically have to classify the correct sense among a huge collection of possible senses. Actually, we are given also a list of candidates of possible candidate senses for each word with a sense. This will be used in order to narrow the senses to confront with when computing the loss.<br> The practical implementation is achieved by creating a logits mask. <br>Each logits mask has length equal to all possible senses and is initialized to all zeros. Each entry of this list corresponds to the index of a sense, and only the associated candidates will have their values set to 1. We will use this mask to mask out the returned logits from our model.

In [19]:
def create_labels_and_mask(current_sentence, sentence_labels, sentence_candidates, sense2idx, padding_length):
  '''
  Given a sentence returns its corresponding labels list and logits mask.
  Args:
    current_sentence: the sentence expressed as a list of tokenized words.
    sentence_labels: the labels associated to that sentence in the form of a dict.
    sentence_candidates: the sentence candidates in the form of a dict.
    sense2idx: the dictionary which indexes senses.
    padding_length: the desired padding length.
  '''
  sentence_len = len(current_sentence)
  num_senses = len(sense2idx)


  if sentence_len > padding_length:
      current_sentence = current_sentence[:padding_length]
      sentence_len = padding_length


  labels = torch.zeros(padding_length, dtype=torch.int32)  # Pad labels
  logits_mask = torch.zeros([padding_length, num_senses], dtype=torch.int32)  # Pad logits

  for word_index in range(sentence_len): # Create mask
      if str(word_index) in sentence_labels and word_index < padding_length:
          labels[word_index] = sense2idx[sentence_labels[str(word_index)][0]]

          word_candidate_senses = sentence_candidates[str(word_index)]
          for candidate in word_candidate_senses:
              sense_index = sense2idx[candidate]
              logits_mask[word_index][sense_index] = 1

  return labels, logits_mask


In [20]:
def replace_all_occurrences(input_list, old_value, new_value):
  '''
  Utility function to replace all occurrencies of a certain value (old_value)
  with a specificed new_value in the input_list.
  '''
  return [new_value if item == old_value else item for item in input_list]

In [21]:
class HomonymyDataset(Dataset):
  '''
  The dataset class used for train, val and test. Returns batches of tokenized sentences
  together with their processed labels, target words and logits masks.
  Args:
    sentences: list of lists of tokenized sentences.
    senses: list of dicts of senses.
    candidates: list of dicts of candidates.
  '''
  def __init__(self, sentences, senses, candidates, sense2idx, idx2sense, num_senses,
               tokenized_sentences, attention_masks, word_ids, target_words, padding_length):
    self.sentences = sentences
    self.candidates = candidates
    self.senses = senses
    self.sense2idx = sense2idx
    self.idx2sense = idx2sense
    self.num_senses = num_senses

    self.tokenized_sentences = tokenized_sentences
    self.attention_masks = attention_masks

    self.word_ids = word_ids
    self.padding_length = padding_length
    self.target_words = target_words

  def __getitem__(self, index):
    current_sentence = self.sentences[index]
    current_senses = self.senses[index]
    current_candidates = self.candidates[index]
    target_words = self.target_words[index]
    tokens = self.tokenized_sentences[index]
    attention_mask = self.attention_masks[index]
    word_ids = self.word_ids[index]

    labels, logits_mask = create_labels_and_mask(current_sentence, current_senses, current_candidates, self.sense2idx, self.padding_length)

    word_ids_updated = torch.Tensor(replace_all_occurrences(word_ids, None, -1)) #replaces all None occurrencies with the special index -1

    return {'tokenized_sentences': tokens, 'attention_masks': attention_mask,'labels': labels, 'logits_mask': logits_mask,
            'word_ids': word_ids_updated, 'sentence_length':len(current_sentence), 'target_words':target_words}

  def __len__(self):
    return len(self.sentences)



## Model definition

The classifier's architecture is constituted by two fundamental modules:
*   The *transformer module*, which is derived from HuggingFace AutoModel library is used to create the embeddings for the input tokens.
*   The *classification head*, an MLP module which takes as input the generated embeddings and outputs the final logits.

Although the implementation might seem substantially simple, there are a few tricks to notice:


*   Instead of considering just the embeddings at the last hidden layer, the average of the last four hidden layer is taken. This has proven to improve the model's perfomance in many cases, including this task (look at the experimental data in the report).
*   The encoder takes as input sub-word level tokens. In order to keep track of the embeddings for the entire word, the word_ids for each sentence are used to take the compute the average embeddings of the sub-words for each words.
*   In order to better fine-tune the model, only the last four hidden layers of the transformers are unfrozen, in order to leverage the low level knowledge which is already embedded in pre-trained BERT-like models.







In [22]:
def find_indices(word_ids):
  '''
  Given the word_ids of a certain sentence, it creates a dict which maps each full word index
  to the indexes of it's sub-tokens. Ex: if word_ids=[0,1,1,2], the output will be {0: [0], 1: [1, 2], 2: [3]}
  Args:
    word_ids: the word_ids list for a certain sentence.
  '''
  idx_dict = {}
  for i in range(len(word_ids)):
    elem = word_ids[i].item()
    if elem != -1: # filters padding
      if elem not in idx_dict:
        idx_dict[elem] = [i]
      else:
        idx_dict[elem].append(i)
  return idx_dict

In [40]:
class HomonymyAverage(nn.Module):
  def __init__(self, transformer_name, num_classes, padding_length):
    super(HomonymyAverage, self).__init__()

    self.transformer_model = AutoModel.from_pretrained(transformer_name,output_hidden_states=True)
    self.padding_length = padding_length

    # Unfreeze the last four hidden layers
    for name, param in self.transformer_model.named_parameters():
      if 'layer.8' in name or 'layer.9' in name or 'layer.10' in name or 'layer.11' in name:
          param.requires_grad = True
      else:
          param.requires_grad = False

    # Get the size of transformer's embedding
    self.hidden_size = self.transformer_model.config.hidden_size

    # The final classifier head
    self.MLP = nn.Sequential(
      nn.Dropout(0.2),  # Adding dropout for regularization
      nn.Linear(self.hidden_size, num_classes),
    )

    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  def forward(self, tokenized_input, attention_mask, word_ids, target_words):
        encoded_input = self.transformer_model(input_ids=tokenized_input, attention_mask=attention_mask)
        hidden_states = encoded_input['hidden_states']

        # Average the last four hidden states
        last_four_layers = torch.stack(hidden_states[-4:], dim=0)
        hidden_states = torch.mean(last_four_layers, dim=0)

        # Operation 1: Average embeddings for interesting words
        averaged_embeddings = [] # this will contain the averaged embeddings of the subtokens for each word

        total_idxs = []

        # map indexes for each sentence in batch using the find_indices function
        for elem in word_ids:
          total_idxs.append(find_indices(elem))

        sentences_embeddings = []

        # averages the embeddings for each sub-word in a word, to create word embeddings.
        for i in range(len(total_idxs)):
          idx_dict = total_idxs[i]
          averaged_embeddings = [] # sub-token embeddings for each word are appended

          for key in idx_dict.keys():
            word_token_embeddings = []
            word_subtokens_idxs = idx_dict[key]

            # collect all the hidden states for current word subtokens
            for subtoken_idx in word_subtokens_idxs:
              subtoken_embedding = hidden_states[i][subtoken_idx]
              word_token_embeddings.append(subtoken_embedding)

            # stack the embeddings and append their mean to the averaged_embeddings list
            word_token_embeddings = torch.stack(word_token_embeddings).to(self.device)
            averaged_embeddings.append(torch.mean(word_token_embeddings, dim=0).to(self.device)) # appends the average for the tokens

          # Pads the new generated embeddings
          for i in range (self.padding_length - len(averaged_embeddings)):
            averaged_embeddings.append(torch.zeros(self.hidden_size).to(self.device))
          sentences_embeddings.append(torch.stack(averaged_embeddings))
        sentences_embeddings = torch.stack(sentences_embeddings)

        logits = self.MLP(sentences_embeddings)

        return logits

## Lightning datamodule

In [41]:
class HomonymyDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_sentences_list, train_senses_list, train_candidates_list, tokenized_training, training_attention_mask, training_word_ids, train_target_words,
                 test_sentences_list, test_senses_list, test_candidates_list, tokenized_test, test_attention_mask,test_word_ids, test_target_words,
                 val_sentences_list, val_senses_list, val_candidates_list, tokenized_val, val_attention_mask ,val_word_ids, val_target_words,
                 tokenizer, sense2idx, idx2sense, batch_size, num_senses, padding_length):
      super().__init__()

      self.train_sentences_list = train_sentences_list
      self.train_senses_list = train_senses_list
      self.train_candidates_list = train_candidates_list
      self.tokenized_training = tokenized_training
      self.training_attention_mask = training_attention_mask
      self.training_word_ids = training_word_ids
      self.train_target_words = train_target_words

      self.test_sentences_list = test_sentences_list
      self.test_senses_list = test_senses_list
      self.test_candidates_list = test_candidates_list
      self.tokenized_test = tokenized_test
      self.test_attention_mask = test_attention_mask
      self.test_word_ids = test_word_ids
      self.test_target_words = test_target_words

      self.val_sentences_list = val_sentences_list
      self.val_senses_list = val_senses_list
      self.val_candidates_list = val_candidates_list
      self.tokenized_val = tokenized_val
      self.val_attention_mask = val_attention_mask
      self.val_word_ids = val_word_ids
      self.val_target_words = val_target_words

      self.tokenizer = tokenizer
      self.sense2idx = sense2idx
      self.idx2sense = idx2sense
      self.batch_size = batch_size
      self.num_senses = num_senses
      self.padding_length = padding_length

    def setup(self, stage: str):
      self.train_dataset = HomonymyDataset(self.train_sentences_list, self.train_senses_list, self.train_candidates_list,
                                           self.sense2idx, self.idx2sense, self.num_senses,
                                           self.tokenized_training, self.training_attention_mask,self.training_word_ids, self.train_target_words, self.padding_length)

      self.test_dataset = HomonymyDataset(self.test_sentences_list, self.test_senses_list, self.test_candidates_list,
                                          self.sense2idx, self.idx2sense, self.num_senses,
                                          self.tokenized_test, self.test_attention_mask,self.test_word_ids, self.test_target_words, self.padding_length)

      self.val_dataset = HomonymyDataset(self.val_sentences_list, self.val_senses_list, self.val_candidates_list,
                                         self.sense2idx, self.idx2sense, self.num_senses,
                                         self.tokenized_val, self.val_attention_mask,self.val_word_ids, self.val_target_words, self.padding_length)

    def train_dataloader(self):
      return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
      return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

    def test_dataloader(self):
      return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

In [42]:
datamodule = HomonymyDataModule(train_sentences_list, train_senses_list, train_candidates_list, tokenized_training['input_ids'], tokenized_training['attention_mask'],training_word_ids,train_target_words,
                                test_sentences_list, test_senses_list, test_candidates_list,tokenized_test['input_ids'], tokenized_test['attention_mask'], test_word_ids,test_target_words,
                                val_sentences_list, val_senses_list, val_candidates_list, tokenized_val['input_ids'], tokenized_val['attention_mask'], val_word_ids,val_target_words,
                                tokenizer, sense2idx, idx2sense, parameters_dict['batch_size'], num_senses, parameters_dict['padding_length'])

## Lightning module

The LitHomonymyModel class takes care of intantiating, training, evaluating and testing the model according to the Pytorch Lightning framework.

In [44]:
class LitHomonymyModel(pl.LightningModule):
  def __init__(self, transformer_name, num_classes, padding_length, learning_rate, debug_mode, num_warmup_steps, total_steps):
    super().__init__()

    self.model = HomonymyAverage(transformer_name, num_classes, padding_length)
    self.lr = learning_rate
    self.loss_fn = nn.CrossEntropyLoss(ignore_index=0) # the index 0 is associated to padding values.
    self.f1_metric = evaluate.load("f1") # use the HuggingFace evaluate library to compute micro-F1
    self.debug_mode = debug_mode
    self.index_to_ignore = 0

    self.num_warmup_steps = num_warmup_steps # num of warmup steps for the scheduler
    self.total_steps = total_steps

    self.training_step_outputs = {'loss': [], 'accuracy': []}
    self.validation_step_outputs = {'loss': [], 'accuracy': []}
    self.test_step_outputs = {'accuracy': []}

  def training_step(self, batch, batch_idx):
    '''
    '''
    input_ids = batch['tokenized_sentences'].squeeze()
    attention_mask = batch['attention_masks'].squeeze()

    labels = batch['labels']

    logits_mask = batch['logits_mask']
    word_ids = batch['word_ids']
    target_words = batch['target_words']

    logits = self.model(input_ids, attention_mask, word_ids, target_words)

    # Apply logits mask
    for i in range(len(logits)):
      logits[i] = torch.mul(logits_mask[i], logits[i])

    labels = labels.view(-1).to(torch.long)  # [batch_size * words_in_a_sentence]
    logits = logits.view(-1, logits.shape[2])  # [batch_size * words_in_a_sentence, logits_for_each_possible_label]
    loss = self.loss_fn(logits, labels)

    accuracy = self.compute_accuracy(logits, labels)

    self.training_step_outputs['loss'].append(loss.item())
    self.training_step_outputs['accuracy'].append(accuracy['f1'])

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    input_ids = batch['tokenized_sentences'].squeeze()
    attention_mask = batch['attention_masks'].squeeze()

    labels = batch['labels']
    logits_mask = batch['logits_mask']
    word_ids = batch['word_ids']
    target_words = batch['target_words']

    logits = self.model(input_ids, attention_mask, word_ids, target_words)

    # Apply logits mask
    for i in range(len(logits)):
      logits[i] = torch.mul(logits_mask[i] , logits[i])

    labels = labels.view(-1).to(torch.long)  # [batch_size * words_in_a_sentence]
    logits = logits.view(-1, logits.shape[2])  # [batch_size * words_in_a_sentence, logits_for_each_possible_label]

    loss = self.loss_fn(logits, labels)

    accuracy = self.compute_accuracy(logits, labels)

    self.validation_step_outputs['loss'].append(loss.item())
    self.validation_step_outputs['accuracy'].append(accuracy['f1'])

    self.log('validation/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('validation/accuracy', accuracy['f1'], on_step=False, on_epoch=True, prog_bar=True)

    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['tokenized_sentences'].squeeze()
    attention_mask = batch['attention_masks'].squeeze()

    labels = batch['labels']
    logits_mask = batch['logits_mask']
    word_ids = batch['word_ids']
    target_words = batch['target_words']

    logits = self.model(input_ids, attention_mask, word_ids, target_words)

    # Apply logits mask
    for i in range(len(logits)):
      logits[i] = torch.mul(logits_mask[i], logits[i])

    labels = labels.view(-1).to(torch.long)  # [batch_size * words_in_a_sentence]
    logits = logits.view(-1, logits.shape[2])  # [batch_size * words_in_a_sentence, logits_for_each_possible_label]


    accuracy = self.compute_accuracy(logits, labels)
    self.test_step_outputs['accuracy'].append(accuracy['f1'])

    self.log('test/accuracy', accuracy['f1'], on_step=False, on_epoch=True, prog_bar=True)

    return accuracy


  def compute_accuracy(self, logits, labels):
    '''
    Computes the accuracy by taking the argmax of the logits and discarding the padding labels.
    Accuracy is computed using HuggingFace's evaluate library with micro-f1.
    '''
     # Find non-ignored indices
    non_ignored_indices = (labels != self.index_to_ignore)

    # Compute predicted labels
    predicted_labels = logits.argmax(dim=1)

    # Compare predicted labels with ground truth labels, ignoring the desired index
    correct_predictions = (predicted_labels == labels)
    correct_predictions = correct_predictions[non_ignored_indices]

    accuracy = self.f1_metric.compute(predictions=predicted_labels[non_ignored_indices], references=labels[non_ignored_indices], average="micro")

    return accuracy

  def on_train_epoch_end(self):
    '''
    Compute average loss and accuracy when epoch ends and log them.
    '''
    print(f'Training epoch: {self.current_epoch}:')
    avg_loss = sum(self.training_step_outputs['loss']) / len(self.training_step_outputs['loss'])
    avg_acc = sum(self.training_step_outputs['accuracy']) / len(self.training_step_outputs['accuracy'])

    print('avg_train_loss', avg_loss)
    print('avg_train_acc', avg_acc)
    print()

    self.log('avg_train_loss', avg_loss)
    self.log('avg_train_acc', avg_acc)

    self.training_step_outputs['loss'].clear()
    self.training_step_outputs['accuracy'].clear()

    # Log to wandb if not in debug mode
    if not self.debug_mode:
      wandb.log({"train/loss": avg_loss})
      wandb.log({"train/accuracy": avg_acc})

  def on_validation_epoch_end(self):
    '''
    Compute average loss and accuracy when epoch ends and log them.
    '''
    print(f'Validation epoch {self.current_epoch}:')
    avg_loss = sum(self.validation_step_outputs['loss']) / len(self.validation_step_outputs['loss'])
    avg_acc = sum(self.validation_step_outputs['accuracy']) / len(self.validation_step_outputs['accuracy'])

    print('avg_validation_loss', avg_loss)
    print('avg_validation_acc', avg_acc)
    print()

    self.log('avg_validation_loss', avg_loss)
    self.log('avg_validation_acc', avg_acc)

    self.validation_step_outputs['loss'].clear()
    self.validation_step_outputs['accuracy'].clear()

    # Log to wandb if not in debug mode
    if not self.debug_mode:
      wandb.log({"validation/loss": avg_loss})
      wandb.log({"validation/accuracy": avg_acc})

  def on_test_end(self):
    '''
    Compute the final test accuracy.
    '''
    print('Test result:')
    avg_acc = sum(self.test_step_outputs['accuracy']) / len(self.test_step_outputs['accuracy'])

    print('avg_test_acc', avg_acc)
    print()

    self.test_step_outputs['accuracy'].clear()

    # Log to wandb if not in debug mode
    if not self.debug_mode:
      wandb.log({"test/accuracy": avg_acc})

  def configure_optimizers(self):
    '''
    Initializes the optimizer (AdamW) and the learning rate scheduler.
    '''
    optimizer = torch.optim.RAdam(self.parameters(), self.lr)

    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5, num_training_steps=self.total_steps)
    return [optimizer], [scheduler]

In [45]:
# Initialize lightning module
model = LitHomonymyModel(parameters_dict['transformer_name'], num_senses, parameters_dict['padding_length'],parameters_dict['learning_rate'],
                         parameters_dict['debug_mode'], parameters_dict['num_warmup_steps'], parameters_dict['num_epochs'])

In [46]:
# Checkpoint callback to save best model based on highest validation accuracy.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='avg_validation_acc',
    verbose=True,
    save_top_k=1,
    mode='max',
    dirpath= parameters_dict['checkpoint_dir'],
    filename="{epoch}-{avg_validation_acc:.4f}"
)

In [None]:
if parameters_dict['debug_mode']: # if in debug mode, disable checkpoining
  trainer = pl.Trainer(max_epochs=parameters_dict['num_epochs'])
else:
  trainer = pl.Trainer(max_epochs=parameters_dict['num_epochs'], callbacks=[checkpoint_callback])

In [None]:
trainer.fit(model=model, datamodule=datamodule)

## Model testing

Test the model using the Lightning framework.

In [None]:
trainer.test(ckpt_path=parameters_dict['checkpoint dir'], dataloaders=datamodule.test_dataloader())