In [1]:
import torch
import os
from tqdm.notebook import trange, tqdm

from torch.utils import data
from torch.nn.utils.rnn import pad_sequence
from torch.nn import CrossEntropyLoss

import torch.nn.functional as F
from seqeval.metrics import accuracy_score, f1_score, classification_report

import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
from transformers import BertForTokenClassification
from transformers import BertTokenizer, AdamW

11/29/2019 22:43:06 - INFO - transformers.file_utils -   PyTorch version 1.3.1 available.


# Data

In [3]:
UNIQUE_LABELS = {'X', '[CLS]', '[SEP]'}

In [4]:
def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    else:
        raise ValueError("Unsupported string type: %s" % (type(text)))

In [5]:
class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text, label=None, segment_ids=None):
        """Constructs a InputExample.
        Args:
          guid: Unique id for the example.
          text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
          label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text = text
        self.label = label
        self.segment_ids = segment_ids

In [6]:
def readfile(filename):
    '''
    read file
    '''
    f = open(filename)
    data = []
    sentence = []
    label = []
    for line in f:
        if len(line) == 0 or line.startswith('-DOCSTART') or line[0] == "\n":
            if len(sentence) > 0:
                data.append((sentence, label))
                sentence = []
                label = []
            continue
        splits = line.split(' ')
        sentence.append(splits[0])
        label.append(splits[-1][:-1])

    if len(sentence) > 0:
        data.append((sentence, label))
        sentence = []
        label = []
    return data

In [7]:
test_data = readfile('./NER_datasets/CONLL2003/test.txt')
print(test_data[0])

(['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',', 'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.'], ['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O'])


In [8]:
class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        return readfile(input_file)

In [9]:
class NerProcessor(DataProcessor):
    """Processor for the CoNLL-2003 data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "valid.txt")),
            "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.txt")),
            "test")

    def get_labels(self):
        return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC",
                "[CLS]", "[SEP]", "X"]

    @staticmethod
    def _create_examples(lines, set_type):
        examples = []
        for i, (sentence, label) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = ' '.join(sentence)
            label = label
            examples.append(InputExample(guid=guid, text=text_a, label=label))
        return examples

In [10]:
train_examples = NerProcessor().get_train_examples('./NER_datasets/CONLL2003/')

In [11]:
x = train_examples[101]
print((x.guid, x.segment_ids))
print((x.text, x.label))

('train-101', None)
('He will be replaced by Eliahu Ben-Elissar , a former Israeli envoy to Egypt and right-wing Likud party politician .', ['O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'B-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'O'])


In [12]:
class NERDataSet(data.Dataset):
    def __init__(self, data_list, tokenizer, label_map, max_len):
        self.max_len = max_len
        self.label_map = label_map
        self.data_list = data_list
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        input_example = self.data_list[idx]
        text = input_example.text
        label = input_example.label
        word_tokens = ['[CLS]']
        label_list = ['[CLS]']
        label_mask = [0]  # value in (0, 1) - 0 signifies invalid token

        input_ids = [self.tokenizer.convert_tokens_to_ids('[CLS]')]
        label_ids = [self.label_map['[CLS]']]

        # iterate over individual tokens and their labels
        for word, label in zip(text.split(), label):
            tokenized_word = self.tokenizer.tokenize(word)

            for token in tokenized_word:
                word_tokens.append(token)
                input_ids.append(self.tokenizer.convert_tokens_to_ids(token))

            label_list.append(label)
            label_ids.append(self.label_map[label])
            label_mask.append(1)
            # len(tokenized_word) > 1 only if it splits word in between, in which case
            # the first token gets assigned NER tag and the remaining ones get assigned
            # X
            for i in range(1, len(tokenized_word)):
                label_list.append('X')
                label_ids.append(self.label_map['X'])
                label_mask.append(0)

        assert len(word_tokens) == len(label_list) == len(input_ids) == len(label_ids) == len(
            label_mask)

        if len(word_tokens) >= self.max_len:
            word_tokens = word_tokens[:(self.max_len - 1)]
            label_list = label_list[:(self.max_len - 1)]
            input_ids = input_ids[:(self.max_len - 1)]
            label_ids = label_ids[:(self.max_len - 1)]
            label_mask = label_mask[:(self.max_len - 1)]

        assert len(word_tokens) < self.max_len, len(word_tokens)

        word_tokens.append('[SEP]')
        label_list.append('[SEP]')
        input_ids.append(self.tokenizer.convert_tokens_to_ids('[SEP]'))
        label_ids.append(self.label_map['[SEP]'])
        label_mask.append(0)

        assert len(word_tokens) == len(label_list) == len(input_ids) == len(label_ids) == len(
            label_mask)

        sentence_id = [0 for _ in input_ids]
        attention_mask = [1 for _ in input_ids]

        while len(input_ids) < self.max_len:
            input_ids.append(0)
            label_ids.append(self.label_map['X'])
            attention_mask.append(0)
            sentence_id.append(0)
            label_mask.append(0)

        assert len(word_tokens) == len(label_list)
        assert len(input_ids) == len(label_ids) == len(attention_mask) == len(sentence_id) == len(
            label_mask) == self.max_len, len(input_ids)
        # return word_tokens, label_list,
        return torch.LongTensor(input_ids), torch.LongTensor(label_ids), torch.LongTensor(
            attention_mask), torch.LongTensor(sentence_id), torch.BoolTensor(label_mask)

In [13]:
ner_processor = NerProcessor()

tags_vals = ner_processor.get_labels()
label_map = {}

for (i, label) in enumerate(tags_vals):
    label_map[label] = i
    
label_map[''] = 11 # same as 'X'
print(label_map)

{'O': 0, 'B-MISC': 1, 'I-MISC': 2, 'B-PER': 3, 'I-PER': 4, 'B-ORG': 5, 'I-ORG': 6, 'B-LOC': 7, 'I-LOC': 8, '[CLS]': 9, '[SEP]': 10, 'X': 11, '': 11}


In [14]:
train_examples = ner_processor.get_train_examples('./NER_datasets/CONLL2003/')

In [15]:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

11/29/2019 22:43:09 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/dominykas/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [16]:
tokenizer.convert_tokens_to_ids('[CLS]'), label_map['[CLS]']

(101, 9)

In [17]:
train_dataset = NERDataSet(data_list=train_examples, tokenizer=tokenizer, label_map=label_map, max_len=128)

In [18]:
train_dataset[0][4]
# returns: (text, label, mask, idk, label_mask)

tensor([False,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False])

In [19]:
# difference between mask and label_mask
train_dataset[0][2].type(torch.int), train_dataset[0][4].type(torch.int)

(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32),
 tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32))

# Model

In [20]:
class CoNLLClassifier(BertForTokenClassification):
    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None, label_masks=None):
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask)

        sequence_output = outputs[0]  # (b, MAX_LEN, 768)

        token_reprs = [embedding[mask] for mask, embedding in zip(label_masks, sequence_output)]
        token_reprs = pad_sequence(sequences=token_reprs, batch_first=True,
                                   padding_value=-1)  # (b, local_max_len, 768)
        sequence_output = self.dropout(token_reprs)
        logits = self.classifier(sequence_output)  # (b, local_max_len, num_labels)

        outputs = (logits,)
        if labels is not None:
            labels = [label[mask] for mask, label in zip(label_masks, labels)]
            labels = pad_sequence(labels, batch_first=True, padding_value=-1)  # (b, local_max_len)
            loss_fct = CrossEntropyLoss(ignore_index=-1, reduction='sum')
            mask = labels != -1
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            loss /= mask.float().sum()
            outputs = (loss, logits, labels)

        return outputs  # (loss), scores, (hidden_states), (attentions)

# Training

In [21]:
train_examples = ner_processor.get_train_examples('./NER_datasets/CONLL2003/')
val_examples = ner_processor.get_dev_examples('./NER_datasets/CONLL2003/')
test_examples = ner_processor.get_test_examples('./NER_datasets/CONLL2003/')

In [22]:
train_dataset = NERDataSet(train_examples, tokenizer, label_map, max_len=128)
eval_dataset = NERDataSet(val_examples, tokenizer, label_map, max_len=128)
test_dataset = NERDataSet(test_examples, tokenizer, label_map, max_len=128)

In [23]:
bs = 16
train_iter = data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True,  num_workers=4)
eval_iter =  data.DataLoader(dataset=eval_dataset,  batch_size=bs, shuffle=False, num_workers=4)
test_iter =  data.DataLoader(dataset=test_dataset,  batch_size=bs, shuffle=False, num_workers=4)

In [24]:
model = CoNLLClassifier.from_pretrained("bert-base-cased", num_labels=len(label_map)).to('cuda')

11/29/2019 22:43:17 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/dominykas/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.d7a3af18ce3a2ab7c0f48f04dc8daff45ed9a3ed333b9e9a79d012a0dedf87a6
11/29/2019 22:43:17 - INFO - transformers.configuration_utils -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 13,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 289

In [25]:
num_epochs = 5
num_train_optimization_steps = int(len(train_examples) / bs) * num_epochs

In [26]:
FULL_FINETUNING = True
lr = 3e-5

if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

In [27]:
warmup_steps = int(0.1 * num_train_optimization_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
scheduler = None

In [30]:
train(train_iter, eval_iter, model, optimizer, scheduler, num_epochs)

11/29/2019 22:47:21 - INFO - __main__ -   starting to train


HBox(children=(IntProgress(value=0, description='Epoch', max=5, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=878), HTML(value='')))

11/29/2019 23:37:08 - INFO - __main__ -   Train loss: 0.11004682384714577
11/29/2019 23:37:08 - INFO - __main__ -   starting to evaluate





HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

11/29/2019 23:40:23 - INFO - __main__ -   Validation loss: 0.045760134218189884
11/29/2019 23:40:23 - INFO - __main__ -   Seq eval accuracy: 0.9886795393877988
11/29/2019 23:40:24 - INFO - __main__ -   F1-Score: 0.9304035424847524
11/29/2019 23:40:24 - INFO - __main__ -   Classification report: -- 





11/29/2019 23:40:24 - INFO - __main__ -              precision    recall  f1-score   support

      ORG       0.93      0.88      0.91      1341
      PER       0.92      0.99      0.95      1836
      LOC       0.96      0.95      0.96      1837
     MISC       0.84      0.89      0.87       922
        X       0.00      0.00      0.00         1

micro avg       0.92      0.94      0.93      5937
macro avg       0.92      0.94      0.93      5937



HBox(children=(IntProgress(value=0, max=878), HTML(value='')))

11/30/2019 00:29:53 - INFO - __main__ -   Train loss: 0.026279549470812014
11/30/2019 00:29:53 - INFO - __main__ -   starting to evaluate





HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

11/30/2019 00:33:08 - INFO - __main__ -   Validation loss: 0.04150540274585022
11/30/2019 00:33:08 - INFO - __main__ -   Seq eval accuracy: 0.9905890146717846
11/30/2019 00:33:08 - INFO - __main__ -   F1-Score: 0.9439126784214944
11/30/2019 00:33:08 - INFO - __main__ -   Classification report: -- 





11/30/2019 00:33:09 - INFO - __main__ -              precision    recall  f1-score   support

      ORG       0.92      0.92      0.92      1341
      PER       0.97      0.97      0.97      1836
      LOC       0.96      0.96      0.96      1837
     MISC       0.88      0.91      0.89       922
        X       0.00      0.00      0.00         1

micro avg       0.94      0.95      0.94      5937
macro avg       0.94      0.95      0.94      5937



HBox(children=(IntProgress(value=0, max=878), HTML(value='')))

11/30/2019 01:22:29 - INFO - __main__ -   Train loss: 0.014622718369912309
11/30/2019 01:22:29 - INFO - __main__ -   starting to evaluate





HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

11/30/2019 01:25:44 - INFO - __main__ -   Validation loss: 0.04302565876666146
11/30/2019 01:25:44 - INFO - __main__ -   Seq eval accuracy: 0.9907059213218246
11/30/2019 01:25:44 - INFO - __main__ -   F1-Score: 0.9428595331715887
11/30/2019 01:25:44 - INFO - __main__ -   Classification report: -- 





11/30/2019 01:25:45 - INFO - __main__ -              precision    recall  f1-score   support

      ORG       0.90      0.93      0.91      1341
      PER       0.97      0.98      0.97      1836
      LOC       0.97      0.96      0.96      1837
     MISC       0.87      0.90      0.89       922
        X       0.00      0.00      0.00         1

micro avg       0.94      0.95      0.94      5937
macro avg       0.94      0.95      0.94      5937



HBox(children=(IntProgress(value=0, max=878), HTML(value='')))

11/30/2019 02:15:01 - INFO - __main__ -   Train loss: 0.008680306672296807
11/30/2019 02:15:01 - INFO - __main__ -   starting to evaluate





HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

11/30/2019 02:18:17 - INFO - __main__ -   Validation loss: 0.054337821371054994
11/30/2019 02:18:17 - INFO - __main__ -   Seq eval accuracy: 0.9902188102799915
11/30/2019 02:18:17 - INFO - __main__ -   F1-Score: 0.9447101266884805
11/30/2019 02:18:17 - INFO - __main__ -   Classification report: -- 





11/30/2019 02:18:17 - INFO - __main__ -              precision    recall  f1-score   support

      ORG       0.87      0.95      0.91      1341
      PER       0.97      0.97      0.97      1836
      LOC       0.98      0.95      0.97      1837
     MISC       0.92      0.90      0.91       922
        X       0.00      0.00      0.00         1

micro avg       0.94      0.95      0.94      5937
macro avg       0.94      0.95      0.94      5937



HBox(children=(IntProgress(value=0, max=878), HTML(value='')))

11/30/2019 03:07:33 - INFO - __main__ -   Train loss: 0.006678770585474213
11/30/2019 03:07:33 - INFO - __main__ -   starting to evaluate





HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

11/30/2019 03:10:48 - INFO - __main__ -   Validation loss: 0.055189460339765815
11/30/2019 03:10:48 - INFO - __main__ -   Seq eval accuracy: 0.9903162324883581
11/30/2019 03:10:48 - INFO - __main__ -   F1-Score: 0.944258453297623
11/30/2019 03:10:48 - INFO - __main__ -   Classification report: -- 





11/30/2019 03:10:48 - INFO - __main__ -              precision    recall  f1-score   support

      ORG       0.91      0.93      0.92      1341
      PER       0.97      0.98      0.97      1836
      LOC       0.96      0.97      0.96      1837
     MISC       0.88      0.90      0.89       922
        X       0.00      0.00      0.00         1

micro avg       0.94      0.95      0.94      5937
macro avg       0.94      0.95      0.94      5937






In [31]:
eval(test_iter, model)

11/30/2019 03:10:48 - INFO - __main__ -   starting to evaluate


HBox(children=(IntProgress(value=0, max=216), HTML(value='')))

11/30/2019 03:14:16 - INFO - __main__ -   Validation loss: 0.16370227087335715
11/30/2019 03:14:16 - INFO - __main__ -   Seq eval accuracy: 0.9806967274920826
11/30/2019 03:14:16 - INFO - __main__ -   F1-Score: 0.90792928408892
11/30/2019 03:14:16 - INFO - __main__ -   Classification report: -- 





11/30/2019 03:14:16 - INFO - __main__ -              precision    recall  f1-score   support

     MISC       0.77      0.83      0.80       702
      ORG       0.88      0.91      0.89      1661
      PER       0.96      0.96      0.96      1615
      LOC       0.92      0.93      0.92      1666
        X       0.00      0.00      0.00         1

micro avg       0.90      0.92      0.91      5645
macro avg       0.90      0.92      0.91      5645



In [28]:
def train(train_iter, eval_iter, model, optimizer, scheduler, num_epochs, device='cuda'):
    logger.info("starting to train")
    max_grad_norm = 1.0  # should be a flag
    for _ in trange(num_epochs, desc="Epoch"):
        # TRAIN loop
        model = model.train()
        tr_loss = 0
        nb_tr_steps = 0
        for step, batch in enumerate(tqdm(train_iter)):
            # add batch to gpu
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_labels, b_input_mask, b_token_type_ids, b_label_masks = batch
            # forward pass
            loss, logits, labels = model(b_input_ids, token_type_ids=b_token_type_ids,
                                         attention_mask=b_input_mask, labels=b_labels,
                                         label_masks=b_label_masks)
            # backward pass
            loss.backward()
            # track train loss
            tr_loss += loss.item()
            nb_tr_steps += 1
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            # update parameters
            optimizer.step()
#             scheduler.step()
            model.zero_grad()
        # print train loss per epoch
        logger.info("Train loss: {}".format(tr_loss / nb_tr_steps))
        eval(eval_iter, model)

In [29]:
def eval(iter_data, model, device='cuda'):
    logger.info("starting to evaluate")
    model = model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps = 0
    predictions, true_labels = [], []
    for batch in tqdm(iter_data):
        batch = tuple(t.to(device) for t in batch)

        b_input_ids, b_labels, b_input_mask, b_token_type_ids, b_label_masks = batch

        with torch.no_grad():
            tmp_eval_loss, logits, reduced_labels = model(b_input_ids,
                                                          token_type_ids=b_token_type_ids,
                                                          attention_mask=b_input_mask,
                                                          labels=b_labels,
                                                          label_masks=b_label_masks)

        logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
        logits = logits.detach().cpu().numpy()
        reduced_labels = reduced_labels.to('cpu').numpy()

        labels_to_append = []
        predictions_to_append = []

        for prediction, r_label in zip(logits, reduced_labels):
            preds = []
            labels = []
            for pred, lab in zip(prediction, r_label):
                if lab.item() == -1:  # masked label; -1 means do not collect this label
                    continue
                preds.append(pred)
                labels.append(lab)
            predictions_to_append.append(preds)
            labels_to_append.append(labels)

        predictions.extend(predictions_to_append)
        true_labels.append(labels_to_append)

        eval_loss += tmp_eval_loss.mean().item()

        nb_eval_steps += 1
    eval_loss = eval_loss / nb_eval_steps
    logger.info("Validation loss: {}".format(eval_loss))
    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]
    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]
    logger.info("Seq eval accuracy: {}".format(accuracy_score(valid_tags, pred_tags)))
    logger.info("F1-Score: {}".format(f1_score(valid_tags, pred_tags)))
    logger.info("Classification report: -- ")
    logger.info(classification_report(valid_tags, pred_tags))