This notebook implements the ambiguous pronoun detection. The model reaches pretty good performances: 91% accuracy.

However, I did not delve into it too much, since I did not find a good model for entity identification and resolution.

In [None]:
! pip install transformers

In [1]:
import transformers
from transformers import (
    AutoTokenizer,
    BertModel,
    logging
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

import numpy as np
import pandas as pd

import os
import random
import time
import math
import yaml
from typing import *
from datetime import datetime
from collections import namedtuple

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler

SEED = 10

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# Display the entire text
pd.set_option("display.max_colwidth", None)
logging.set_verbosity_error()

In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cuda')

In [6]:
# For Colab
from google.colab import drive
drive.mount('/content/drive/')

curr_location = "/content/drive/MyDrive/Colab Notebooks/Deep Learning/NLP/nlp2022-hw3/hw3/stud/"
os.chdir(curr_location)

Mounted at /content/drive/


In [3]:
curr_location = "H:/My Drive/Colab Notebooks/Deep Learning/NLP/nlp2022-hw3/hw3/stud"
os.chdir(curr_location)

In [4]:
from arguments import *

In [5]:
yaml_file = "./train_notebook2.yaml"
# Read configuration file with all the necessary parameters
with open(yaml_file) as file:
    config = yaml.safe_load(file)
    
model_args = ModelArguments(**config['model_args'])
model_name_or_path = model_args.model_name_or_path
model_name_or_path

'bert-base-uncased'

In [6]:
train_clean_path = "../../model/data/train_clean.tsv"
valid_clean_path = "../../model/data/valid_clean.tsv"

In [7]:
df_train = pd.read_csv(filepath_or_buffer=train_clean_path, sep="\t")
df_valid = pd.read_csv(filepath_or_buffer=valid_clean_path, sep="\t")

In [9]:
class GAP_Ambiguous_Detection_Dataset(Dataset):
    """
    Custom GAP dataset for ambiguous pronoun identification.
    
    Parameters
    ----------
    df: pd.DataFrame
        A dataframe from the GAP dataset.
        
    tokenizer: PreTrainedTokenizerBase
        The tokenizer used to preprocess the data.
        
    pronouns_list: List[str]
        A list containing the pronound that may apper
        in the dataset.
        
    pronoun_tag: strdd
        The pronoun tag that will be inserted to detect the ambiguous
        pronoun, if labeled is True.
          
    keep_tags: bool
        If true the tags added to text are kept even after
        the tokenization process.
  
    labeled: bool
        If the dataset also contains the labels.
        
    cleaned: bool
        Whether the GAP dataframe is already cleaned or not.
    """
    def __init__(
        self, 
        df: pd.DataFrame, 
        tokenizer: PreTrainedTokenizerBase, 
        pronouns_list: List[str],
        pronoun_tag: str, 
        keep_tags: bool=False, 
        labeled: bool=True, 
        cleaned: bool=True
    ):
        
        if not cleaned:
             self.clean_dataframe(df)

        self.df = df
        self.tokenizer = tokenizer
        self.pronouns_list = pronouns_list
        self.pronoun_tag = pronoun_tag
        self.keep_tags = keep_tags
        self.labeled = labeled
        
        self.samples = []
        self._convert_tokens_to_ids()
        
    @staticmethod
    def clean_text(text: str):
        text = text.translate(str.maketrans("`", "'"))
        return text

    def clean_dataframe(self, df: pd.DataFrame):
        df['text'] = df['text'].map(self.clean_text)
        df['entity_A'] = df['entity_A'].map(self.clean_text) 
        df['entity_B'] = df['entity_B'].map(self.clean_text) 

    def _assign_class_to_pronouns(self, offsets: List[int],
                                ambiguous_offset: int) -> List[int]:
        """
        Returns
        -------
            A list of integers defining the pronoun class.
            The class id is:
            - 2 if the pronoun is the ambiguous one
            - 1 for all the other pronouns
            - (0 will be used for padding)
        """
        labels = []
        for off in offsets:
            if off == ambiguous_offset:
                labels.append(2)
            else:
                labels.append(1)

        return labels
        
    def _convert_tokens_to_ids(self):
        CLS = [self.tokenizer.cls_token]
        SEP = [self.tokenizer.sep_token]

        Sample = namedtuple("Sample", ['tokens', 'offsets'])
        if self.labeled:
            Sample = namedtuple("Sample", Sample._fields + ("labels",))

        for _, row in self.df.iterrows():
            tokens, pronouns_offsets, ambiguous_offset = self._tokenize(row)

            tokens_to_convert = CLS + tokens + SEP
            final_tokens = self.tokenizer.convert_tokens_to_ids(tokens_to_convert)
                
            sample = {'tokens': final_tokens,
                      'offsets': self._get_offsets_list(pronouns_offsets)}
    
            if self.labeled:
                sample['labels'] = self._assign_class_to_pronouns(pronouns_offsets,
                                                                  ambiguous_offset)

            sample_namedtuple = Sample(**sample)
            self.samples.append(sample_namedtuple)

            
    def _get_offsets_list(self, offsets: List[int]) -> List[int]:
        # 1 is added for the introduction of the CLS token
        return list(map(lambda off: off+1, offsets))

        
    def _insert_tag(self, text: str, offsets: Tuple[int, int], 
                    start_tag: str, end_tag: str = None) -> str:
        start_off, end_off = offsets 

        # Starting tag only
        if end_tag is None:
            text = text[:start_off] + start_tag + text[start_off:]
            return text

        text = text[:start_off] + start_tag + text[start_off:end_off] + end_tag + text[end_off:]
        return text

    
    def _tokenize(self, row: pd.Series) -> Tuple[List[int], List[int], int]: 
        """
        Tokenize the text.
        If keep_tags is True, also the tags are tokenized.
        """        
        tokens =[]
        pronouns_offsets = []
        ambiguous_offset = -1
   
        text = row['text']
        
        if self.labeled:
            text = self._insert_tag(text, (row['p_offset'], None),
                                    self.pronoun_tag)

        # Also the tags are added to the tokens
        if self.keep_tags:
            for token in self.tokenizer.tokenize(text):
                if token == self.pronoun_tag:
                    ambiguous_offset = len(tokens) + 1
                
                if token.lower() in self.pronouns_list:
                    pronouns_offsets.append(len(tokens)) 
                    
                tokens.append(token)
        
        # The tags are skipped
        else:
            for token in self.tokenizer.tokenize(text):
                
                if token == self.pronoun_tag:
                    ambiguous_offset = len(tokens)
                    continue
                
                if token.lower() in self.pronouns_list:
                    pronouns_offsets.append(len(tokens)) 
                    
                tokens.append(token)
        
        return tokens, pronouns_offsets, ambiguous_offset

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [10]:
tokenizer_name_or_path = model_args.tokenizer
if tokenizer_name_or_path is None:
    tokenizer_name_or_path = model_name_or_path
tokenizer_name_or_path

'bert-base-uncased'

In [11]:
pronoun_tag = "<p>"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, never_split=[pronoun_tag])
tokenizer.add_tokens([pronoun_tag], special_tokens=True)

1

In [12]:
pronouns_list = ['he','she','him','her','his','hers']

In [13]:
train_ds = GAP_Ambiguous_Detection_Dataset(df_train, tokenizer, pronouns_list, pronoun_tag)
valid_ds = GAP_Ambiguous_Detection_Dataset(df_valid, tokenizer, pronouns_list, pronoun_tag)

In [14]:
def compute_max_len(sequences: Union[List[List[int]], Tuple[List[int]]], truncate_len: int) -> int:
    """
    Computes the maximum length in the sequences.         
    """
    max_len = min(
        max((len(x) for x in sequences)),
        truncate_len
    )
    return max_len
    
def pad_sequence(sequences: Union[List[List[int]], Tuple[List[int]]], max_len: int, pad: int) -> np.ndarray:
    """
    Returns
    -------
        A numpy array padded with the 'pad' value until 
        the 'max_len' length. 

    Parameters
    ----------
    sequences: Union[List[List[int]], Tuple[List[int]]]
        A list or tuple of lists. 

    max_len: int
        The length to which the input is padded.

    pad: int
        The padding value.
    """
    array_sequences = np.full((len(sequences), max_len), pad, dtype=np.int64)

    # Padding
    for i, sequence in enumerate(sequences):
        array_sequences[i, :len(sequence)] = sequence

    return array_sequences

In [16]:
class Collator_Token_Classification:
    """
    Collator for Token Classification.
    
    Returns
    -------
        A dictionary of tensors of the batch sequences in input.

    Parameters
    ----------
    device: str
        Where (CPU/GPU) to load the features.
        
    pad: int
        The padding token.

    truncate_len: int
        Maximum length possible in the batch.

    labeled: bool
        If the batch also contains the labels.
    """
    def __init__(self, device: str, pad: int=0, 
                 truncate_len: int=512, labeled=True):
        self.device = device
        self.pad = pad
        self.truncate_len = truncate_len
        self.labeled = labeled
        
    def __call__(self, batch):

        if self.labeled:
            batch_features, batch_offsets, batch_labels = zip(*batch)

        else:
            batch_features, batch_offsets = zip(*batch)

        max_len_features_in_batch = compute_max_len(batch_features, self.truncate_len)
        max_len_offsets_in_batch = compute_max_len(batch_offsets, self.truncate_len)

        collate_sample = {}
        
        # Features        
        padded_features = pad_sequence(batch_features, max_len_features_in_batch, self.pad)
        features_tensor = torch.tensor(padded_features, device=self.device)
        collate_sample['features'] = features_tensor
        
        # Offsets
        padded_offsets = pad_sequence(batch_offsets, max_len_offsets_in_batch, self.pad)
        offsets_tensor = torch.tensor(padded_offsets, device=self.device)
        collate_sample['offsets'] = offsets_tensor
        
        if not self.labeled:
            return collate_sample

        # Labels
        padded_labels = pad_sequence(batch_labels, max_len_offsets_in_batch, self.pad)
        labels_tensor = torch.tensor(padded_labels, dtype=torch.uint8, device=self.device)
        collate_sample['labels'] = labels_tensor
        
        return collate_sample

In [17]:
class Ambiguous_Detection_Head(nn.Module):
    def __init__(self, bert_hidden_size: int, args: ModelArguments):
        super().__init__()
        
        self.args = args
        self.bert_hidden_size = bert_hidden_size

        input_size = bert_hidden_size
        if args.output_strategy == "concat":
            input_size *= 4

        self.ffnn = nn.Sequential(
            nn.Linear(input_size, args.head_hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(args.dropout),
        )
        
        self.classifier = nn.Linear(args.head_hidden_size, args.num_output)

    def forward(self, bert_outputs, offsets):
        embeddings = self._retrieve_pronouns_embeddings(bert_outputs, offsets)

        x = self.ffnn(embeddings)


        output = self.classifier(x)
        return output
    
    def _retrieve_pronouns_embeddings(self, bert_embeddings: torch.Tensor, 
                                      pronouns_offsets: torch.Tensor):
        
        # bert_embeddings shape: batch_size x seq_length x embed_dim
        # entities_and_pron_offsets shape: batch_size x seq_length

        pronouns_embeddings = []
        
        # Consider embeddings and offsets in each batch separately
        for embeddings, offsets in zip(bert_embeddings, pronouns_offsets):
            pronouns_embeddings.append(embeddings[offsets])

        # Merge outputs
        merged_pronouns_embeddings = torch.stack(pronouns_embeddings, dim=0)
        
        # shape: batch_size x seq_length x embedding_dim
        return merged_pronouns_embeddings

In [18]:
class CR_Model(nn.Module):
    """The main model."""

    def __init__(self, bert_model: str, tokenizer, args: ModelArguments):
        super().__init__()

        self.args = args

        if bert_model in {"bert-base-uncased", "bert-base-cased"}:
            self.bert_hidden_size = 768
        elif bert_model in {"bert-large-uncased", "bert-large-cased"}:
            self.bert_hidden_size = 1024
        else:
            self.bert_hidden_size = args.bert_hidden_size

        self.bert = BertModel.from_pretrained(
            bert_model).to(device, non_blocking=True)
        
        # If the tag tokens (e.g., <p>, <a> etc.) are present in the features,
        # the embedding dimension of the bert embeddings must be changed
        # to be compliant with the new size of the tokenizer vocabulary. 
        if args.resize_embeddings:
            self.bert.resize_token_embeddings(len(tokenizer.vocab))
        
        self.head = Ambiguous_Detection_Head(self.bert_hidden_size, args).to(
            device, non_blocking=True) 

    def forward(self, sample):
        x = sample['features']
        x_offsets = sample['offsets']

        bert_outputs = self.bert(
            x, attention_mask=(x > 0).long(),
            token_type_ids=None, output_hidden_states=True)

        if self.args.output_strategy == "last":
            out = bert_outputs.last_hidden_state

        elif self.args.output_strategy == "concat":
            out = torch.cat([bert_outputs.hidden_states[x] for x in [-1, -2, -3, -4]], dim=-1)

        elif self.args.output_strategy == "sum":
            layers_to_sum = torch.stack([bert_outputs.hidden_states[x] for x in [-1, -2, -3, -4]], dim=0)
            out = torch.sum(layers_to_sum, dim=0)

        else:
            raise ValueError("Unsupported output strategy.")

        head_outputs = self.head(out, x_offsets)
        return head_outputs

**Gradient Scaling**

If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.

The method `step(optimizer, *args, **kwargs)` internally invokes `unscale_(optimizer)`and if no inf/NaN gradients are found, invokes `optimizer.step()` using the unscaled gradients. Otherwise `optimizer.step()` is skipped to avoid corrupting the params.

\**Note for Gradient Clipping*

If you wish to modify the gradients (like in gradient clipping), you should unscale them first. If you attempted to clip *without* unscaling, the gradients' norm magnitude would also be scaled, so your requested threshold would be invalid.

In [29]:
class TokenClassificationTrainer:    
    def __init__(
        self,
        device: str,
        model: nn.Module,
        args: CustomTrainingArguments,
        train_dataloader: DataLoader,
        valid_dataloader: DataLoader,
        criterion: torch.nn,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        pad: int = 0,
    ):
        
        self.model = model
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.pad = pad
        
        assert args is not None, "No training arguments passed!"
        self.args = args
        
    def train(self):
        args = self.args
        valid_dataloader = self.valid_dataloader
        epochs = args.num_train_epochs
        
        train_losses = []
        train_acc_list = []
        valid_losses = []
        valid_acc_list = []
        
        if args.use_early_stopping:
            patience_counter = 0 

        scaler = GradScaler() if args.use_scaler else None

        training_start_time = time.time()
        print("\nTraining...")
        for epoch in range(epochs):
            train_loss, train_acc = self._inner_training_loop(scaler)
            train_losses.append(train_loss)
            train_acc_list.append(train_acc)

            valid_loss, valid_acc = self.evaluate(valid_dataloader)
            valid_losses.append(valid_loss)
            valid_acc_list.append(valid_acc)

            if self.scheduler is not None:
                self._print_sceduler_lr()
                self.scheduler.step()

            self._print_epoch_log(epoch, epochs, train_loss, valid_loss, valid_acc)

            if args.use_early_stopping and len(valid_acc_list) >= 2:
                stop, patience_counter = self._early_stopping(patience_counter, epoch, valid_acc_list)
                if stop:
                    break
        
        training_time = time.time() - training_start_time
        print(f'Training time: {self._print_time(training_time)}')

        metrics_history = {
            "train_losses": train_losses,
            "train_acc": train_acc_list,
            "valid_losses": valid_losses,
            "valid_acc": valid_acc_list,
        }

#         print(metrics_history)
        if args.save_model:
            self._save_model(args.task_type, epoch, valid_acc, scaler, metrics_history)
    
        return metrics_history

    def _inner_training_loop(self, scaler):
        args = self.args
        train_dataloader = self.train_dataloader
        
        train_loss = 0.0
        train_correct, total_count = 0.0, 0.0

        self.model.train()
        for step, sample in enumerate(train_dataloader):
            ### Empty gradients ###
            self.optimizer.zero_grad(set_to_none=True)
            
            ### Forward ###
            if scaler is None:
                predictions = self.model(sample)
                labels = sample['labels']
                train_correct, total_count = self.compute_metrics(predictions, labels, 
                                                              train_correct, total_count)
                labels = labels.view(-1)
                predictions = predictions.view(-1, predictions.shape[-1])
                loss = self.criterion(predictions, labels)
            else:
                with torch.autocast(device_type=self.device):
                    predictions = self.model(sample)
                    labels = sample['labels']
                    train_correct, total_count = self.compute_metrics(predictions, labels, 
                                                              train_correct, total_count)
                    
                    labels = labels.view(-1)
                    predictions = predictions.view(-1, predictions.shape[-1])
                    loss = self.criterion(predictions, labels)
                    
            
            ### Backward  ###
            if scaler is None:
                loss.backward()
            else: 
                # Backward pass without mixed precision
                # It's not recommended to use mixed precision for backward pass
                # Because we need more precise loss
                scaler.scale(loss).backward()
            
            if args.grad_clipping is not None:
                if scaler is not None:
                    scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.grad_clipping)
            
            ### Update weights ### 
            if scaler is None:
                self.optimizer.step()
            else:
                scaler.step(self.optimizer)
                scaler.update()

            train_loss += loss.item()

            if step % args.logging_steps == args.logging_steps - 1:
                running_loss = train_loss / (step + 1)
                running_acc = train_correct / total_count
                self._print_step_log(step, running_loss, running_acc)
                
        return train_loss / len(train_dataloader), train_correct / total_count
  
    def evaluate(self, eval_dataloader):
        valid_loss = 0.0
        eval_correct, total_count = 0, 0
        
        self.model.eval()
        with torch.no_grad():
            for sample in eval_dataloader:
                
                predictions = self.model(sample)
                labels = sample['labels']
                eval_correct, total_count = self.compute_metrics(predictions, labels, 
                                                                 eval_correct, total_count)
                
                labels = labels.view(-1)
                predictions = predictions.view(-1, predictions.shape[-1])
                loss = self.criterion(predictions, labels)
                valid_loss += loss.item()

        
        return valid_loss / len(eval_dataloader), eval_correct / total_count

    def compute_metrics(self, predictions, labels, num_correct, total_count):  
        # Iterate one batch at a time
        for one_batch_predictions, one_batch_labels in zip(predictions, labels):
            num_batch_correct, batch_count = 0.0, 0.0

            mask = one_batch_labels != self.pad
            one_batch_labels = one_batch_labels[mask]

            one_batch_predictions = one_batch_predictions[mask]
            maximum_logits, predicted_labels = one_batch_predictions.max(1)

            # It may happen that more than one pronoun is classify as ambiguous
            multiple_ambiguous_pronouns_mask = predicted_labels == 2
            ambiguous_pronouns_logits = maximum_logits[multiple_ambiguous_pronouns_mask]

            # More than one pronoun is classify as ambiguous
            if len(ambiguous_pronouns_logits) > 1:
                # Get the highest logit among the ambiguous ones
                highest_ambiguous_pronoun_logit = ambiguous_pronouns_logits.max()

                # Identity the position of the logit that should correspond to the ambiguous prononun class (2)
                ambiguous_pronoun_mask = maximum_logits == highest_ambiguous_pronoun_logit

                # All the predictions that are not of that class are set to the "not ambiguous class" (1)
                predicted_labels[~ambiguous_pronoun_mask] = 1

                # However, it may happen again that we have multiple pronouns classified as ambiguous, 
                # since there may be more than one logit with value = highest_ambiguous_pronoun_logit


            # When the model predicts that all the pronouns are not ambiguous (no class 2)
            if not torch.any(predicted_labels == 2):
                # Try to select the most probable ambiguous pronoun

                probable_ambiguous_index = one_batch_predictions[:,-1].argmax(dim=0)
                predicted_labels[probable_ambiguous_index] = 2

            label_ambiguous_mask = one_batch_labels == 2
            num_batch_correct += (one_batch_labels[label_ambiguous_mask] == predicted_labels[label_ambiguous_mask]).sum().item()
            batch_count += 1
        
        num_correct += num_batch_correct
        total_count += batch_count
    
        return num_correct, total_count

    def _early_stopping(self, patience_counter, epoch, valid_acc_list):
        args = self.args

        # stop = args.early_stopping_mode == 'min' and epoch > 0 and valid_acc_list[-1] > valid_acc_list[-2]
        stop = args.early_stopping_mode == 'max' and epoch > 0 and valid_acc_list[-1] < valid_acc_list[-2]
        if stop:
            if patience_counter >= args.early_stopping_patience:
                print('Early stop.')
                return stop, patience_counter
            else:
                print('-- Patience.\n')
                patience_counter += 1

        return False, patience_counter   
    
    def _print_time(self, s):
        m = math.floor(s / 60)
        s -= m * 60
        return '%dm %ds' % (m, s)

    def _print_sceduler_lr(self):
        print('-' * 17)
        print(f"| LR: {self.scheduler.get_last_lr()[0]:.3e} |")

    def _print_step_log(self, step, running_loss, running_acc):
        print(f'\t| step {step+1:4d}/{len(self.train_dataloader):d} | train_loss: {running_loss:.3f} | ' \
                f'train_acc: {running_acc:.3f} |')

    def _print_epoch_log(self, epoch, epochs, train_loss, valid_loss, valid_acc):
        print('-' * 76)
        print(f'| epoch {epoch+1:>3d}/{epochs:<3d} | train_loss: {train_loss:.3f} | ' \
                f'valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} |')
        print('-' * 76)
        
    
    def _save_model(self, task_type, epoch, valid_acc, scaler, metrics_history):
        print("Saving model...")
        params_to_save = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "metrics_history": metrics_history,
        }
        
        if self.scheduler is not None:
            params_to_save["scheduler_state_dict"] = self.scheduler.state_dict()
            
        if scaler is not None:
            params_to_save["scaler_state_dict"] = scaler.state_dict()
            
        save_path = f"{self.args.output_dir}my_model{str(task_type)}_{str(valid_acc)[2:5]}_{epoch+1}"
        now = datetime.now()
        current_time = now.strftime("%H-%M-%S")
        
        if os.path.exists(f"{save_path}_{current_time}.pth"):
            torch.save(params_to_save, f"{save_path}_{current_time}_new.pth")
        else:
            torch.save(params_to_save, f"{save_path}_{current_time}.pth")
        
        print("Model saved.")

In [30]:
def freeze_weights(modules):
    for module in modules:
        for param in module.parameters():
            if hasattr(param, 'requires_grad'):
                param.requires_grad = False

In [31]:
model = CR_Model(model_name_or_path, tokenizer, model_args).to(device, non_blocking=True)

# last_frozen_layer = 12
# modules = [model.bert.embeddings, *model.bert.encoder.layer[:last_frozen_layer]]
# # modules = [*model.bert.encoder.layer[:last_frozen_layer]]
# freeze_weights(modules)

yaml_file = "./train_notebook2.yaml"
# Read configuration file with all the necessary parameters
with open(yaml_file) as file:
    config = yaml.safe_load(file)
    
training_args = CustomTrainingArguments(**config['training_args'])

# Make sure that the learning rate is read as a number and not as a string
training_args.learning_rate = float(training_args.learning_rate)
print(training_args)

criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([0, 0.1, 0.9]), ignore_index=0).to(device=device, non_blocking=True)
optimizer = torch.optim.Adam(model.parameters(), lr=training_args.learning_rate)
# scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
scheduler = None

batch_size = 4

collator = Collator_Token_Classification(device)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, 
                              collate_fn=collator, shuffle=True)
valid_dataloader = DataLoader(valid_ds, batch_size=batch_size, 
                              collate_fn=collator, shuffle=False)

trainer = TokenClassificationTrainer(str(device), model, training_args, 
                  train_dataloader, valid_dataloader, 
                  criterion, optimizer, scheduler)


CustomTrainingArguments(output_dir='../../model/checkpoints/', task_type=1, save_model=False, num_train_epochs=2, logging_steps=250, learning_rate=5e-06, grad_clipping=None, use_early_stopping=True, early_stopping_mode='max', early_stopping_patience=2, use_scaler=True)


In [32]:
metrics_history = trainer.train()


Training...
	| step  250/750 | train_loss: 0.431 | train_acc: 0.684 |
	| step  500/750 | train_loss: 0.322 | train_acc: 0.768 |
	| step  750/750 | train_loss: 0.271 | train_acc: 0.807 |
----------------------------------------------------------------------------
| epoch   1/2   | train_loss: 0.271 | valid_loss: 0.166 | valid_acc: 0.877 |
----------------------------------------------------------------------------
	| step  250/750 | train_loss: 0.119 | train_acc: 0.944 |
	| step  500/750 | train_loss: 0.114 | train_acc: 0.936 |
	| step  750/750 | train_loss: 0.109 | train_acc: 0.943 |
----------------------------------------------------------------------------
| epoch   2/2   | train_loss: 0.109 | valid_loss: 0.120 | valid_acc: 0.912 |
----------------------------------------------------------------------------
Training time: 4m 49s


In [None]:
metrics_history

In [34]:
y_true_list = []
y_pred_list = []
logits = []

eval_correct, total_count = 0.0, 0.0

model.eval()
with torch.no_grad():
    collator = Collator_Token_Classification(device)
    dataloader = DataLoader(valid_ds, batch_size=1, collate_fn=collator, shuffle=False)
    for idx, sample in enumerate(dataloader):
        predictions = model(sample)
        labels = sample['labels']
  
        # Iterate one batch at a time
        for one_batch_predictions, one_batch_labels in zip(predictions, labels):
            eval_batch_correct, batch_count = 0.0, 0.0
            logits.append(one_batch_predictions)

            mask = one_batch_labels != 0
            one_batch_labels = one_batch_labels[mask]
            y_true_list.append(one_batch_labels.tolist())

            one_batch_predictions = one_batch_predictions[mask]
            maximum_logits, predicted_labels = one_batch_predictions.max(1)

            # It may happen that more than one pronoun is classify as ambiguous
            multiple_ambiguous_pronouns_mask = predicted_labels == 2
            ambiguous_pronouns_logits = maximum_logits[multiple_ambiguous_pronouns_mask]

            # More than one pronoun is classify as ambiguous
            if len(ambiguous_pronouns_logits) > 1:
                # Get the highest logit among the ambiguous ones
                highest_ambiguous_pronoun_logit = ambiguous_pronouns_logits.max()

                # Identity the position of the logit that should correspond to the ambiguous prononun class (2)
                ambiguous_pronoun_mask = maximum_logits == highest_ambiguous_pronoun_logit

                # All the predictions that are not of that class are set to the "not ambiguous class" (1)
                predicted_labels[~ambiguous_pronoun_mask] = 1

                # However, it may happen again that we have multiple pronouns classified as ambiguous, 
                # since there may be more than one logit with value = highest_ambiguous_pronoun_logit


            # When the model predicts that all the pronouns are not ambiguous (no class 2)
            if not torch.any(predicted_labels == 2):
                # Try to select the most probable ambiguous pronoun
                probable_ambiguous_index = one_batch_predictions[:,-1].argmax(dim=0)
                predicted_labels[probable_ambiguous_index] = 2


            y_pred_list.append(predicted_labels.tolist())


            label_ambiguous_mask = one_batch_labels == 2
            eval_batch_correct += (one_batch_labels[label_ambiguous_mask] == predicted_labels[label_ambiguous_mask]).sum().item()
            batch_count += 1
        
        eval_correct += eval_batch_correct
        total_count += batch_count

In [35]:
eval_correct

413.0

In [36]:
total_count

454.0

In [37]:
eval_correct / total_count

0.9096916299559471