In [1]:
!pip install --upgrade datasets --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from google.colab import drive
import os

drive.mount("/content/drive")
%cd /content/drive/MyDrive

PARENT_DIR = os.getcwd()
MAIN_DIR = os.path.join(PARENT_DIR, "Implement Classic Papers", "BERT")
DATA_DIR = os.path.join(PARENT_DIR, "Data", "Text", "conll2003")

Mounted at /content/drive
/content/drive/MyDrive


In [110]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import random
import numpy as np

from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Union, Callable, Optional, Dict
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

MAIN_DIR = os.getcwd()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained(
    # os.path.join(MAIN_DIR, "tokenizers", "google-bert", "bert-large-uncased")
    "bert-base-cased"
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]



# Finetuning

- Task Type: NER
- Dataset: CoNLL-2003

## Embeddings
- Token embeddings: For each token in vocabulary
- Absolute positional embedding: Using sinusoidal function to model absolute position in the sentence sequence.
- Segment Embedding: Used for next sentence prediction and NLI tasks

In [5]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(
        self, d_model: int, dropout: float=0.1, max_len: int=512
    ):
        super(SinusoidalPositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(size=(max_len, d_model)) # (max_len, d_model)
        const_term = math.log(10000) / d_model
        div_terms = torch.exp(-torch.arange(0, d_model, 2) * const_term) # (d_model//2)
        positions = torch.arange(0, max_len).unsqueeze(1) # (max_len, 1)
        pe[:, ::2] = torch.sin(positions*div_terms)  # sin(pos * div_term)
        pe[:, 1::2] = torch.cos(positions*div_terms)  # sin(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(
        self, inputs: torch.Tensor
    ):
        # Input sequence: (batch_size, seq_length)
        seq_len = inputs.size(1)
        x = self.pe[:, : seq_len, :].requires_grad_(False)
        return self.dropout(x)

class TokenEmbeddings(nn.Module):
    def __init__(
        self,
        d_model: int,
        vocab_size: int=30522,
        padding_idx: int=0
    ):
        super(TokenEmbeddings, self).__init__()
        self.d_model, self.vocab_size = d_model, vocab_size
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model, padding_idx=padding_idx
        )

    def forward(
        self, inputs: torch.Tensor
    ):
        return self.embeddings(inputs) * math.sqrt(self.d_model)

class SegmentEmbedding(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        padding_idx: int = 2,
    ):
        super(SegmentEmbedding, self).__init__()
        self.padding_idx = padding_idx
        self.embedding = nn.Embedding(
                3, embedding_dim=d_model, padding_idx=padding_idx
            )

    def forward(
        self,
        token_types_id: torch.Tensor, # (batch_size, seq_len)
        attn_mask: Optional[torch.Tensor] = None
        ):
        token_types_id = token_types_id.masked_fill(
            attn_mask==0, self.padding_idx
        )
        segment_embeddings = self.embedding(token_types_id)
        return segment_embeddings # (batch_size, seq_len, d_model)

class BERTEmbedding(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        vocab_size: int = 30522,
        use_segment_embedding: bool = False,
        dropout: float = 0.1,
        token_padding_idx: int = 0,
        segment_padding_idx: int = 2
    ):
        super(BERTEmbedding, self).__init__()
        self.token_embedding = TokenEmbeddings(
            d_model=d_model,
            vocab_size=vocab_size,
            padding_idx=token_padding_idx
            )
        self.positional_embedding = SinusoidalPositionalEncoding(
            d_model=d_model, max_len=512
        )
        self.segment_embedding = None
        if use_segment_embedding:
            self.segment_embedding = SegmentEmbedding(
                d_model=d_model, padding_idx=segment_padding_idx
            )
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        inputs: torch.Tensor, # (batch_size, seq_len)
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        token_embeddings = self.token_embedding(inputs)
        pos_embeddings = self.positional_embedding(inputs)
        embeddings = token_embeddings + pos_embeddings
        if self.segment_embedding:
            embeddings = embeddings + self.segment_embedding(token_type_ids, attn_mask)
        return self.dropout(embeddings)

## Model Architecture

In [6]:
class GroupAttentionHead(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_q: int,
        d_k: int,
        d_v: int,
        n_heads: int = 1,
        dropout: float = 0.1
    ):
        super(GroupAttentionHead, self).__init__()
        self.d_q, self.d_k, self.d_v = d_q, d_k, d_v
        self.Q = nn.ModuleList([nn.Linear(d_model, d_q) for _ in range(n_heads)])
        self.K = nn.Linear(d_model, d_k)
        self.V = nn.Linear(d_model, d_v)
        self.dropout = dropout

    def forward(
        self,
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
        attn_mask: Optional[torch.Tensor]=None,
        dropout: Optional[float]=None,
        is_causal: bool=False
    ):
        """Calculate attention values for a single attention head.
        For self attention, q = k = v; Dimension = (batch_size, seq_len, d_model)
        For seq2seq cross attention
        - q = output of masked attention previous component of the decoder. Dimension = (batch_size, tgt_seq_len, d_model)
        - k = v = output of the encoder. Dimension = (batch_size, src_seq_len, d_model)
        """
        dropout = dropout or self.dropout
        query = torch.cat(
            [q(query).unsqueeze(1) for q in self.Q], dim = 1
        ) # (batch_size, n_groups, seq_len, d_q)
        key = self.K(key).unsqueeze(1) # (batch_size, 1, seq_len, d_k)
        value = self.V(value).unsqueeze(1) # (batch_size, 1, seq_len, d_v)
        if isinstance(attn_mask, torch.Tensor):
            attn_mask = attn_mask.unsqueeze(1).unsqueeze(1).type(torch.bool)

        values = F.scaled_dot_product_attention(
            query, key, value, attn_mask, dropout, is_causal
            )

        return values

class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        n_heads: int = 12,
        n_groups: int = 12,
        dropout: float = 0.1,
        d_v: Optional[int] = None
    ):
        super(MultiHeadedAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % n_groups == 0, "Total number of heads must be divisible by number of groups"
        self.heads_per_group = n_heads // n_groups
        self.d_q = d_model // n_heads
        self.d_k = self.d_q
        self.d_v = d_v or self.d_q
        self.attn_heads = nn.ModuleList(
            [
                GroupAttentionHead(d_model, self.d_q, self.d_k, self.d_v, n_heads=self.heads_per_group, dropout=dropout)
                for _ in range(n_groups)
                ]
            )
        self.O = nn.Linear(self.d_v * self.n_heads, self.d_model)

    def forward(
        self,
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
        attn_mask: Optional[torch.Tensor]=None,
        dropout: Optional[float]=None,
        is_causal: bool=False
    ):
        batch_size = query.size(0)
        tgt_seq_len = value.size(1)
        values = torch.cat(
            [
                att_head(
                    query, key, value,
                    attn_mask=attn_mask, dropout=dropout, is_causal=is_causal)
                for att_head in self.attn_heads
                ], dim = 1
            ).transpose(1, 2) # (bs, tgt_seq_len, n_heads * n_groups, d_v)
        values = values.contiguous().view(batch_size, tgt_seq_len, -1)
        return self.O(values)

class PointwiseFeedForward(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(PointwiseFeedForward, self).__init__()
        self.d_feedforward = d_feedfoward or d_model * 4
        self.linear1 = nn.Linear(d_model, self.d_feedforward)
        self.linear2 = nn.Linear(self.d_feedforward, d_model)
        self.activation = getattr(nn.functional, activation)

    def forward(
        self, inputs: torch.Tensor
    ):
        x = self.linear1(inputs)
        x = self.linear2(self.activation(x))
        return x

class ResidualLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        dropout: float = 0.1
    ):
        super(ResidualLayer, self).__init__()
        self.d_model = d_model
        self.layer_norm = nn.LayerNorm(normalized_shape=d_model, eps=1e-5)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self, inputs: torch.Tensor, sublayer: Callable
    ):
        return self.layer_norm(inputs + self.dropout(sublayer(inputs)))

class EncoderBlock(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(EncoderBlock, self).__init__()
        self.d_model = d_model
        self.multihead_group_attention = MultiHeadedAttention(
            d_model=d_model, n_heads=n_heads, n_groups=n_groups, dropout=dropout, d_v=d_v,
        )
        self.feedforward = PointwiseFeedForward(d_model=d_model, d_feedfoward=d_feedfoward, activation=activation)
        self.att_residual = ResidualLayer(d_model=d_model)
        self.ff_residual = ResidualLayer(d_model=d_model)

    def forward(
        self,
        inputs: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        x = self.att_residual(
            inputs,
            lambda x: self.multihead_group_attention(
                x, x, x, attn_mask=attn_mask, is_causal=is_causal
            )
        )
        x = self.ff_residual(x, self.feedforward)
        return x

class TransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        n_layers: int=12,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(TransformerEncoder, self).__init__()
        self.d_model = d_model
        self.layers = nn.ModuleList(
            [
                EncoderBlock(
                    d_model=d_model, n_heads=n_heads, n_groups=n_groups, dropout=dropout, d_v=d_v,
                    d_feedfoward=d_feedfoward, activation=activation)
                for _ in range(n_layers)
                ]
        )

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        for layer in self.layers:
            x = layer(x, attn_mask=attn_mask, is_causal=is_causal)
        return x

class BERTBackbone(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        vocab_size: int=30522,
        n_layers: int=12,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
        use_segment_embedding: bool = False
    ):
        super(BERTBackbone, self).__init__()

        self.embedding = BERTEmbedding(
            d_model=d_model,
            vocab_size=vocab_size,
            use_segment_embedding=use_segment_embedding,
            dropout=dropout
        )

        self.encoder = TransformerEncoder(
            d_model=d_model,
            n_layers=n_layers,
            n_heads=n_heads,
            n_groups=n_groups,
            dropout=dropout,
            d_v=d_v,
            d_feedfoward=d_feedfoward,
            activation=activation
        )

        self.d_model = d_model
        self.vocab_size = vocab_size

        # Initialize weights
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(
        self,
        inputs: torch.Tensor, # (batch_size, seq_len)
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        x = self.embedding(
            inputs=inputs, token_type_ids=token_type_ids, attn_mask=attn_mask
        )
        x = self.encoder(
            x,
            attn_mask=attn_mask.type(torch.float),
            is_causal=is_causal
            )
        return x

class BertForTokenClassification(nn.Module):
    def __init__(
        self,
        bert: BERTBackbone,
        num_classes: int,
        projection_layer: Optional[nn.Module] = None
    ):
        super(BertForTokenClassification, self).__init__()
        self.bert = bert
        self.num_classes = num_classes
        self.d_model = self.bert.d_model
        self.proj = projection_layer or nn.Linear(
            self.d_model, num_classes
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        final_hiddens = self.bert(
            inputs = input_ids,
            token_type_ids = token_type_ids,
            attn_mask = attn_mask,
            is_causal = is_causal
        )
        return self.proj(final_hiddens) # Use CrossEntropyLoss

In [64]:
def _extract_first_word_indices(
    word_ids: List[int]
):
    extracted_indices = []
    curr_word_id = None
    for idx, word_id in enumerate(word_ids):
        if word_id is not None and word_id != curr_word_id:
            extracted_indices.append(idx)
            curr_word_id = word_id
    return extracted_indices

def _map_words_to_tokens(
    word_ids: List[int]
):
    words_to_tokens_mapping = {}
    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            if word_id in words_to_tokens_mapping:
                words_to_tokens_mapping[word_id].append(idx)
            else:
                words_to_tokens_mapping[word_id] = [idx]

    return words_to_tokens_mapping

def word_level_prediction(
    probs: torch.Tensor, # (token_seq_len, num_classes)
    word_ids: List[int], # (word_seq_len)
    aggregation_strategy: str = "first"
):
    if aggregation_strategy == "first":
        extracted_indices = _extract_first_word_indices(word_ids)
        extracted_probs = probs[extracted_indices, :] # (word_seq_len, num_classes)
        return extracted_probs.argmax(dim=-1)
    elif aggregation_strategy == "mean":
        words_to_tokens_mapping = _map_words_to_tokens(word_ids)
        print(words_to_tokens_mapping)
        token_probs = []
        for word_id, token_indices in words_to_tokens_mapping.items():
            extracted_probs = probs[token_indices, :] # (n_tokens_per_word, num_classes)
            avg_probs = extracted_probs.mean(dim=0) # (num_classes)
            token_probs.append(avg_probs)
        return torch.stack(token_probs, dim=0).argmax(dim=-1)
    else:
        raise ValueError("Invalid Aggregation Strategy Type!!!")

# Data Preparation



In [7]:
dataset = load_dataset("arrow", data_files = {
    "train": os.path.join(DATA_DIR, "train", "data-00000-of-00001.arrow"),
    "validation": os.path.join(DATA_DIR, "validation", "data-00000-of-00001.arrow"),
    "test": os.path.join(DATA_DIR, "test", "data-00000-of-00001.arrow")
    })

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [74]:
def _align_token_labels(
    word_labels,
    word_ids,
    special_tokens_mask,
    offset_mappings,
    ignore_token_id: int= -100,
    aggregation_strategy: str = "first"
):
    batch_size = special_tokens_mask.size(0)
    seq_len = special_tokens_mask.size(1)
    token_labels = torch.ones((batch_size, seq_len), dtype=torch.long) * ignore_token_id

    for seq_idx in range(batch_size):
        for token_idx in range(seq_len):
            if aggregation_strategy == "first":
                if special_tokens_mask[seq_idx][token_idx] == 0 and offset_mappings[seq_idx][token_idx][0] == 0:
                    token_labels[seq_idx][token_idx] = word_labels[seq_idx][word_ids[seq_idx][token_idx]]
            elif aggregation_strategy == "mean":
                if special_tokens_mask[seq_idx][token_idx] == 0:
                    token_labels[seq_idx][token_idx] = word_labels[seq_idx][word_ids[seq_idx][token_idx]]
            else:
                raise ValueError("Invalid Aggregation Strategy!!!")

    return token_labels

class CoNLL2003Dataset(Dataset):
    def __init__(
        self, dataset: Dataset,
    ):
        self.dataset = dataset
        self.id2classes = {idx:class_name for idx, class_name in enumerate(self.dataset.features["ner_tags"].feature.names)}
        self.classes2id = {class_name:idx for idx, class_name in enumerate(self.dataset.features["ner_tags"].feature.names)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        words = self.dataset[idx]['tokens']
        labels = self.dataset[idx]['ner_tags']
        return (words, labels)

class CoNLL2003DataManager:
    def __init__(
        self,
        train_dataset: Dataset,
        val_dataset: Dataset,
        test_dataset: Dataset,
        tokenizer: Callable,
        aggregation_strategy = "first",
        ignore_token_id: int = -100
    ):
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.ignore_token_id = ignore_token_id
        self.aggregation_strategy = aggregation_strategy

    def collate_fn(self, batch):
        words, word_labels = zip(*batch)

        tokens_info = self.tokenizer(
            words,
            is_split_into_words = True if isinstance(words[0], list) else False,
            padding=True,
            return_offsets_mapping=True,
            return_special_tokens_mask=True,
            return_tensors="pt"
        )
        input_ids = tokens_info["input_ids"]
        offset_mappings = tokens_info["offset_mapping"]
        special_tokens_mask = tokens_info["special_tokens_mask"]
        word_ids = [tokens_info.word_ids(idx) for idx in range(input_ids.size(0))]

        tokens_info["labels"] = _align_token_labels(
            word_labels=word_labels,
            word_ids=word_ids,
            special_tokens_mask=special_tokens_mask,
            offset_mappings=offset_mappings,
            ignore_token_id=self.ignore_token_id,
            aggregation_strategy=self.aggregation_strategy
        )
        tokens_info["word_ids"] = word_ids
        tokens_info["word_labels"] = word_labels

        return tokens_info

    def get_train_dataloader(
        self, batch_size = 32,
    ):
        return DataLoader(
            self.train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers = 1, pin_memory=True, collate_fn = self.collate_fn
        )

    def get_validation_dataloader(
        self, batch_size = 32,
    ):
        return DataLoader(
            self.val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers = 1, pin_memory=True, collate_fn = self.collate_fn
        )

    def get_test_dataloader(
        self, batch_size = 32,
    ):
        return DataLoader(
            self.test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers = 1, pin_memory=True, collate_fn = self.collate_fn
        )

# Trainer Class

In [115]:
def custom_lr_schedule(
    step_no: int, warm_up: int = 10000, total_steps: int = 100000
):
    post_warm_up_steps = total_steps - warm_up
    if step_no == 0:
        step_no = 1
    if step_no <= warm_up:
        return step_no / warm_up
    else:
        return (total_steps - step_no) / post_warm_up_steps

def _extract_first_word_indices(
    word_ids: List[int]
):
    extracted_indices = []
    curr_word_id = None
    for idx, word_id in enumerate(word_ids):
        if word_id is not None and word_id != curr_word_id:
            extracted_indices.append(idx)
            curr_word_id = word_id
    return extracted_indices

def _map_words_to_tokens(
    word_ids: List[int]
):
    words_to_tokens_mapping = {}
    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            if word_id in words_to_tokens_mapping:
                words_to_tokens_mapping[word_id].append(idx)
            else:
                words_to_tokens_mapping[word_id] = [idx]

    return words_to_tokens_mapping

def word_level_prediction(
    probs: torch.Tensor, # (token_seq_len, num_classes)
    word_ids: List[int], # (word_seq_len)
    aggregation_strategy: str = "first"
):
    if aggregation_strategy == "first":
        extracted_indices = _extract_first_word_indices(word_ids)
        extracted_probs = probs[extracted_indices, :] # (word_seq_len, num_classes)
        return extracted_probs.argmax(dim=-1)
    elif aggregation_strategy == "mean":
        words_to_tokens_mapping = _map_words_to_tokens(word_ids)
        print(words_to_tokens_mapping)
        token_probs = []
        for word_id, token_indices in words_to_tokens_mapping.items():
            extracted_probs = probs[token_indices, :] # (n_tokens_per_word, num_classes)
            avg_probs = extracted_probs.mean(dim=0) # (num_classes)
            token_probs.append(avg_probs)
        return torch.stack(token_probs, dim=0).argmax(dim=-1)
    else:
        raise ValueError("Invalid Aggregation Strategy Type!!!")

class TrainState:
    n_steps: int = 0
    n_updates: int = 0
    train_loss = {"step": [], "epoch": []}
    val_loss = {"step": [], "epoch": []}
    lr = []

class BERTFTTrainer:
    def __init__(
        self,
        model: nn.Module,
        optimizer_args: Dict,
        data_manager,
        device: str,
    ):
        self.model = model.to(device)
        self.optimizer_args = optimizer_args
        self.device = device
        self.train_state = TrainState()
        self.accumulation_steps = optimizer_args["accumulation_steps"]
        self.data_manager = data_manager

        self.optimizer = AdamW(
            params = self.model.parameters(),
            lr = optimizer_args["lr"],
            betas = optimizer_args["betas"],
            weight_decay = optimizer_args["weight_decay"]
        )

        self.scheduler = LambdaLR(
            optimizer=self.optimizer,
            lr_lambda=lambda x: custom_lr_schedule(
                x, warm_up=optimizer_args["warmup"], total_steps=optimizer_args["total_steps"]
                )
            )
        self.loss_criterion = nn.CrossEntropyLoss(ignore_index=self.data_manager.ignore_token_id)
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def run_epoch(
        self,
        epoch: int,
        dataloader: DataLoader,
        train_state: Optional[TrainState] = None,
        is_train: bool = True
    ):
        aggregation_strategy = self.data_manager.aggregation_strategy
        train_state = train_state or self.train_state
        if is_train:
            self.model.train()
        else:
            self.model.eval()

        mode = "Train" if is_train else "Eval"

        data_iter = tqdm(
            enumerate(dataloader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(dataloader),
            bar_format="{l_bar}{r_bar}"
        )

        epoch_loss = 0.0
        epoch_steps = 0
        total_preds, total_corrects = 0, 0
        all_word_preds = []
        all_word_labels = []

        for idx, batch in data_iter:
            input_ids = batch["input_ids"].to(self.device)
            token_type_ids = batch["token_type_ids"].to(self.device)
            attn_mask = batch["attention_mask"].to(self.device)
            labels = batch["labels"].to(self.device)
            word_labels = batch["word_labels"]
            word_ids = batch["word_ids"]

            if is_train:
                outs = self.model(
                    input_ids=input_ids, token_type_ids=token_type_ids, attn_mask=attn_mask,
                )
                logits = outs.view(-1, outs.size(-1))
                labels = labels.view(-1)
                loss = self.loss_criterion(logits, labels)
            else:
                with torch.no_grad():
                    outs = self.model(
                        input_ids=input_ids, token_type_ids=token_type_ids, attn_mask=attn_mask,
                    )
                    logits = outs.view(-1, outs.size(-1))
                    labels = labels.view(-1)
                    loss = self.loss_criterion(logits, labels)

            probs = F.softmax(outs, dim=-1)

            for seq_idx in range(probs.size(0)):
                word_pred = word_level_prediction(
                    probs=probs[seq_idx, :],
                    word_ids=word_ids[seq_idx],
                    aggregation_strategy=aggregation_strategy
                ).detach().cpu()
                word_label = word_labels[seq_idx]

                all_word_preds.extend(word_pred.tolist())
                all_word_labels.extend(word_label)

            if is_train:
                self.train_state.n_steps += 1
                loss.backward()
                if self.train_state.n_steps % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    self.train_state.n_updates += 1

                lr = self.optimizer.param_groups[0]["lr"]
                self.train_state.lr.append(lr)
                self.train_state.train_loss["step"].append(loss.item())
                self.scheduler.step()

            else:
                self.train_state.val_loss["step"].append(loss.item())

            epoch_loss += loss.item()
            epoch_steps += 1

            if (idx + 1) % 100 == 0:
                avg_loss = epoch_loss / epoch_steps

                print(
                    f"Epoch {epoch} - Step {self.train_state.n_steps}, mode {mode}: avg_loss={avg_loss}")

        avg_loss = epoch_loss / epoch_steps
        epoch_micro_acc = accuracy_score(all_word_labels, all_word_preds)
        epoch_macro_precision = precision_score(all_word_labels, all_word_preds, average="macro", zero_division=np.nan)
        epoch_macro_recall = recall_score(all_word_labels, all_word_preds, average="macro", zero_division=np.nan)

        if is_train:
            train_state.train_loss["epoch"].append(avg_loss)
        else:
            train_state.val_loss["epoch"].append(avg_loss)

        print(
            f"Whole Epoch {epoch}, mode {mode}: avg_loss={avg_loss}, accuracy={epoch_micro_acc}, macro precision={epoch_macro_precision}, macro recall={epoch_macro_recall}")

    def train(
        self,
        training_args: Dict
        ):
        train_dataloader = self.data_manager.get_train_dataloader(batch_size=training_args["batch_size"])
        validation_dataloader = self.data_manager.get_validation_dataloader(batch_size=training_args["batch_size"])

        for epoch_no in range(training_args["n_epochs"]):
            self.run_epoch(
                epoch=epoch_no,
                dataloader=train_dataloader,
                is_train=True
            )
            torch.cuda.empty_cache()
            self.run_epoch(
                epoch=epoch_no,
                dataloader=validation_dataloader,
                is_train=False
            )

    def test(
        self,
        batch_size: int = 32
        ):
        test_dataloader = self.data_manager.get_test_dataloader(batch_size=batch_size)
        self.run_epoch(
            epoch=0,
            dataloader=test_dataloader,
            is_train=False
        )

# Finetune !!!

### Load Pretrain Backbone

In [92]:
model_args = dict(
    d_model=768,
    n_layers=12,
    n_heads=12,
    n_groups=4,
    dropout=0.1,
    use_segment_embedding=True,
) # BERT BASE

optimizer_args = dict(
    lr=3e-5,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    warmup=10000,
    total_steps=100000,
    accumulation_steps=1
)

training_args = dict(n_epochs=20, batch_size=32)

In [98]:
train_dataset = CoNLL2003Dataset(dataset["train"].remove_columns(['id', 'pos_tags', 'chunk_tags']))
val_dataset = CoNLL2003Dataset(dataset["validation"].remove_columns(['id', 'pos_tags', 'chunk_tags']))
test_dataset = CoNLL2003Dataset(dataset["test"].remove_columns(['id', 'pos_tags', 'chunk_tags']))
num_classes = len(train_dataset.id2classes.keys())

data_manager = CoNLL2003DataManager(
    train_dataset = train_dataset,
    val_dataset = val_dataset,
    test_dataset = test_dataset,
    tokenizer = tokenizer
)

train_loader = data_manager.get_train_dataloader(batch_size=3)
sample_batch = next(iter(train_loader))

In [99]:
# Load in backbone weights
bert_backbone = BERTBackbone(**model_args)
ckpt_dir = None
if ckpt_dir:
    bert_backbone.load_state_dict(torch.load(os.path.join(ckpt_dir, "best_checkpoint.pt")))

model = BertForTokenClassification(
    bert=bert_backbone,
    num_classes=num_classes
    )

In [116]:
bert_trainer = BERTFTTrainer(
    model=model,
    data_manager=data_manager,
    optimizer_args=optimizer_args,
    device=DEVICE,
)

Total Parameters: 99055113


In [None]:
bert_trainer.train(training_args)

EP_Train:0:  23%|| 100/438 [00:34<01:53,  2.99it/s]

Epoch 0 - Step 100, mode Train: avg_loss=0.7303533762693405


EP_Train:0:  46%|| 200/438 [01:05<01:03,  3.74it/s]

Epoch 0 - Step 200, mode Train: avg_loss=0.7330129128694535


EP_Train:0:  68%|| 300/438 [01:36<00:42,  3.23it/s]

Epoch 0 - Step 300, mode Train: avg_loss=0.7299988355239232


EP_Train:0:  91%|| 400/438 [02:08<00:11,  3.31it/s]

Epoch 0 - Step 400, mode Train: avg_loss=0.7322510153800249


EP_Train:0: 100%|| 438/438 [02:20<00:00,  3.11it/s]


Whole Epoch 0, mode Train: avg_loss=0.7324057758943131, accuracy=0.832717540502113, macro precision=0.4183727893999801, macro recall=0.11474006655101067


EP_Eval:0: 100%|| 102/102 [00:11<00:00,  8.69it/s]

Epoch 0 - Step 438, mode Eval: avg_loss=0.788083046078682





Whole Epoch 0, mode Eval: avg_loss=0.783482560340096, accuracy=0.8324247498150383, macro precision=0.6402719447981865, macro recall=0.11781555640963745


EP_Train:1:  23%|| 100/438 [00:32<01:48,  3.11it/s]

Epoch 1 - Step 538, mode Train: avg_loss=0.7194894766807556


EP_Train:1:  46%|| 200/438 [01:04<01:14,  3.18it/s]

Epoch 1 - Step 638, mode Train: avg_loss=0.7256212276220322


EP_Train:1:  68%|| 300/438 [01:35<00:42,  3.22it/s]

Epoch 1 - Step 738, mode Train: avg_loss=0.7338892636696498


EP_Train:1:  91%|| 400/438 [02:05<00:12,  3.10it/s]

Epoch 1 - Step 838, mode Train: avg_loss=0.72885621920228


EP_Train:1: 100%|| 438/438 [02:17<00:00,  3.18it/s]


Whole Epoch 1, mode Train: avg_loss=0.7276347859265053, accuracy=0.8327304226281, macro precision=0.3568109511871917, macro recall=0.11608648563103914


EP_Eval:1: 100%|| 102/102 [00:11<00:00,  8.71it/s]

Epoch 1 - Step 876, mode Eval: avg_loss=0.7663249835371971





Whole Epoch 1, mode Eval: avg_loss=0.7619962873412114, accuracy=0.8324052801682178, macro precision=0.646473601735652, macro recall=0.1237173856775739


EP_Train:2:  23%|| 100/438 [00:30<01:41,  3.33it/s]

Epoch 2 - Step 976, mode Train: avg_loss=0.7115887609124184


EP_Train:2:  46%|| 200/438 [01:01<01:13,  3.26it/s]

Epoch 2 - Step 1076, mode Train: avg_loss=0.7177570323646069


EP_Train:2:  68%|| 300/438 [01:33<00:51,  2.69it/s]

Epoch 2 - Step 1176, mode Train: avg_loss=0.7175290818015735


EP_Train:2:  91%|| 400/438 [02:04<00:11,  3.28it/s]

Epoch 2 - Step 1276, mode Train: avg_loss=0.7153616753965616


EP_Train:2: 100%|| 438/438 [02:16<00:00,  3.21it/s]


Whole Epoch 2, mode Train: avg_loss=0.7154320800277196, accuracy=0.8326799636014854, macro precision=0.3380137911181363, macro recall=0.11694726912289743


EP_Eval:2:  99%|| 101/102 [00:10<00:00,  8.45it/s]

Epoch 2 - Step 1314, mode Eval: avg_loss=0.7474773308634758


EP_Eval:2: 100%|| 102/102 [00:10<00:00,  9.67it/s]


Whole Epoch 2, mode Eval: avg_loss=0.7434982499655556, accuracy=0.8326778552237062, macro precision=0.5783087227292311, macro recall=0.12129498946475427


EP_Train:3:  23%|| 100/438 [00:31<01:56,  2.91it/s]

Epoch 3 - Step 1414, mode Train: avg_loss=0.7043641597032547


EP_Train:3:  46%|| 200/438 [01:02<01:10,  3.37it/s]

Epoch 3 - Step 1514, mode Train: avg_loss=0.7062022814154625


EP_Train:3:  68%|| 300/438 [01:34<00:39,  3.48it/s]

Epoch 3 - Step 1614, mode Train: avg_loss=0.7039688166975975


EP_Train:3:  91%|| 400/438 [02:05<00:11,  3.39it/s]

Epoch 3 - Step 1714, mode Train: avg_loss=0.6986699825525284


EP_Train:3: 100%|| 438/438 [02:17<00:00,  3.18it/s]


Whole Epoch 3, mode Train: avg_loss=0.6993759469082367, accuracy=0.832813391855453, macro precision=0.43253743846809584, macro recall=0.12054473313620291


EP_Eval:3:  99%|| 101/102 [00:11<00:00,  9.19it/s]

Epoch 3 - Step 1752, mode Eval: avg_loss=0.7424794596433639


EP_Eval:3: 100%|| 102/102 [00:11<00:00,  8.87it/s]


Whole Epoch 3, mode Eval: avg_loss=0.7384240750004264, accuracy=0.8322884622872941, macro precision=0.5620256751460159, macro recall=0.12627929092968643


EP_Train:4:  23%|| 100/438 [00:31<01:46,  3.16it/s]

Epoch 4 - Step 1852, mode Train: avg_loss=0.6788500699400902


EP_Train:4:  46%|| 200/438 [01:02<01:07,  3.52it/s]

Epoch 4 - Step 1952, mode Train: avg_loss=0.6815797655284405


EP_Train:4:  68%|| 300/438 [01:33<00:46,  2.98it/s]

Epoch 4 - Step 2052, mode Train: avg_loss=0.6819700487454732


EP_Train:4:  91%|| 400/438 [02:05<00:11,  3.45it/s]

Epoch 4 - Step 2152, mode Train: avg_loss=0.6816404520720244


EP_Train:4: 100%|| 438/438 [02:16<00:00,  3.20it/s]


Whole Epoch 4, mode Train: avg_loss=0.6800295093424245, accuracy=0.8327291864428454, macro precision=0.3760045229775863, macro recall=0.12329913581176924


EP_Eval:4:  99%|| 101/102 [00:11<00:00,  9.20it/s]

Epoch 4 - Step 2190, mode Eval: avg_loss=0.7271996891498566


EP_Eval:4: 100%|| 102/102 [00:11<00:00,  8.92it/s]


Whole Epoch 4, mode Eval: avg_loss=0.7232310114537969, accuracy=0.8330283088664772, macro precision=0.5395298341996655, macro recall=0.12426831185601632


EP_Train:5:  23%|| 100/438 [00:31<01:53,  2.98it/s]

Epoch 5 - Step 2290, mode Train: avg_loss=0.6667281699180603


EP_Train:5:  46%|| 200/438 [01:02<01:08,  3.49it/s]

Epoch 5 - Step 2390, mode Train: avg_loss=0.6662799876928329


EP_Train:5:  68%|| 300/438 [01:33<00:43,  3.19it/s]

Epoch 5 - Step 2490, mode Train: avg_loss=0.6632698959112168


EP_Train:5:  91%|| 400/438 [02:05<00:12,  3.07it/s]

Epoch 5 - Step 2590, mode Train: avg_loss=0.6649074966460466


EP_Train:5: 100%|| 438/438 [02:17<00:00,  3.19it/s]


Whole Epoch 5, mode Train: avg_loss=0.6640648162120009, accuracy=0.8328379508882436, macro precision=0.3803470732948977, macro recall=0.12692881686378146


EP_Eval:5: 100%|| 102/102 [00:11<00:00,  9.19it/s]

Epoch 5 - Step 2628, mode Eval: avg_loss=0.7049059212207794





Whole Epoch 5, mode Eval: avg_loss=0.7012133741495656, accuracy=0.8331256571005802, macro precision=0.6072642968879968, macro recall=0.12878731501855


EP_Train:6:  23%|| 100/438 [00:32<01:46,  3.17it/s]

Epoch 6 - Step 2728, mode Train: avg_loss=0.6618957799673081


EP_Train:6:  46%|| 200/438 [01:03<01:19,  2.99it/s]

Epoch 6 - Step 2828, mode Train: avg_loss=0.6478245574235916


EP_Train:6:  68%|| 300/438 [01:33<00:43,  3.20it/s]

Epoch 6 - Step 2928, mode Train: avg_loss=0.6467839127779007


EP_Train:6:  91%|| 400/438 [02:05<00:11,  3.42it/s]

Epoch 6 - Step 3028, mode Train: avg_loss=0.6470156691968441


EP_Train:6: 100%|| 438/438 [02:16<00:00,  3.21it/s]


Whole Epoch 6, mode Train: avg_loss=0.6495273756246044, accuracy=0.8329889502218811, macro precision=0.4175040435286658, macro recall=0.1304983541223243


EP_Eval:6:  99%|| 101/102 [00:10<00:00,  8.07it/s]

Epoch 6 - Step 3066, mode Eval: avg_loss=0.6942897269129753


EP_Eval:6: 100%|| 102/102 [00:10<00:00,  9.40it/s]


Whole Epoch 6, mode Eval: avg_loss=0.6905674575006261, accuracy=0.8330088392196565, macro precision=0.5191542318078728, macro recall=0.12130371336167729


EP_Train:7:  23%|| 100/438 [00:31<01:43,  3.27it/s]

Epoch 7 - Step 3166, mode Train: avg_loss=0.6363174089789391


EP_Train:7:  46%|| 200/438 [01:02<01:16,  3.10it/s]

Epoch 7 - Step 3266, mode Train: avg_loss=0.638593435138464


EP_Train:7:  68%|| 300/438 [01:33<00:47,  2.89it/s]

Epoch 7 - Step 3366, mode Train: avg_loss=0.6359055504202843


EP_Train:7:  91%|| 400/438 [02:04<00:11,  3.37it/s]

Epoch 7 - Step 3466, mode Train: avg_loss=0.6336079628020524


EP_Train:7: 100%|| 438/438 [02:16<00:00,  3.22it/s]


Whole Epoch 7, mode Train: avg_loss=0.6330826381161877, accuracy=0.8332694010032458, macro precision=0.4343810711916842, macro recall=0.13354687539383436


EP_Eval:7: 100%|| 102/102 [00:12<00:00,  8.12it/s]

Epoch 7 - Step 3504, mode Eval: avg_loss=0.6915683579444886





Whole Epoch 7, mode Eval: avg_loss=0.6880335559447607, accuracy=0.8328141427514505, macro precision=0.4993190809045191, macro recall=0.1399463339400096


EP_Train:8:  23%|| 100/438 [00:31<01:39,  3.40it/s]

Epoch 8 - Step 3604, mode Train: avg_loss=0.6275615635514259


EP_Train:8:  46%|| 200/438 [01:02<01:05,  3.62it/s]

Epoch 8 - Step 3704, mode Train: avg_loss=0.6225086092948914


EP_Train:8:  68%|| 300/438 [01:33<00:42,  3.28it/s]

Epoch 8 - Step 3804, mode Train: avg_loss=0.622448493440946


EP_Train:8:  91%|| 400/438 [02:05<00:12,  2.96it/s]

Epoch 8 - Step 3904, mode Train: avg_loss=0.618564380928874


EP_Train:8: 100%|| 438/438 [02:16<00:00,  3.20it/s]


Whole Epoch 8, mode Train: avg_loss=0.618620239096145, accuracy=0.8328612733817194, macro precision=0.3882313612999112, macro recall=0.13404642180057036


EP_Eval:8: 100%|| 102/102 [00:11<00:00,  9.05it/s]

Epoch 8 - Step 3942, mode Eval: avg_loss=0.6739881280064582





Whole Epoch 8, mode Eval: avg_loss=0.6706786909524132, accuracy=0.832989369572836, macro precision=0.4441034911381944, macro recall=0.13056730561558455


EP_Train:9:  23%|| 100/438 [00:31<01:40,  3.37it/s]

Epoch 9 - Step 4042, mode Train: avg_loss=0.6152951392531395


EP_Train:9:  46%|| 200/438 [01:02<01:15,  3.17it/s]

Epoch 9 - Step 4142, mode Train: avg_loss=0.60964819714427


EP_Train:9:  68%|| 300/438 [01:33<00:46,  2.95it/s]

Epoch 9 - Step 4242, mode Train: avg_loss=0.6106923664609591


EP_Train:9:  91%|| 400/438 [02:04<00:12,  3.01it/s]

Epoch 9 - Step 4342, mode Train: avg_loss=0.6111584222316742


EP_Train:9: 100%|| 438/438 [02:16<00:00,  3.21it/s]


Whole Epoch 9, mode Train: avg_loss=0.6079617393071248, accuracy=0.8331193648826125, macro precision=0.4184650120118425, macro recall=0.13570177803074537


EP_Eval:9:  99%|| 101/102 [00:10<00:00,  7.91it/s]

Epoch 9 - Step 4380, mode Eval: avg_loss=0.6684470245242119


EP_Eval:9: 100%|| 102/102 [00:10<00:00,  9.48it/s]


Whole Epoch 9, mode Eval: avg_loss=0.6651803322282492, accuracy=0.8331061874537596, macro precision=0.5186842518551853, macro recall=0.13141734540569516


EP_Train:10:  23%|| 100/438 [00:31<01:51,  3.04it/s]

Epoch 10 - Step 4480, mode Train: avg_loss=0.5891771438717842


EP_Train:10:  46%|| 200/438 [01:02<01:13,  3.22it/s]

Epoch 10 - Step 4580, mode Train: avg_loss=0.5968253390491008


EP_Train:10:  68%|| 300/438 [01:33<00:42,  3.27it/s]

Epoch 10 - Step 4680, mode Train: avg_loss=0.6015199932456017


EP_Train:10:  91%|| 400/438 [02:04<00:11,  3.41it/s]

Epoch 10 - Step 4780, mode Train: avg_loss=0.5988377591967583


EP_Train:10: 100%|| 438/438 [02:16<00:00,  3.22it/s]


Whole Epoch 10, mode Train: avg_loss=0.5974449863716892, accuracy=0.8334719017656406, macro precision=0.4337720996766834, macro recall=0.13666947969678012


EP_Eval:10:  99%|| 101/102 [00:11<00:00,  8.87it/s]

Epoch 10 - Step 4818, mode Eval: avg_loss=0.6589709433913231


EP_Eval:10: 100%|| 102/102 [00:11<00:00,  8.86it/s]


Whole Epoch 10, mode Eval: avg_loss=0.6557231650632971, accuracy=0.8331645963942214, macro precision=0.49596012128570904, macro recall=0.14810099737962776


EP_Train:11:  23%|| 100/438 [00:32<01:43,  3.27it/s]

Epoch 11 - Step 4918, mode Train: avg_loss=0.5804343569278717


EP_Train:11:  46%|| 200/438 [01:04<01:16,  3.09it/s]

Epoch 11 - Step 5018, mode Train: avg_loss=0.5831435942649841


EP_Train:11:  63%|| 277/438 [01:27<00:46,  3.44it/s]

In [15]:
output_dir = os.path.join(
    MAIN_DIR, "artifacts", "checkpoints"
)
os.makedirs(output_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(output_dir, "ft_ner_checkpoint.pt"))