# Imports, Seeds, Initialization of Modules


Setup required python modules

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 4.3MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/08/cd/342e584ee544d044fb573ae697404ce22ede086c9e87ce5960772084cad0/sacremoses-0.0.44.tar.gz (862kB)
[K     |████████████████████████████████| 870kB 18.1MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 17.4MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.44-cp37-none-any.whl size=886084 sha256=fb060

Import everything

In [None]:
import os
import json
import torch
import random
import numpy as np
import torch.nn as nn
import tqdm.notebook as tq
import torch.nn.functional as F
from torch.backends import cudnn
from transformers import AutoTokenizer, AutoModel

from sklearn.metrics import classification_report

from typing import List, Optional

Seeds for reproducibility and Check for GPU

In [None]:
my_seed = 1
random.seed(my_seed)
np.random.seed(my_seed)
torch.manual_seed(my_seed)
cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

gpu_device = 0
use_cuda = torch.cuda.is_available()
if (use_cuda):
    torch.cuda.manual_seed(my_seed)
    print("Using GPU")

Using GPU


Initialize CRF Layer Module

In [None]:
class CRF(nn.Module):
    """Conditional random field.
    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.
    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.
    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282–289.
    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.
        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
        Returns:
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

BERT-based Sequence Tagging System Module

In [None]:
class GreekNERSystem(nn.Module):
    def __init__(self, b_size, n_classes_output, hidden_size):
        super(GreekNERSystem, self).__init__()

        self.b_size = b_size
        self.n_classes_output = n_classes_output
        self.hidden_size = hidden_size

        # Load AutoModel from huggingface model repository
        self.tokenizer = AutoTokenizer.from_pretrained("nlpaueb/bert-base-greek-uncased-v1")
        self.bert = AutoModel.from_pretrained("nlpaueb/bert-base-greek-uncased-v1")

        # Freeze the parameters
        # NOTE: Delete / Comment this for finetuning
        for param in self.bert.parameters():
            param.requires_grad = False

        # MLP with 2 Layers
        # NOTE: You can use less or more layers
        self.linear = nn.Linear(768, self.hidden_size, bias=True)
        self.linear_out_output = nn.Linear(self.hidden_size, self.n_classes_output, bias=True)

        # CRF Layer
        self.crf = CRF(num_tags=self.n_classes_output, batch_first=True)

    def preprocess_batch(self, batch):
        # Preprocess the instances for a batch 
        # NOTE: id 101 is [CLS], 102 is [SEP], [103] is MASK
        preprocessed_batch = list()
        for inst in batch:
            # Tokenize each word inside of every sentence
            inst_tokens = [[self.tokenizer(w, add_special_tokens=False, return_tensors='pt')['input_ids'] for w in s[0]] for s in inst]
            # Calculate the lengths of subtokens for each token (we will use it later to connect the subtokens)
            inst_token_lens = [[w.shape[1] for w in s] for s in inst_tokens]
            # Calculate the length of each sentence in subtokens
            inst_sent_lens = [sum(s) for s in inst_token_lens]
            # If the length of the combined sentence is > 510 then discard the instance
            # (510 instead of 512 because we will add the [CLS] and [SEP] tokens)
            if sum(inst_sent_lens) > 510:
                print('Instance subtokens > 512 , skipping...')
                continue
            # Get the tags of the instance
            inst_tags = list()
            for s_i, s in enumerate(inst_token_lens):
                inst_tags.append(list())
                for tok_i, tok_len in enumerate(s):
                    # Get the tag of the token and create a list of subtoken tags
                    # The first tag gets the tag of the token and the rest get the 'X' token except if the tag if
                    # 'O', which don't get a 'X' tag at all
                    tok_tag = inst[s_i][1][tok_i]
                    if tok_tag == 'O':
                        tags = ['O'] * tok_len
                    else:
                        tags = ['X'] * tok_len
                        tags[0] = tok_tag
                    inst_tags[-1].extend(tags)
            # Concatenate the lists of the 3 sentences tokens into 1
            inst_tokens = torch.cat([torch.LongTensor([[101]])] + [torch.cat(s, dim=1) for s in inst_tokens if s] + [torch.LongTensor([[102]])], dim=1)

            # Concatenate the lists of the 3 sentences tags into 1
            inst_tags = inst_tags[0] + inst_tags[1] + inst_tags[2]
            preprocessed_batch.append([
                inst_tokens,
                inst_token_lens,
                inst_tags
            ])

        return [e[0] for e in preprocessed_batch], [e[1] for e in preprocessed_batch], [e[2] for e in preprocessed_batch]

    def forward(self, instances, instances_sents):
        instances_logits_output = list()
        for inst, sents in zip(instances, instances_sents):
            # If GPU is available use cuda
            if use_cuda:
                inst = inst.to('cuda')
                inst_token_type_ids = torch.zeros(inst.shape).type(torch.LongTensor).to('cuda')
                inst_attention_mask = torch.ones(inst.shape).type(torch.LongTensor).to('cuda')
            else:
                inst_token_type_ids = torch.zeros(inst.shape).type(torch.LongTensor)
                inst_attention_mask = torch.ones(inst.shape).type(torch.LongTensor)

            # Pass the input tokens through the BERT model and get the contextual representation for the tokens
            bert_out = self.bert(inst,
                                 token_type_ids=inst_token_type_ids,
                                 attention_mask=inst_attention_mask)[0]

            # Get rid of the contextual representations for [CLS] and [SEP]
            bert_out = bert_out[:, 1:-1, :]

            # Pass only the middle sentence though the MLP module to get the logits
            linear_out = self.linear(bert_out[:,sents[0]:sents[0]+sents[1],:])
            logits_out = self.linear_out_output(F.relu(linear_out))

            instances_logits_output.append(logits_out)

        return instances_logits_output

# Main code

## Setup

Initialize global parameters

In [None]:
# Batch size
batch_size = 2
# Hidden size of MLP
hidden_size = 100
# Learning rate
lr = 1e-3
# Max epochs of system
epochs = 100
# Patience of system
max_patience = 5

# The classes that we can predict
# NOTE: The 'X' class is used as a class label of all the subwords of a word entity (other than 'O') proceeding the first word
class_dict = {
    'O': 0,
    'B-Person': 1,
    'I-Person': 2,
    'B-Organization': 3,
    'I-Organization': 4,
    'B-Location': 5,
    'I-Location': 6,
    'B-GPE': 7,
    'I-GPE': 8,
    'B-Facility': 9,
    'I-Facility': 10,
    'X': 11
}

# Get the inverted class dictionary
inv_class_dict = {class_dict[k]: k for k in class_dict}

Load data and Split to train, dev and test sets

In [None]:
# Load the sample completion instance
with open('NER_DATA.json', encoding='utf-8') as fin:
    instances = json.load(fin)

# Find the sizes of the train, dev and test splits using 70% - 20% - 10%
train_len = int(len(instances) * 0.7)
dev_len = int(len(instances) * 0.2)
test_len = int(len(instances) * 0.1)

# Add any trailing instances to the train (due to the integer casting)
if len(instances) > train_len+dev_len+test_len:
    train_len += len(instances) - (train_len+dev_len+test_len)

# Randomly sample the train, dev and test splits from the data
train_instances = random.sample(instances, train_len)
rest_instances = [e for e in instances if e not in train_instances]
dev_instances = random.sample(rest_instances, dev_len)
test_instances = [e for e in rest_instances if e not in dev_instances]


Initialize the System and Optimizer

In [None]:
model = GreekNERSystem(b_size=batch_size,
                              n_classes_output=len(class_dict),
                              hidden_size=hidden_size)

# If GPU available use cuda
if use_cuda:
    model.to('cuda')

# Initialize an Adam optimizer with the learning rate from the global parameters
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Initialize the CrossEntropyLoss
cross_entropy = nn.CrossEntropyLoss()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=459.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=529930.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=454248854.0, style=ProgressStyle(descri…




Preprocess all batches

In [None]:
# Preprocess all train batches
print('Preprocessing train batches')
train_batches = [train_instances[i*batch_size: (i+1)*batch_size] for i in tq.trange((len(train_instances)//batch_size)+1)]
if not train_batches[-1]:
    train_batches = train_batches[:-1]

train_batches = [model.preprocess_batch(batch) for batch in tq.tqdm(train_batches)]


# Preprocess all dev batches
print('Preprocessing dev batches')
dev_batches = [dev_instances[i*batch_size: (i+1)*batch_size] for i in tq.trange((len(dev_instances)//batch_size)+1)]
if not dev_batches[-1]:
    dev_batches = dev_batches[:-1]

dev_batches = [model.preprocess_batch(batch) for batch in tq.tqdm(dev_batches)]

# Preprocess all test batches
print('Preprocessing test batches')
test_batches = [test_instances[i*batch_size: (i+1)*batch_size] for i in tq.trange((len(test_instances)//batch_size)+1)]
if not test_batches[-1]:
    test_batches = test_batches[:-1]

test_batches = [model.preprocess_batch(batch) for batch in tq.tqdm(test_batches)]

Preprocessing train batches


HBox(children=(FloatProgress(value=0.0, max=402.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Preprocessing dev batches


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Preprocessing test batches


HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))




## Training

Global variables for training

In [None]:
patience = max_patience
best_f1 = 0.0
best_f1_output = 0.0
best_epoch = 0
best_cr = None

Train the model

In [None]:
# Open a file for logging
log_f = open('system_log.txt', 'w', encoding='utf-8')
log_f.flush()
# Train for the max epochs
for epoch in range(epochs):
    print(f'\n{"="*30} EPOCH {epoch} {"="*30}\n')
    log_f.write(f'\n{"="*30} EPOCH {epoch} {"="*30}\n\n')
    log_f.flush()

    # Set model to train mode for the train set
    model.train()

    # Shuffle the training batches
    random.shuffle(train_batches)

    losses = list()
    true_labels_output = list()
    pred_labels_output = list()

    for p_batch, p_batch_token_lens, label_tags in tq.tqdm(train_batches):

        # Find the batch sentence lengths
        p_batch_sents = [[sum(s) for s in b] for b in p_batch_token_lens]

        # Reset the gradients
        optimizer.zero_grad()

        # Pass the batch though the system
        model_out = model(p_batch, p_batch_sents)

        # Calculate the predictions of the model using the decode function of the CRF layer and add them to a list
        pred_labels_output.extend([model.crf.decode(e)[0] for e in model_out])

        # Get only the label_tags of the middle sentence
        label_tags_middle = [l[s[0]:s[0]+s[1]] for l, s in zip(label_tags, p_batch_sents)]

        # Add the label tags of the middle sentence to a list
        true_labels_output.extend([[class_dict[e2] for e2 in e] for e in label_tags_middle])

        # Add the label tags of the middle sentence to a Tensor object (use GPU if available)
        labels_output = [torch.LongTensor([class_dict[e2] for e2 in e]) for e in label_tags_middle]
        if use_cuda:
            labels_output = [e.cuda(gpu_device) for e in labels_output]

        # Calculate the loss of each instance of the batch using the CRF Layer
        loss_output = [-model.crf(e[0], e[1].view(1, -1), mask=(e[1] != 11).view(1, -1), reduction='token_mean') for e in zip(model_out, labels_output)]

        # Calculate the loss of the batch (mean of the losses of each instance)
        loss = torch.mean(torch.stack(loss_output))

        # Add the loss value to a list
        losses.append(loss.item())

        # Perform back propagation to calculate the gradients
        loss.backward()

        # Update the parameters of the model
        optimizer.step()

    # Get the true and pred labels of all batches into a flat list
    true_labels_output_flat = [item for sublist in true_labels_output for item in sublist]
    pred_labels_output_flat = [item for sublist in pred_labels_output for item in sublist]

    print(f'Epoch {epoch} Train Loss: {np.mean(losses)}')
    log_f.write(f'Epoch {epoch} Train Loss: {np.mean(losses)}\n')
    log_f.flush()
    
    # Calculate the PRF scores using the true and pred labels for the epoch
    # NOTE: We don't take into account the 'X' labels for the scores
    cr_output = classification_report(y_true=true_labels_output_flat,
                                        y_pred=pred_labels_output_flat,
                                        labels=[e for e in inv_class_dict.keys() if e != 11],
                                        target_names=[e for e in class_dict.keys() if e != 'X'],
                                        output_dict=True)

    print(classification_report(y_true=true_labels_output_flat,
                                y_pred=pred_labels_output_flat,
                                labels=[e for e in inv_class_dict.keys() if e != 11],
                                target_names=[e for e in class_dict.keys() if e != 'X'],
                                output_dict=False))

    log_f.write(classification_report(y_true=true_labels_output_flat,
                                        y_pred=pred_labels_output_flat,
                                        labels=[e for e in inv_class_dict.keys() if e != 11],
                                        target_names=[e for e in class_dict.keys() if e != 'X'],
                                        output_dict=False) + '\n')
    log_f.flush()
    print('Train Output F1-Score: {}'.format(cr_output['macro avg']['f1-score']))
    log_f.write('Train Output F1-Score: {}\n'.format(cr_output['macro avg']['f1-score']))
    log_f.flush()

    # Set model to evaluation mode for the dev set
    model.eval()

    losses = list()
    true_labels_output = list()
    pred_labels_output = list()

    for p_batch, p_batch_token_lens, label_tags in tq.tqdm(dev_batches):

        # Find the batch sentence lengths
        p_batch_sents = [[sum(s) for s in b] for b in p_batch_token_lens]

        # Pass the batch though the system
        model_out = model(p_batch, p_batch_sents)

        # Calculate the predictions of the model using the decode function of the CRF layer and add them to a list
        pred_labels_output.extend([model.crf.decode(e)[0] for e in model_out])

        # Get only the label_tags of the middle sentence
        label_tags_middle = [l[s[0]:s[0]+s[1]] for l, s in zip(label_tags, p_batch_sents)]

        # Add the label tags of the middle sentence to a list
        true_labels_output.extend([[class_dict[e2] for e2 in e] for e in label_tags_middle])

        # Add the label tags of the middle sentence to a Tensor object (use GPU if available)
        labels_output = [torch.LongTensor([class_dict[e2] for e2 in e]) for e in label_tags_middle]
        if use_cuda:
            labels_output = [e.cuda(gpu_device) for e in labels_output]

        # Calculate the loss of each instance of the batch using the CRF Layer
        loss_output = [-model.crf(e[0], e[1].view(1, -1), mask=(e[1] != 11).view(1, -1), reduction='token_mean') for e in zip(model_out, labels_output)]

        # Calculate the loss of the batch (mean of the losses of each instance)
        loss = torch.mean(torch.stack(loss_output))

        # Add the loss value to a list
        # NOTE: We don't perform back propagation because we are in the development set
        losses.append(loss.item())

    # Get the true and pred labels of all batches into a flat list
    true_labels_output_flat = [item for sublist in true_labels_output for item in sublist]
    pred_labels_output_flat = [item for sublist in pred_labels_output for item in sublist]

    print(f'Epoch {epoch} Dev Loss: {np.mean(losses)}')
    log_f.write(f'Epoch {epoch} Dev Loss: {np.mean(losses)}\n')
    log_f.flush()

    # Calculate the PRF scores using the true and pred labels for the epoch
    # NOTE: We don't take into account the 'X' labels for the scores
    cr_output = classification_report(y_true=true_labels_output_flat,
                                        y_pred=pred_labels_output_flat,
                                        labels=[e for e in inv_class_dict.keys() if e != 11],
                                        target_names=[e for e in class_dict.keys() if e != 'X'],
                                        output_dict=True)

    cr_output_text = classification_report(y_true=true_labels_output_flat,
                                        y_pred=pred_labels_output_flat,
                                        labels=[e for e in inv_class_dict.keys() if e != 11],
                                        target_names=[e for e in class_dict.keys() if e != 'X'],
                                        output_dict=False)

    print(cr_output_text)
    log_f.write(cr_output_text + '\n')
    print('Dev Output F1-Score: {}'.format(cr_output['macro avg']['f1-score']))
    log_f.write('Dev Output F1-Score: {}\n\n'.format(cr_output['macro avg']['f1-score']))
    log_f.flush()

    print()

    # Check if the macro avg f1-score has improved
    if cr_output['macro avg']['f1-score'] > best_f1:
        # Assign the new best macro avg f1-score, classification reports, epoch and reset patience
        best_f1 = cr_output['macro avg']['f1-score']
        best_f1_output = cr_output['macro avg']['f1-score']
        best_cr = [cr_output, cr_output_text]
        best_epoch = epoch
        patience = max_patience

        # Save the parameters of the best state of the system
        state = dict(
            model=model.state_dict(),
            optimizer=optimizer.state_dict(),
            best_f1=best_f1,
            best_epoch=best_epoch,
            best_cr=best_cr
        )
        torch.save(state, 'system_best_epoch.pth.tar')
        print('Model saved')
        print()
        log_f.write('Model saved\n\n')
        log_f.flush()

    else:
        # Decrease the patience
        patience -= 1
        # If the patience variable goes to 0, then we stop the training of the system
        if patience == 0:
            break
    print(f'Best epoch: {best_epoch}')
    print(f'Patience: {patience}')
    print()
    log_f.write(f'Best epoch: {best_epoch}\n')
    log_f.write(f'Patience: {patience}\n\n')
    log_f.flush()

print(f'\n{"="*30} FINAL RESULTS {"="*30}\n')
log_f.write(f'\n{"="*30} FINAL RESULTS {"="*30}\n\n')

print('Best Epoch: {}\nOutput F1 Score: {}'.format(best_epoch, best_f1_output))
print(best_cr[1])

log_f.write('Best Epoch: {}\nOutput F1 Score: {}\n'.format(best_epoch, best_f1_output))
log_f.write(best_cr[1] + '\n')
log_f.flush()
# Close log file
log_f.close()





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 0 Train Loss: 0.16361038047018195
                precision    recall  f1-score   support

             O       0.96      1.00      0.98     12659
      B-Person       0.62      0.26      0.37        87
      I-Person       0.72      0.32      0.45        71
B-Organization       0.75      0.08      0.14        76
I-Organization       0.53      0.12      0.19        77
    B-Location       0.00      0.00      0.00        47
    I-Location       0.00      0.00      0.00        32
         B-GPE       0.00      0.00      0.00        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.95      0.97      0.96     13082
     macro avg       0.33      0.16      0.19     13082
  weighted avg       0.94      0.97      0.95     13082

Train Output F1-Score: 0.1935711818762136


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 0 Dev Loss: 0.08741513593021133
                precision    recall  f1-score   support

             O       0.95      1.00      0.97      3623
      B-Person       0.60      0.43      0.50        14
      I-Person       0.60      0.50      0.55        12
B-Organization       1.00      0.04      0.07        27
I-Organization       0.20      0.06      0.09        33
    B-Location       0.00      0.00      0.00        10
    I-Location       0.00      0.00      0.00        13
         B-GPE       0.00      0.00      0.00        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.95      0.97      0.96      3750
     macro avg       0.30      0.18      0.20      3750
  weighted avg       0.93      0.97      0.95      3750

Dev Output F1-Score: 0.19861865073837925

Model saved

Best epoch: 0
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 1 Train Loss: 0.06542346943328234
                precision    recall  f1-score   support

             O       0.97      1.00      0.98     12659
      B-Person       0.70      0.62      0.66        87
      I-Person       0.64      0.65      0.64        71
B-Organization       0.76      0.45      0.56        76
I-Organization       0.63      0.31      0.42        77
    B-Location       0.36      0.11      0.16        47
    I-Location       0.00      0.00      0.00        32
         B-GPE       0.29      0.08      0.12        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97     13082
     macro avg       0.39      0.29      0.32     13082
  weighted avg       0.95      0.98      0.96     13082

Train Output F1-Score: 0.3229221581481173


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 1 Dev Loss: 0.06829339210222395
                precision    recall  f1-score   support

             O       0.96      1.00      0.98      3623
      B-Person       0.50      0.57      0.53        14
      I-Person       0.43      0.25      0.32        12
B-Organization       0.82      0.33      0.47        27
I-Organization       0.86      0.18      0.30        33
    B-Location       0.80      0.40      0.53        10
    I-Location       0.00      0.00      0.00        13
         B-GPE       1.00      0.25      0.40        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.97      0.96      3750
     macro avg       0.49      0.27      0.32      3750
  weighted avg       0.95      0.97      0.96      3750

Dev Output F1-Score: 0.32134235381843035

Model saved

Best epoch: 1
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 2 Train Loss: 0.04048093108320885
                precision    recall  f1-score   support

             O       0.97      1.00      0.99     12659
      B-Person       0.76      0.76      0.76        87
      I-Person       0.66      0.76      0.71        71
B-Organization       0.79      0.58      0.67        76
I-Organization       0.71      0.52      0.60        77
    B-Location       0.68      0.40      0.51        47
    I-Location       0.44      0.12      0.20        32
         B-GPE       0.71      0.40      0.51        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.98     13082
     macro avg       0.52      0.41      0.45     13082
  weighted avg       0.96      0.98      0.97     13082

Train Output F1-Score: 0.4484256681385778


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 2 Dev Loss: 0.05594697901786118
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.60      0.43      0.50        14
      I-Person       0.41      0.58      0.48        12
B-Organization       0.78      0.52      0.62        27
I-Organization       0.60      0.27      0.37        33
    B-Location       0.62      0.50      0.56        10
    I-Location       1.00      0.23      0.38        13
         B-GPE       1.00      0.33      0.50        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.54      0.35      0.40      3750
  weighted avg       0.96      0.98      0.96      3750

Dev Output F1-Score: 0.3992546973948006

Model saved

Best epoch: 2
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 3 Train Loss: 0.029152918524361797
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.82      0.76      0.79        87
      I-Person       0.63      0.80      0.71        71
B-Organization       0.82      0.70      0.75        76
I-Organization       0.73      0.64      0.68        77
    B-Location       0.63      0.47      0.54        47
    I-Location       0.69      0.28      0.40        32
         B-GPE       0.44      0.32      0.37        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.99      0.98     13082
     macro avg       0.52      0.45      0.48     13082
  weighted avg       0.97      0.99      0.98     13082

Train Output F1-Score: 0.4751242231945058


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 3 Dev Loss: 0.05585827037445557
                precision    recall  f1-score   support

             O       0.96      1.00      0.98      3623
      B-Person       0.64      0.50      0.56        14
      I-Person       0.42      0.42      0.42        12
B-Organization       0.91      0.37      0.53        27
I-Organization       0.70      0.21      0.33        33
    B-Location       0.67      0.40      0.50        10
    I-Location       1.00      0.15      0.27        13
         B-GPE       0.67      0.67      0.67        12
         I-GPE       0.50      0.17      0.25         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.59      0.35      0.41      3750
  weighted avg       0.96      0.98      0.96      3750

Dev Output F1-Score: 0.4084344127958841

Model saved

Best epoch: 3
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 4 Train Loss: 0.021700804662591246
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.85      0.91      0.88        87
      I-Person       0.62      0.92      0.74        71
B-Organization       0.83      0.68      0.75        76
I-Organization       0.67      0.58      0.62        77
    B-Location       0.72      0.60      0.65        47
    I-Location       0.53      0.31      0.39        32
         B-GPE       0.64      0.56      0.60        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.99      0.98     13082
     macro avg       0.53      0.51      0.51     13082
  weighted avg       0.97      0.99      0.98     13082

Train Output F1-Score: 0.5110459512116129


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 4 Dev Loss: 0.05052083363861643
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.60      0.64      0.62        14
      I-Person       0.42      0.42      0.42        12
B-Organization       0.76      0.48      0.59        27
I-Organization       0.62      0.24      0.35        33
    B-Location       0.50      0.80      0.62        10
    I-Location       0.83      0.38      0.53        13
         B-GPE       0.60      0.50      0.55        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.48      0.41      0.42      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.4223734778480678

Model saved

Best epoch: 4
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 5 Train Loss: 0.01222596714428063
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.79      0.89      0.83        87
      I-Person       0.59      0.89      0.71        71
B-Organization       0.84      0.82      0.83        76
I-Organization       0.80      0.68      0.73        77
    B-Location       0.74      0.79      0.76        47
    I-Location       0.62      0.62      0.62        32
         B-GPE       0.65      0.60      0.63        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.99      0.98     13082
     macro avg       0.55      0.57      0.55     13082
  weighted avg       0.97      0.99      0.98     13082

Train Output F1-Score: 0.5547495308728645


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 5 Dev Loss: 0.04242908328546748
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.46      0.86      0.60        14
      I-Person       0.46      0.50      0.48        12
B-Organization       0.88      0.56      0.68        27
I-Organization       0.75      0.27      0.40        33
    B-Location       0.62      0.80      0.70        10
    I-Location       0.88      0.54      0.67        13
         B-GPE       0.69      0.75      0.72        12
         I-GPE       1.00      0.17      0.29         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.61      0.49      0.50      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5013757637568715

Model saved

Best epoch: 5
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 6 Train Loss: 0.010926936076409158
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.83      0.94      0.88        87
      I-Person       0.63      0.94      0.76        71
B-Organization       0.88      0.78      0.83        76
I-Organization       0.76      0.75      0.76        77
    B-Location       0.71      0.64      0.67        47
    I-Location       0.50      0.47      0.48        32
         B-GPE       0.60      0.48      0.53        25
         I-GPE       0.00      0.00      0.00         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.99      0.98     13082
     macro avg       0.54      0.55      0.54     13082
  weighted avg       0.97      0.99      0.98     13082

Train Output F1-Score: 0.536727496851798


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 6 Dev Loss: 0.04144652491571951
                precision    recall  f1-score   support

             O       0.98      1.00      0.99      3623
      B-Person       0.61      0.79      0.69        14
      I-Person       0.42      0.67      0.52        12
B-Organization       1.00      0.48      0.65        27
I-Organization       0.75      0.18      0.29        33
    B-Location       0.57      0.80      0.67        10
    I-Location       0.69      0.85      0.76        13
         B-GPE       0.62      0.83      0.71        12
         I-GPE       0.60      1.00      0.75         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.97      3750
     macro avg       0.57      0.60      0.55      3750
  weighted avg       0.97      0.98      0.97      3750

Dev Output F1-Score: 0.547513302139961

Model saved

Best epoch: 6
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 7 Train Loss: 0.01466327932506831
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.86      0.95      0.90        87
      I-Person       0.61      0.94      0.74        71
B-Organization       0.91      0.79      0.85        76
I-Organization       0.70      0.71      0.71        77
    B-Location       0.70      0.70      0.70        47
    I-Location       0.52      0.53      0.52        32
         B-GPE       0.67      0.56      0.61        25
         I-GPE       0.20      0.12      0.15         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.99      0.98     13082
     macro avg       0.56      0.57      0.56     13082
  weighted avg       0.97      0.99      0.98     13082

Train Output F1-Score: 0.5609105249934143


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 7 Dev Loss: 0.034227966394850345
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.65      0.79      0.71        14
      I-Person       0.50      0.83      0.62        12
B-Organization       0.71      0.56      0.63        27
I-Organization       0.70      0.21      0.33        33
    B-Location       0.60      0.60      0.60        10
    I-Location       0.75      0.69      0.72        13
         B-GPE       0.67      0.83      0.74        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.60      0.53      0.53      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.530140574901819

Best epoch: 6
Patience: 4





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 8 Train Loss: 0.003972167630215982
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.86      0.94      0.90        87
      I-Person       0.62      0.90      0.73        71
B-Organization       0.90      0.83      0.86        76
I-Organization       0.77      0.77      0.77        77
    B-Location       0.78      0.81      0.79        47
    I-Location       0.61      0.69      0.65        32
         B-GPE       0.79      0.76      0.78        25
         I-GPE       0.50      0.25      0.33         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.98     13082
     macro avg       0.62      0.63      0.62     13082
  weighted avg       0.98      0.99      0.98     13082

Train Output F1-Score: 0.6182288473003902


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 8 Dev Loss: 0.0266755116867147
                precision    recall  f1-score   support

             O       0.98      1.00      0.99      3623
      B-Person       0.71      0.86      0.77        14
      I-Person       0.50      0.83      0.62        12
B-Organization       0.83      0.56      0.67        27
I-Organization       0.58      0.42      0.49        33
    B-Location       0.67      0.60      0.63        10
    I-Location       1.00      0.46      0.63        13
         B-GPE       0.77      0.83      0.80        12
         I-GPE       1.00      0.83      0.91         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.98      3750
     macro avg       0.64      0.58      0.59      3750
  weighted avg       0.97      0.98      0.97      3750

Dev Output F1-Score: 0.592323292594778

Model saved

Best epoch: 8
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 9 Train Loss: -0.0032141403012303587
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.88      0.95      0.92        87
      I-Person       0.64      0.94      0.77        71
B-Organization       0.93      0.86      0.89        76
I-Organization       0.81      0.81      0.81        77
    B-Location       0.81      0.72      0.76        47
    I-Location       0.62      0.41      0.49        32
         B-GPE       0.74      0.80      0.77        25
         I-GPE       0.50      0.88      0.64         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.98     13082
     macro avg       0.63      0.67      0.64     13082
  weighted avg       0.98      0.99      0.98     13082

Train Output F1-Score: 0.639061133526223


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 9 Dev Loss: 0.04590322064691266
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.77      0.71      0.74        14
      I-Person       0.53      0.75      0.62        12
B-Organization       1.00      0.44      0.62        27
I-Organization       0.67      0.30      0.42        33
    B-Location       0.86      0.60      0.71        10
    I-Location       0.75      0.69      0.72        13
         B-GPE       0.88      0.58      0.70        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.58      0.46      0.50      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5002899054131822

Best epoch: 8
Patience: 4





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 10 Train Loss: -0.001953690524582517
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.89      0.95      0.92        87
      I-Person       0.61      0.94      0.74        71
B-Organization       0.89      0.82      0.85        76
I-Organization       0.85      0.86      0.85        77
    B-Location       0.69      0.72      0.71        47
    I-Location       0.69      0.75      0.72        32
         B-GPE       0.78      0.72      0.75        25
         I-GPE       0.75      0.38      0.50         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.99     13082
     macro avg       0.65      0.65      0.64     13082
  weighted avg       0.98      0.99      0.99     13082

Train Output F1-Score: 0.6390618077149989


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 10 Dev Loss: 0.037120975608159704
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.73      0.57      0.64        14
      I-Person       0.60      0.75      0.67        12
B-Organization       0.93      0.52      0.67        27
I-Organization       0.75      0.36      0.49        33
    B-Location       0.75      0.60      0.67        10
    I-Location       0.75      0.69      0.72        13
         B-GPE       1.00      0.58      0.74        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.97      3750
     macro avg       0.68      0.49      0.55      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5518598685347303

Best epoch: 8
Patience: 3





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 11 Train Loss: -0.004969709330938055
                precision    recall  f1-score   support

             O       0.99      1.00      0.99     12659
      B-Person       0.88      0.92      0.90        87
      I-Person       0.63      0.97      0.77        71
B-Organization       0.88      0.89      0.89        76
I-Organization       0.78      0.82      0.80        77
    B-Location       0.85      0.83      0.84        47
    I-Location       0.62      0.72      0.67        32
         B-GPE       0.92      0.88      0.90        25
         I-GPE       1.00      0.75      0.86         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.99     13082
     macro avg       0.69      0.71      0.69     13082
  weighted avg       0.98      0.99      0.99     13082

Train Output F1-Score: 0.6912815790600134


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 11 Dev Loss: 0.05153191605008893
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.73      0.79      0.76        14
      I-Person       0.48      0.83      0.61        12
B-Organization       1.00      0.52      0.68        27
I-Organization       0.57      0.12      0.20        33
    B-Location       0.71      0.50      0.59        10
    I-Location       0.82      0.69      0.75        13
         B-GPE       0.82      0.75      0.78        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.65      0.50      0.53      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5320738399884447

Best epoch: 8
Patience: 2





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 12 Train Loss: -0.00787015117831677
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.85      0.95      0.90        87
      I-Person       0.61      0.92      0.73        71
B-Organization       0.91      0.84      0.88        76
I-Organization       0.75      0.71      0.73        77
    B-Location       0.83      0.81      0.82        47
    I-Location       0.63      0.75      0.69        32
         B-GPE       0.92      0.96      0.94        25
         I-GPE       0.83      0.62      0.71         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.99     13082
     macro avg       0.67      0.69      0.67     13082
  weighted avg       0.98      0.99      0.99     13082

Train Output F1-Score: 0.6715915539952205


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 12 Dev Loss: 0.03027817511258886
                precision    recall  f1-score   support

             O       0.98      1.00      0.99      3623
      B-Person       0.65      0.79      0.71        14
      I-Person       0.38      0.75      0.50        12
B-Organization       0.77      0.63      0.69        27
I-Organization       0.46      0.39      0.43        33
    B-Location       0.73      0.80      0.76        10
    I-Location       0.69      0.69      0.69        13
         B-GPE       0.90      0.75      0.82        12
         I-GPE       1.00      1.00      1.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.98      3750
     macro avg       0.60      0.62      0.60      3750
  weighted avg       0.97      0.98      0.98      3750

Dev Output F1-Score: 0.5990539319682032

Model saved

Best epoch: 12
Patience: 5





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 13 Train Loss: -0.013079912340368493
                precision    recall  f1-score   support

             O       0.98      1.00      0.99     12659
      B-Person       0.84      0.87      0.86        87
      I-Person       0.65      0.92      0.76        71
B-Organization       0.96      0.91      0.93        76
I-Organization       0.84      0.83      0.84        77
    B-Location       0.79      0.81      0.80        47
    I-Location       0.68      0.78      0.72        32
         B-GPE       0.77      0.80      0.78        25
         I-GPE       0.56      0.62      0.59         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.99     13082
     macro avg       0.64      0.69      0.66     13082
  weighted avg       0.98      0.99      0.99     13082

Train Output F1-Score: 0.6615287845583881


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 13 Dev Loss: 0.03679194695885935
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.59      0.71      0.65        14
      I-Person       0.42      0.42      0.42        12
B-Organization       0.79      0.56      0.65        27
I-Organization       0.52      0.39      0.45        33
    B-Location       0.75      0.60      0.67        10
    I-Location       0.82      0.69      0.75        13
         B-GPE       0.73      0.67      0.70        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.60      0.49      0.52      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5236497528649593

Best epoch: 12
Patience: 4





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 14 Train Loss: -0.01833489661667979
                precision    recall  f1-score   support

             O       0.99      1.00      0.99     12659
      B-Person       0.90      0.97      0.93        87
      I-Person       0.63      0.94      0.75        71
B-Organization       0.92      0.92      0.92        76
I-Organization       0.86      0.88      0.87        77
    B-Location       0.84      0.81      0.83        47
    I-Location       0.69      0.78      0.74        32
         B-GPE       0.81      0.84      0.82        25
         I-GPE       1.00      0.75      0.86         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      1.00      0.99     13082
     macro avg       0.69      0.72      0.70     13082
  weighted avg       0.98      1.00      0.99     13082

Train Output F1-Score: 0.7012351749931462


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 14 Dev Loss: 0.034677676616751266
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.69      0.79      0.73        14
      I-Person       0.44      0.58      0.50        12
B-Organization       0.71      0.56      0.63        27
I-Organization       0.45      0.30      0.36        33
    B-Location       0.44      0.70      0.54        10
    I-Location       0.69      0.69      0.69        13
         B-GPE       0.67      0.50      0.57        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.55      0.50      0.50      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.5008816804271349

Best epoch: 12
Patience: 3





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 15 Train Loss: -0.01804284160327615
                precision    recall  f1-score   support

             O       0.99      1.00      0.99     12659
      B-Person       0.87      0.95      0.91        87
      I-Person       0.59      0.94      0.73        71
B-Organization       0.92      0.93      0.93        76
I-Organization       0.80      0.83      0.82        77
    B-Location       0.85      0.87      0.86        47
    I-Location       0.66      0.66      0.66        32
         B-GPE       0.85      0.88      0.86        25
         I-GPE       0.56      0.62      0.59         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.99      0.99     13082
     macro avg       0.64      0.70      0.67     13082
  weighted avg       0.98      0.99      0.99     13082

Train Output F1-Score: 0.6678789555761537


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 15 Dev Loss: 0.04135308049803529
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.63      0.86      0.73        14
      I-Person       0.43      0.75      0.55        12
B-Organization       0.87      0.48      0.62        27
I-Organization       0.47      0.27      0.35        33
    B-Location       0.78      0.70      0.74        10
    I-Location       0.82      0.69      0.75        13
         B-GPE       0.86      0.50      0.63        12
         I-GPE       0.00      0.00      0.00         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.53      0.48      0.49      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.48560326310023194

Best epoch: 12
Patience: 2





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 16 Train Loss: -0.019044991124572273
                precision    recall  f1-score   support

             O       0.99      1.00      0.99     12659
      B-Person       0.86      0.98      0.91        87
      I-Person       0.59      0.99      0.74        71
B-Organization       0.96      0.93      0.95        76
I-Organization       0.86      0.91      0.89        77
    B-Location       0.77      0.70      0.73        47
    I-Location       0.65      0.75      0.70        32
         B-GPE       0.90      0.72      0.80        25
         I-GPE       1.00      0.62      0.77         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      1.00      0.99     13082
     macro avg       0.69      0.69      0.68     13082
  weighted avg       0.98      1.00      0.99     13082

Train Output F1-Score: 0.6798488201459922


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 16 Dev Loss: 0.03146048529244344
                precision    recall  f1-score   support

             O       0.97      1.00      0.98      3623
      B-Person       0.64      0.64      0.64        14
      I-Person       0.57      0.67      0.62        12
B-Organization       0.65      0.56      0.60        27
I-Organization       0.59      0.30      0.40        33
    B-Location       0.75      0.60      0.67        10
    I-Location       1.00      0.46      0.63        13
         B-GPE       0.77      0.83      0.80        12
         I-GPE       1.00      0.50      0.67         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.96      0.98      0.97      3750
     macro avg       0.63      0.51      0.55      3750
  weighted avg       0.96      0.98      0.97      3750

Dev Output F1-Score: 0.546110954254048

Best epoch: 12
Patience: 1





HBox(children=(FloatProgress(value=0.0, max=401.0), HTML(value='')))


Epoch 17 Train Loss: -0.02207267769073718
                precision    recall  f1-score   support

             O       0.99      1.00      0.99     12659
      B-Person       0.86      0.94      0.90        87
      I-Person       0.60      0.99      0.74        71
B-Organization       0.91      0.95      0.93        76
I-Organization       0.81      0.82      0.81        77
    B-Location       0.81      0.83      0.82        47
    I-Location       0.68      0.81      0.74        32
         B-GPE       0.88      0.92      0.90        25
         I-GPE       0.80      1.00      0.89         8
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      1.00      0.99     13082
     macro avg       0.67      0.75      0.70     13082
  weighted avg       0.98      1.00      0.99     13082

Train Output F1-Score: 0.7032078778057259


HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))


Epoch 17 Dev Loss: 0.030150456854080254
                precision    recall  f1-score   support

             O       0.97      1.00      0.99      3623
      B-Person       0.63      0.86      0.73        14
      I-Person       0.39      0.75      0.51        12
B-Organization       0.82      0.67      0.73        27
I-Organization       0.65      0.33      0.44        33
    B-Location       0.73      0.80      0.76        10
    I-Location       0.80      0.62      0.70        13
         B-GPE       1.00      0.67      0.80        12
         I-GPE       1.00      0.33      0.50         6
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.97      0.98      0.97      3750
     macro avg       0.64      0.55      0.56      3750
  weighted avg       0.97      0.98      0.97      3750

Dev Output F1-Score: 0.5599777723534395



Best Epoch: 12
Output F1 Score: 0.5990539319682032
                precis

## Testing

Load the best state of the system

In [None]:
# Load the best state from the file that it is stored
state = torch.load('system_best_epoch.pth.tar')

# Load the parameters into the model
model.load_state_dict(state['model'])

# Load the best epoch, macro avg f1-score and classification report
best_epoch = state['best_epoch']
best_f1 = state['best_f1']
best_cr = state['best_cr']
cr_output = best_cr[0]

Test the model

In [None]:
# Open the system log
log_f = open('system_log.txt', 'a', encoding='utf-8')

# Set model to evaluation mode for the dev set
model.eval()

losses = list()
true_labels_output = list()
pred_labels_output = list()

for p_batch, p_batch_token_lens, label_tags in tq.tqdm(test_batches):

    # Find the batch sentence lengths
    p_batch_sents = [[sum(s) for s in b] for b in p_batch_token_lens]

    # Pass the batch though the system
    model_out = model(p_batch, p_batch_sents)

    # Calculate the predictions of the model using the decode function of the CRF layer and add them to a list
    pred_labels_output.extend([model.crf.decode(e)[0] for e in model_out])

    # Get only the label_tags of the middle sentence
    label_tags_middle = [l[s[0]:s[0]+s[1]] for l, s in zip(label_tags, p_batch_sents)]

    # Add the label tags of the middle sentence to a list
    true_labels_output.extend([[class_dict[e2] for e2 in e] for e in label_tags_middle])

    # Add the label tags of the middle sentence to a Tensor object (use GPU if available)
    labels_output = [torch.LongTensor([class_dict[e2] for e2 in e]) for e in label_tags_middle]
    if use_cuda:
        labels_output = [e.cuda(gpu_device) for e in labels_output]

    # Calculate the loss of each instance of the batch using the CRF Layer
    loss_output = [-model.crf(e[0], e[1].view(1, -1), mask=(e[1] != 11).view(1, -1), reduction='token_mean') for e in zip(model_out, labels_output)]

    # Calculate the loss of the batch (mean of the losses of each instance)
    loss = torch.mean(torch.stack(loss_output))

    # Add the loss value to a list
    # NOTE: We don't perform back propagation because we are in the test set
    losses.append(loss.item())

# Get the true and pred labels of all batches into a flat list
true_labels_output_flat = [item for sublist in true_labels_output for item in sublist]
pred_labels_output_flat = [item for sublist in pred_labels_output for item in sublist]

print(f'Best Epoch: {best_epoch} Test Loss: {np.mean(losses)}')
log_f.write(f'Best Epoch: {best_epoch} Test Loss: {np.mean(losses)}\n')
log_f.flush()

# Calculate the PRF scores using the true and pred labels for the test set
# NOTE: We don't take into account the 'X' labels for the scores
cr_output = classification_report(y_true=true_labels_output_flat,
                                  y_pred=pred_labels_output_flat,
                                  labels=[e for e in inv_class_dict.keys() if e != 11],
                                  target_names=[e for e in class_dict.keys() if e != 'X'],
                                  output_dict=True)

cr_output_text = classification_report(y_true=true_labels_output_flat,
                                  y_pred=pred_labels_output_flat,
                                  labels=[e for e in inv_class_dict.keys() if e != 11],
                                  target_names=[e for e in class_dict.keys() if e != 'X'],
                                  output_dict=False)

print(cr_output_text)
log_f.write(cr_output_text + '\n')
print('Test Output F1-Score: {}'.format(cr_output['macro avg']['f1-score']))
log_f.write('Test Output F1-Score: {}\n\n'.format(cr_output['macro avg']['f1-score']))
log_f.flush()

# Close log file
log_f.close()

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))


Best Epoch: 0 Test Loss: 0.0599751421258022
                precision    recall  f1-score   support

             O       0.99      0.99      0.99      1777
      B-Person       0.78      0.88      0.82         8
      I-Person       0.80      0.57      0.67         7
B-Organization       0.62      0.80      0.70        10
I-Organization       0.73      0.73      0.73        11
    B-Location       0.00      0.00      0.00         4
    I-Location       1.00      0.60      0.75        10
         B-GPE       0.67      1.00      0.80         4
         I-GPE       0.00      0.00      0.00         0
    B-Facility       0.00      0.00      0.00         0
    I-Facility       0.00      0.00      0.00         0

     micro avg       0.98      0.98      0.98      1831
     macro avg       0.51      0.51      0.50      1831
  weighted avg       0.98      0.98      0.98      1831

Test Output F1-Score: 0.49587998059956334


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
