# Named Entity Recognition (NER) System R&D
    

## Load data

In [None]:
!pip install -qq datasets
!pip install -qq seqeval

In [None]:
from datasets import load_dataset

dataset = load_dataset("conll2003")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [None]:
dataset["train"][0]

{'id': '0',
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.'],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}

## BERT-CRF

## Data Loading and Processing

In [None]:
from typing import Dict, List, Optional
from pathlib import Path
import logging

class NERConfig:
    def __init__(
        self,
        label_types: List[str],
        id2label: Dict[int, str],
        max_len: int = 128,
        data_path: Optional[str | Path] = None,
        base_model_path: Optional[str] = None,
        models_path: Optional[str | Path] = None,
    ):
        if isinstance(data_path, str):
            data_path = Path(data_path)
        self.DATA_PATH = data_path
        if isinstance(models_path, str):
            self.MODEL_PATH = Path(models_path)
        self.BASE_MODEL_PATH = base_model_path or "bert-base-cased"
        self.LABEL_TYPES = label_types
        self.ID2LABEL = id2label
        self.LABEL2ID = {label: id for id, label in id2label.items()}
        self.MAX_LEN = max_len
        self.TRAIN_BATCH_SIZE = 64
        self.VALID_BATCH_SIZE = 32
        self.EPOCHS = 15
        self.OUT_DIM = 768  # bert-base-cased: 768, bert-large-cased: 1024
        self.MODEL_PATH = "./model.pt"
        self.TOKENIZER = transformers.BertTokenizer.from_pretrained(
            self.BASE_MODEL_PATH, do_lower_case=False
        )
        self.logger = logging.getLogger(self.__class__.__name__)

In [None]:
import torch
import transformers

LABEL_TYPES = ['X', '[CLS]', '[SEP]', 'O', 'I-LOC', 'B-PER', 'I-PER', 'I-ORG', 'I-MISC', 'B-MISC', 'B-LOC', 'B-ORG']
ID2LABEL = {id: label for id, label in enumerate(LABEL_TYPES)}
TOKENIZER = transformers.BertTokenizer.from_pretrained(
    "bert-large-cased", do_lower_case=False
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
ID2LABEL

{0: 'X',
 1: '[CLS]',
 2: '[SEP]',
 3: 'O',
 4: 'I-LOC',
 5: 'B-PER',
 6: 'I-PER',
 7: 'I-ORG',
 8: 'I-MISC',
 9: 'B-MISC',
 10: 'B-LOC',
 11: 'B-ORG'}

In [None]:
from typing import List, Dict, Optional

import torch
import numpy as np

# from ner_system.config import NERConfig

ner_config = NERConfig(
    label_types=LABEL_TYPES,
    id2label=ID2LABEL,
    max_len=128,
    base_model_path="bert-base-cased",
)


class CoNLL2003Document:
    """
    Class to represent a single CoNLL2003 document in its base form.

    Supports target labels as integers (if coming from `datasets` library)
    or strings (if coming from the "traditional" CoNLL2003 format).
    """

    def __init__(
        self,
        unique_id: int,
        tokens: List[str],
        ner_tags: List[str] | List[int],
    ):
        self.unique_id = unique_id
        self.tokens = tokens
        if isinstance(ner_tags[0], int):
            self.ner_tags_str = [ner_config.ID2LABEL[tag] for tag in ner_tags]
            self.ner_tags_int = ner_tags
        elif isinstance(ner_tags[0], str):
            self.ner_tags_int = [ner_config.LABEL2ID[tag] for tag in ner_tags]
            self.ner_tags_str = ner_tags

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        _token = self.tokens[idx]
        _tag = self.ner_tags_str[idx]
        return (_token, _tag)

    def to_dict(self):
        return {
            "unique_id": self.unique_id,
            "tokens": self.tokens,
            "ner_tags_str": self.ner_tags_str,
            "ner_tags_int": self.ner_tags_int,
        }


class CoNLL2003Features:
    """
    Class to represent a single CoNLL2003 document in its featureized form.

    A result from the `_convert_to_features` method of `NERDataset`.
    """

    def __init__(
        self,
        input_ids: List[int],
        input_mask: List[int],
        token_type_ids: List[int],
        crf_mask: List[int],  # CRF() requires specific mask to avoid: ValueError: mask of the first timestep must all be on
        predict_mask: List[int],
        label_ids: List[int],
    ):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.token_type_ids = token_type_ids
        self.crf_mask = crf_mask
        self.predict_mask = predict_mask
        self.label_ids = label_ids

In [None]:
class NERDataset:
    def __init__(self, data: List[CoNLL2003Document]):
        self.data = data
        # Set IDs of special tokens
        self.CLS = 101
        self.SEP = 102
        self.PAD = 0

    def __len__(self):
        return len(self.data)

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, value):
        self._data = value

    def __getitem__(self, idx):
        _doc = self.data[idx]
        _feat = self._convert_to_features(_doc)
        return (
            _feat.input_ids,
            _feat.input_mask,
            _feat.token_type_ids,
            _feat.crf_mask,
            _feat.predict_mask,
            _feat.label_ids,
        )

    def _convert_to_features(self, document: CoNLL2003Document) -> CoNLL2003Features:
        # Initializes sequence
        _add_label = "X"
        _tokens = ["[CLS]"]
        predict_mask = [0]
        label_ids = [0]

        # Populates sequence as it converts tokens to subwords
        for i, tok in enumerate(document.tokens):
            subwords = ner_config.TOKENIZER.tokenize(tok)
            if not subwords:
                subwords = ["[UNK]"]
            _tokens.extend(subwords)
            for j, sub in enumerate(subwords):
                if j == 0:
                    predict_mask.append(1)
                    label_ids.append(ner_config.LABEL2ID[document.ner_tags_str[i]])
                else:
                    predict_mask.append(0)
                    label_ids.append(ner_config.LABEL2ID[_add_label])

        # Implement truncation strategy (chops end of sequence)
        if len(_tokens) > ner_config.MAX_LEN - 1:
            _tokens = _tokens[0: ner_config.MAX_LEN - 1]
            predict_mask = predict_mask[0: ner_config.MAX_LEN - 1]
            label_ids = label_ids[0: ner_config.MAX_LEN - 1]

        # Finalizes sequence
        _tokens.append("[SEP]")
        predict_mask.append(0)
        label_ids.append(0)

        # Generates remaining features
        input_ids = ner_config.TOKENIZER.convert_tokens_to_ids(_tokens)
        input_mask = [1] * len(input_ids)
        crf_mask = [1] * len(input_ids)
        token_type_ids = [0] * len(input_ids)

        # Hard tests
        assert len(input_ids) == len(input_mask) == len(crf_mask) == len(token_type_ids) == len(predict_mask) == len(label_ids)

        return CoNLL2003Features(
            input_ids, input_mask, crf_mask, token_type_ids, predict_mask, label_ids
        )

    @classmethod
    def pad(cls, batch: List[Dict[str, torch.Tensor]]):
        _seq_lens = [len(sample[0]) for sample in batch]
        _maxlen = np.array(_seq_lens).max()

        def _f(x, seqlen):
            return [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch]

        input_ids_list = torch.LongTensor(_f(0, ner_config.MAX_LEN))
        input_mask_list = torch.LongTensor(_f(1, ner_config.MAX_LEN))
        crf_mask_list = torch.ByteTensor(_f(2, ner_config.MAX_LEN))
        token_type_ids_list = torch.LongTensor(_f(3, ner_config.MAX_LEN))
        predict_mask_list = torch.ByteTensor(_f(4, ner_config.MAX_LEN))
        label_ids_list = torch.LongTensor(_f(5, ner_config.MAX_LEN))

        return {
            "input_ids": input_ids_list,
            "input_mask": input_mask_list,
            "crf_mask": crf_mask_list,
            "token_type_ids": token_type_ids_list,
            "predict_mask": predict_mask_list,
            "label_ids": label_ids_list
        }

In [None]:
def convert_to_conll2003_document(
    tokens: List[str],
    ner_tags: List[int],
    unique_id: int,
) -> CoNLL2003Document:
    """
    Converts a single document from the CoNLL2003 dataset to a CoNLL2003Document object.
    """
    return CoNLL2003Document(tokens=tokens, ner_tags=ner_tags, unique_id=unique_id)


# _dummy = dataset["train"][:10]
# dummy = [
#     convert_to_conll2003_document(
#         tokens=toks,
#         ner_tags=tags,
#         unique_id=id,
#     )
#     for toks, tags, id in zip(_dummy["tokens"], _dummy["ner_tags"], _dummy["id"])
# ]
# a = NERDataset(dummy)

In [None]:
from datasets import Dataset


def prepare_NER_data(dataset: Dataset, split: str = None):
    if not split:
        ValueError("At least one of `train`, `validation` or `test` must be provided.")

    if split not in ["train", "validation", "test"]:
        ValueError("`split` must be one of `train`, `validation` or `test`.")

    _data = dataset[split]
    _docs = [
        convert_to_conll2003_document(
            tokens=toks,
            ner_tags=tags,
            unique_id=id,
        )
        for toks, tags, id in zip(_data["tokens"], _data["ner_tags"], _data["id"])
    ]

    return NERDataset(_docs)

In [None]:
from torch.utils.data import DataLoader

train_data = prepare_NER_data(dataset, split="train")
val_data = prepare_NER_data(dataset, split="validation")
test_data = prepare_NER_data(dataset, split="test")

train_loader = DataLoader(
    dataset=train_data,
    batch_size=ner_config.TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=NERDataset.pad
)
valid_loader = DataLoader(
    dataset=val_data,
    batch_size=ner_config.VALID_BATCH_SIZE,
    shuffle=True,
    collate_fn=NERDataset.pad
)
test_loader = DataLoader(
    dataset=test_data,
    batch_size=ner_config.VALID_BATCH_SIZE,
    shuffle=False,
    collate_fn=NERDataset.pad
)

In [None]:
for data in train_loader:
    batch = {k: v.to(DEVICE) for k, v in data.items()}
    print(batch)
    break

{'input_ids': tensor([[  101,   138,  1602,  ...,     0,     0,     0],
        [  101,   107,  1109,  ...,     0,     0,     0],
        [  101,  1109, 22593,  ...,     0,     0,     0],
        ...,
        [  101, 10163,  1105,  ...,     0,     0,     0],
        [  101, 24890,  9741,  ...,     0,     0,     0],
        [  101,  6232,  1496,  ...,     0,     0,     0]], device='cuda:0'), 'input_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'), 'crf_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
   

In [None]:
batch.keys()

dict_keys(['input_ids', 'input_mask', 'crf_mask', 'token_type_ids', 'predict_mask', 'label_ids'])

### Model Training

In [None]:
import torch.nn as

# NOTE: Forked from https://pytorch-crf.readthedocs.io/en/stable/_modules/torchcrf.html#CRF

class CRF(nn.Module):
    """Conditional random field.

    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.

    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.

    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.


    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282–289.

    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.

        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)


    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.

        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.

        Returns:
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)


    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            # assert len(best_tags) == seq_length, "length of best_tags: {} != seq_length: {}".format(len(best_tags), seq_length)
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

In [None]:
import logging

import numpy as np

import torch
import torch.nn as nn
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import LRScheduler
#from torchcrf import CRF

from transformers import get_linear_schedule_with_warmup

from seqeval.metrics import classification_report, accuracy_score, f1_score

from tqdm import tqdm


def filter_sequence_output(sequence_of_tags, input_mask, predict_mask):
    batch_size = len(sequence_of_tags)
    max_len = ner_config.MAX_LEN
    current_len = len(sequence_of_tags[0])

    filtered_tags = torch.zeros(
        batch_size,
        max_len,
        dtype=torch.long,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    filtered_input_mask = torch.zeros(
        batch_size,
        max_len,
        dtype=torch.long,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )

    c = 0
    try:
        for i in range(batch_size):
            jj = -1
            for j in range(current_len):
                if predict_mask[i][j].item() == 1:
                    jj += 1
                    filtered_tags[i][jj] = sequence_of_tags[i][j] * input_mask[i][j]
                    filtered_input_mask[i][jj] = input_mask[i][j]
                    c += 1
    except Exception as e:
        print(f"c: {c}")
        print(f"i: {i}, j: {j}, jj: {jj}")
        print(f"len sequence_of_tags: {len(sequence_of_tags)}")
        print(f"len sequence_of_tags[i]: {len(sequence_of_tags[i])}")
        print(f"predict_mask: {predict_mask.shape}")
        print(f"filtered_tags: {filtered_tags.shape}")
        print(f"filtered_input_mask: {filtered_input_mask.shape}")
        raise(e)

    return filtered_tags, filtered_input_mask


class BERTCRFModel(nn.Module):
    def __init__(
        self,
        num_tags: int,
        out_dim: int,
        batch_first: bool = True,
        dropout: float = None,
    ):
        super(BERTCRFModel, self).__init__()
        self.num_tags = num_tags
        self.out_dim = out_dim
        if dropout:
            self.dropout = nn.Dropout(dropout)
        self.bert = transformers.BertModel.from_pretrained(
            ner_config.BASE_MODEL_PATH, return_dict=False
        )
        self.linear = nn.Linear(self.out_dim, self.num_tags)
        # self.linear.bias.data[0] = 6
        self.crf = CRF(num_tags, batch_first=batch_first)
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.info(f"Initialized BERTCRFModel with the following params: {self.__dict__}")

    def forward(
        self, input_ids, input_mask, crf_mask, token_type_ids, predict_mask, label_ids
    ):
        o1, _ = self.bert(
            input_ids=input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
        )
        if self.dropout:
            o1 = self.dropout(o1)
        # Take encoded output and compute its emission matrices
        # print(o1.shape)
        # print(model.linear.weight.shape)
        emissions = self.linear(o1)

        # Compute the negative log-likelihood for the CRF
        crf_mask[0].fill_(1)  # Mitigates: ValueError: mask of the first timestep must all be on
        assert crf_mask[:, 0].all()

        print(f"inputs_ids: {input_ids.shape}")
        print(f"predict_mask: {predict_mask.shape}")
        print(f"label_ids: {label_ids.shape}")
        print(f"emissions: {emissions.shape}")

        log_likelihood = self.crf(
            emissions=emissions,
            tags=label_ids,
            mask=crf_mask.bool(),
            reduction="mean"
        )

        # Decode the most likely sequence of tags
        sequence_of_tags = self.crf.decode(emissions, mask=crf_mask.bool())

        print(f"sequence_of_tags: {sequence_of_tags}")

        # Retrieve the target tag given the input and predicted masks
        pred_tags, filtered_mask = filter_sequence_output(
            sequence_of_tags, input_mask, predict_mask
        )
        loss = -1 * log_likelihood

        return pred_tags, label_ids, filtered_mask, loss

In [None]:
def train(
    dataloader: DataLoader,
    model: BERTCRFModel,
    optimizer: Optimizer,
    device: torch.device,
    scheduler: Optional[LRScheduler] = None,
):
    model.train()

    final_loss = 0
    y_gold = []
    y_pred = []

    for data in tqdm(dataloader, total=len(dataloader)):
        batch = {k: v.to(DEVICE) for k, v in data.items()}
        optimizer.zero_grad()

        pred_tags, gold_tags, mask, loss = model(**batch)
        # print(f"pred_tags: {pred_tags.shape}")
        # print(f"gold_tags: {gold_tags.shape}")
        # print(f"mask: {mask.shape}")
        # print(f"loss: {loss}")

        mapped_gold = [
            ner_config.ID2LABEL[item]
            for item in torch.flatten(mask * pred_tags).detach().cpu().tolist()
        ]
        mapped_pred = [
            ner_config.ID2LABEL[item]
            for item in torch.flatten(mask * gold_tags).detach().cpu().tolist()
        ]
        y_gold.append(mapped_gold)
        y_pred.append(mapped_pred)

        loss.backward()
        optimizer.step()
        scheduler.step()

        final_loss += loss.item()

    return final_loss / len(dataloader), f1_score(y_gold, y_pred)


def evaluate(dataloader: DataLoader, model: BERTCRFModel, device: torch.device):
    model.eval()

    final_loss = 0
    y_gold = []
    y_pred = []

    with torch.no_grad():
        for data in tqdm(dataloader, total=len(dataloader)):
            batch = {k: v.to(DEVICE) for k, v in data.items()}

            pred_tags, gold_tags, mask, loss = model(**batch)
            mapped_target = [
                ner_config.ID2LABEL[item]
                for item in torch.flatten(mask * gold_tags).detach().cpu().tolist()
            ]
            mapped_output = [
                ner_config.ID2LABEL[item]
                for item in torch.flatten(mask * pred_tags).detach().cpu().tolist()
            ]
            y_gold.append(mapped_target)
            y_pred.append(mapped_output)

            final_loss += loss.item()

    return final_loss / len(dataloader), f1_score(y_gold, y_pred)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BERTCRFModel(
    num_tags=len(ner_config.LABEL_TYPES),
    out_dim=ner_config.OUT_DIM,
    batch_first=False,
    dropout=0.1,
).to(DEVICE)

optim_parameters = [
    {"params": model.bert.parameters(), "lr": 5e-5},
    {"params": model.linear.parameters(), "lr": 1e-3},
    {"params": model.crf.parameters(), "lr": 1e-3},
]
optimizer = AdamW(optim_parameters, weight_decay=0.01)

_num_train_steps = (
    dataset["train"].num_rows / ner_config.TRAIN_BATCH_SIZE * ner_config.EPOCHS
)
num_train_steps = int(_num_train_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=num_train_steps*0.1, num_training_steps=num_train_steps
)

In [None]:
torch.cuda.empty_cache()

best_loss = np.inf
for epoch in range(ner_config.EPOCHS):
    train_loss, train_f1 = train(
        dataloader=train_loader,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        device=DEVICE
      )
    val_loss, val_f1 = evaluate(
        dataloader=valid_loader,
        model=model,
        device=DEVICE
      )
    print(f"Train loss = {train_loss} Valid loss = {val_loss} ")
    print(f"Train f1_score = {train_f1} Valid f1_score = {val_f1} ")
    if val_loss < best_loss:
        torch.save(model.state_dict(), ner_config.MODEL_PATH)
        best_loss = val_loss

    torch.cuda.empty_cache()

  0%|          | 0/220 [00:00<?, ?it/s]

inputs_ids: torch.Size([64, 128])
predict_mask: torch.Size([64, 128])
label_ids: torch.Size([64, 128])
emissions: torch.Size([64, 128, 12])
sequence_of_tags: [[11, 5, 2, 5, 4, 5, 0, 5, 4, 5, 11, 4, 10, 5, 8, 5, 4, 5, 4, 0, 5, 8, 8, 7, 4, 4, 8, 8, 4, 5, 4, 5, 4, 5, 4, 0, 11, 5, 8, 8, 8, 8, 4, 4, 5, 8, 2, 5, 8, 5, 0, 5, 5, 4, 5, 4, 4, 11, 5, 5, 4, 4, 4, 5], [5, 4, 1, 2, 1, 0, 11, 4, 4, 1, 7, 0, 11, 5, 4, 1, 5, 3, 5, 1, 4, 3, 8, 7, 3, 8, 8, 4, 1, 3, 4, 5, 1, 2, 1, 5, 1, 2, 7, 5, 1, 1, 2, 5, 8, 0, 11, 8, 0, 8, 0, 5, 1, 5, 9, 4, 5, 3, 5, 1, 4, 4, 5, 1], [8, 8, 2, 8, 1, 7, 0, 4, 3, 2, 7, 7, 6, 11, 8, 7, 11, 7, 11, 8, 8, 7, 11, 8, 7, 8, 7, 8, 8, 2, 8, 8, 7, 7, 8, 2, 8, 11, 2, 1, 7, 11, 7, 11, 7, 8, 7, 8, 8, 2, 8, 8, 11, 7, 7, 8, 7, 11, 7, 7, 11, 7, 7, 11], [7, 11, 1, 7, 8, 2, 7, 11, 8, 7, 8, 7, 7, 11, 1, 7, 1, 2, 8, 7, 6, 3, 2, 1, 7, 5, 1, 7, 8, 8, 7, 3, 7, 8, 8, 2, 1, 3, 2, 1, 7, 6, 7, 7, 8, 7, 7, 11, 7, 4, 8, 1, 11, 1, 7, 8, 7, 8, 7, 11, 8, 7, 11, 7], [11, 1, 7, 8, 7, 1, 1, 7, 8, 2, 7, 11, 




IndexError: list index out of range