# Named Entity Recognition Using BERT

### Required packages
* pytorch-pretrained-bert
* pandas
* seqeval
* unicode

In [31]:
import sys
import os
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import torch
import random

from pytorch_pretrained_bert.tokenization import BertTokenizer
from torch.optim import Adam

bert_utils_path = os.path.abspath('../../utils_nlp/bert')
if bert_utils_path not in sys.path:
    sys.path.insert(0, bert_utils_path)

from configs import (PathConfig,
                     GlobalConfig, 
                     DeviceConfig, 
                     ModelConfig, 
                     OptimizerConfig, 
                     TrainConfig, 
                     EvalConfig)
from bert_data_utils import KaggleNERProcessor
from bert_utils import (convert_examples_to_token_features,
                        create_train_dataloader, 
                        create_eval_dataloader, 
                        load_model, 
                        get_optimizer_params, 
                        train_model, 
                        eval_token_model)

## Configurations

### Path configuration

In [2]:
path_config_dict = {"data_dir": "./data/NER/", 
                    "output_dir": "./NER_output/"}
path_config = PathConfig(path_config_dict)

### Global configuration

In [3]:
global_config_dict = {"fp16": False}
global_config = GlobalConfig(global_config_dict)

### Device configuration

In [4]:
device_config_dict = {"no_cuda": False}
device_config = DeviceConfig(device_config_dict)
print("device name: {}".format(torch.cuda.get_device_name(0)))
print("number of gpus: {}".format(device_config.n_gpu))

device name: Tesla K80
number of gpus: 1


### Model configuration

In [5]:
model_config_dict = {"bert_model": "bert-base-uncased",
                     "max_seq_length": 75,
                     "num_labels": 18,
                     "model_type": "token"}
model_config = ModelConfig(model_config_dict)

### Optimizer configuration

In [6]:
optimizer_config_dict = {"no_decay_params": ['bias', 'gamma', 'beta'],
                         "learning_rate": 3e-5}
optimizer_config = OptimizerConfig(optimizer_config_dict)

### Train configuration

In [7]:
train_config_dict = {"train_batch_size": 32,
                     "num_train_epochs": 2, 
                     "clip_gradient": True}
train_config = TrainConfig(train_config_dict)

### Evaluation configuration

In [8]:
eval_config = EvalConfig({"eval_batch_size":32})

### Set random seeds

In [9]:
random.seed(global_config.seed)
np.random.seed(global_config.seed)
torch.manual_seed(global_config.seed)

<torch._C.Generator at 0x7f6f5748d170>

## Preprocess Data

### Create training and validation examples
KaggleNERProcessor is a dataset specific class that generates training and evaluation examples in the format accepted by all utility functions. 

In [10]:
kaggle_ner_processor = KaggleNERProcessor(data_dir="./data/NER/ner_dataset.csv", dev_percentage = 0.1)

In [11]:
train_examples = kaggle_ner_processor.get_train_examples(data_dir="./data/NER/ner_dataset.csv")
dev_examples = kaggle_ner_processor.get_dev_examples(data_dir="./data/NER/ner_dataset.csv")
label_list = kaggle_ner_processor.get_labels()

In [12]:
print('Sample sentence: \n{}\n'.format(train_examples[0].text_a))
print('Sample sentence labels: \n{}\n'.format(train_examples[1].label))

Sample sentence: 
Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .

Sample sentence labels: 
['B-gpe', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-tim', 'O', 'O', 'O', 'B-org', 'O', 'O', 'O', 'O', 'O']



### Convert examples to features
The function `convert_examples_to_token_features` converts raw string data to numerical features, involving the following steps:
1. Tokenization
2. Convert tokens and labels to numerical values
3. Sequence padding or truncation

In [13]:
tokenizer = BertTokenizer.from_pretrained(model_config.bert_model,
                                          do_lower_case=model_config.do_lower_case)
train_features = convert_examples_to_token_features(examples=train_examples,
                                                    tokenizer=tokenizer,
                                                    label_list=label_list, 
                                                    model_config=model_config)
dev_features = convert_examples_to_token_features(examples=dev_examples,
                                                  tokenizer=tokenizer,
                                                  label_list=label_list, 
                                                  model_config=model_config)

In [14]:
print("Sample token id:\n{}\n".format(train_features[0].input_ids))
print("Sample attention mask:\n{}\n".format(train_features[0].input_mask))
print("Sample label ids:\n{}\n".format(train_features[0].label_id))

Sample token id:
[5190, 1997, 28337, 2031, 9847, 2083, 2414, 2000, 6186, 1996, 2162, 1999, 5712, 1998, 5157, 1996, 10534, 1997, 2329, 3629, 2013, 2008, 2406, 1012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

Sample attention mask:
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

Sample label ids:
[1, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

### Create dataloaders
The utility functions `create_train_dataloader` and `create_eval_dataloader` creates Pytorch dataloaders from features, which can be used for model training and evaluation. The following two steps are performed:
1. Convert numpy arrays to Pytorch tensors
2. Create dataloader for sampling and serving data in batches

In [15]:
train_dataloader = create_train_dataloader(train_features=train_features,
                                           model_config=model_config,
                                           train_config=train_config,
                                           device_config=device_config)

In [16]:
valid_dataloader = create_eval_dataloader(eval_features=dev_features, 
                                          model_config=model_config, 
                                          eval_config=eval_config)

In [17]:
it = iter(train_dataloader)
first = next(it)
print("Sample token id tensor:\n{}\n".format(first[0][0]))
print("Sample attention mask tensor:\n{}\n".format(first[1][0]))
print("Sample label id tensor:\n{}\n".format(first[3][0]))

Sample token id tensor:
tensor([ 1057, 29625,  2015, 29625, 29624,  3709,  2749,  1999,  7041,  2360,
         2027,  2730,  2321, 17671,  2076,  2019, 11585,  3169,  1999,  1996,
         2264,  1010,  2096,  2334,  4584,  2758,  2216,  2730,  2020,  9272,
         1012,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0])

Sample attention mask tensor:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0])

Sample label id tensor:
tensor([ 1, 17, 17, 17, 17, 17,  1,  1, 13,  1,  1,  1,  1,  1,  1,  1, 

## Load Model

In [18]:
model = load_model(model_config=model_config, 
                   path_config=path_config, 
                   device_config=device_config,
                   global_config=global_config)

## Configure Optimizer
This step must be done after loading the model, because the load_model function moves all model parameters to the device, e.g. GPU. 

In [19]:
optimizer_config = get_optimizer_params(optimizer_config=optimizer_config,
                                        train_config=train_config, 
                                        device_config=device_config, 
                                        model=model, 
                                        num_train_examples=len(train_dataloader))

In [20]:
# optimizer = BertAdam(optimizer_grouped_parameters,
#                      lr=optimizer_config.learning_rate,
#                      warmup=optimizer_config.warmup_proportion,
#                      t_total=num_train_optimization_steps)
optimizer = Adam(optimizer_config.grouped_parameters, lr=optimizer_config.learning_rate)

## Train Model

In [21]:
model, train_loss = train_model(model=model, 
                                train_dataloader=train_dataloader, 
                                optimizer=optimizer,
                                train_config=train_config, 
                                model_config=model_config, 
                                optimizer_config=optimizer_config,
                                device_config=device_config,
                                global_config=global_config)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1349 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1349 [00:01<22:55,  1.02s/it][A
Iteration:   0%|          | 2/1349 [00:01<22:19,  1.01it/s][A
Iteration:   0%|          | 3/1349 [00:02<21:49,  1.03it/s][A
Iteration:   0%|          | 4/1349 [00:03<21:30,  1.04it/s][A
Iteration:   0%|          | 5/1349 [00:04<21:19,  1.05it/s][A
Iteration:   0%|          | 6/1349 [00:05<21:13,  1.05it/s][A
Iteration:   1%|          | 7/1349 [00:06<21:04,  1.06it/s][A
Iteration:   1%|          | 8/1349 [00:07<20:58,  1.07it/s][A
Iteration:   1%|          | 9/1349 [00:08<20:49,  1.07it/s][A
Iteration:   1%|          | 10/1349 [00:09<20:46,  1.07it/s][A
Iteration:   1%|          | 11/1349 [00:10<20:45,  1.07it/s][A
Iteration:   1%|          | 12/1349 [00:11<20:44,  1.07it/s][A
Iteration:   1%|          | 13/1349 [00:12<20:42,  1.07it/s][A
Iteration:   1%|          | 14/1349 [00:13<20:39,  1.08it/s][A
Iteration:   

Iteration:   9%|▉         | 127/1349 [01:58<18:58,  1.07it/s][A
Iteration:   9%|▉         | 128/1349 [01:59<18:56,  1.07it/s][A
Iteration:  10%|▉         | 129/1349 [02:00<19:02,  1.07it/s][A
Iteration:  10%|▉         | 130/1349 [02:01<19:03,  1.07it/s][A
Iteration:  10%|▉         | 131/1349 [02:02<19:03,  1.07it/s][A
Iteration:  10%|▉         | 132/1349 [02:03<19:00,  1.07it/s][A
Iteration:  10%|▉         | 133/1349 [02:04<18:56,  1.07it/s][A
Iteration:  10%|▉         | 134/1349 [02:05<18:56,  1.07it/s][A
Iteration:  10%|█         | 135/1349 [02:05<18:56,  1.07it/s][A
Iteration:  10%|█         | 136/1349 [02:06<18:58,  1.07it/s][A
Iteration:  10%|█         | 137/1349 [02:07<18:55,  1.07it/s][A
Iteration:  10%|█         | 138/1349 [02:08<18:53,  1.07it/s][A
Iteration:  10%|█         | 139/1349 [02:09<18:52,  1.07it/s][A
Iteration:  10%|█         | 140/1349 [02:10<18:50,  1.07it/s][A
Iteration:  10%|█         | 141/1349 [02:11<18:47,  1.07it/s][A
Iteration:  11%|█        

Iteration:  19%|█▉        | 253/1349 [03:56<17:09,  1.06it/s][A
Iteration:  19%|█▉        | 254/1349 [03:57<17:06,  1.07it/s][A
Iteration:  19%|█▉        | 255/1349 [03:58<17:08,  1.06it/s][A
Iteration:  19%|█▉        | 256/1349 [03:59<17:05,  1.07it/s][A
Iteration:  19%|█▉        | 257/1349 [04:00<17:05,  1.06it/s][A
Iteration:  19%|█▉        | 258/1349 [04:01<17:02,  1.07it/s][A
Iteration:  19%|█▉        | 259/1349 [04:02<17:02,  1.07it/s][A
Iteration:  19%|█▉        | 260/1349 [04:02<17:02,  1.06it/s][A
Iteration:  19%|█▉        | 261/1349 [04:03<17:02,  1.06it/s][A
Iteration:  19%|█▉        | 262/1349 [04:04<17:00,  1.07it/s][A
Iteration:  19%|█▉        | 263/1349 [04:05<17:01,  1.06it/s][A
Iteration:  20%|█▉        | 264/1349 [04:06<17:00,  1.06it/s][A
Iteration:  20%|█▉        | 265/1349 [04:07<16:59,  1.06it/s][A
Iteration:  20%|█▉        | 266/1349 [04:08<16:58,  1.06it/s][A
Iteration:  20%|█▉        | 267/1349 [04:09<16:57,  1.06it/s][A
Iteration:  20%|█▉       

Iteration:  28%|██▊       | 379/1349 [05:55<15:16,  1.06it/s][A
Iteration:  28%|██▊       | 380/1349 [05:55<15:15,  1.06it/s][A
Iteration:  28%|██▊       | 381/1349 [05:56<15:16,  1.06it/s][A
Iteration:  28%|██▊       | 382/1349 [05:57<15:15,  1.06it/s][A
Iteration:  28%|██▊       | 383/1349 [05:58<15:13,  1.06it/s][A
Iteration:  28%|██▊       | 384/1349 [05:59<15:13,  1.06it/s][A
Iteration:  29%|██▊       | 385/1349 [06:00<15:11,  1.06it/s][A
Iteration:  29%|██▊       | 386/1349 [06:01<15:09,  1.06it/s][A
Iteration:  29%|██▊       | 387/1349 [06:02<15:07,  1.06it/s][A
Iteration:  29%|██▉       | 388/1349 [06:03<15:06,  1.06it/s][A
Iteration:  29%|██▉       | 389/1349 [06:04<15:05,  1.06it/s][A
Iteration:  29%|██▉       | 390/1349 [06:05<15:05,  1.06it/s][A
Iteration:  29%|██▉       | 391/1349 [06:06<15:02,  1.06it/s][A
Iteration:  29%|██▉       | 392/1349 [06:07<15:01,  1.06it/s][A
Iteration:  29%|██▉       | 393/1349 [06:08<15:01,  1.06it/s][A
Iteration:  29%|██▉      

Iteration:  37%|███▋      | 505/1349 [07:54<13:19,  1.06it/s][A
Iteration:  38%|███▊      | 506/1349 [07:55<13:17,  1.06it/s][A
Iteration:  38%|███▊      | 507/1349 [07:56<13:17,  1.06it/s][A
Iteration:  38%|███▊      | 508/1349 [07:57<13:18,  1.05it/s][A
Iteration:  38%|███▊      | 509/1349 [07:58<13:18,  1.05it/s][A
Iteration:  38%|███▊      | 510/1349 [07:58<13:17,  1.05it/s][A
Iteration:  38%|███▊      | 511/1349 [07:59<13:15,  1.05it/s][A
Iteration:  38%|███▊      | 512/1349 [08:00<13:13,  1.05it/s][A
Iteration:  38%|███▊      | 513/1349 [08:01<13:11,  1.06it/s][A
Iteration:  38%|███▊      | 514/1349 [08:02<13:10,  1.06it/s][A
Iteration:  38%|███▊      | 515/1349 [08:03<13:11,  1.05it/s][A
Iteration:  38%|███▊      | 516/1349 [08:04<13:09,  1.05it/s][A
Iteration:  38%|███▊      | 517/1349 [08:05<13:10,  1.05it/s][A
Iteration:  38%|███▊      | 518/1349 [08:06<13:07,  1.06it/s][A
Iteration:  38%|███▊      | 519/1349 [08:07<13:07,  1.05it/s][A
Iteration:  39%|███▊     

Iteration:  47%|████▋     | 631/1349 [09:53<11:24,  1.05it/s][A
Iteration:  47%|████▋     | 632/1349 [09:54<11:22,  1.05it/s][A
Iteration:  47%|████▋     | 633/1349 [09:55<11:23,  1.05it/s][A
Iteration:  47%|████▋     | 634/1349 [09:56<11:21,  1.05it/s][A
Iteration:  47%|████▋     | 635/1349 [09:57<11:21,  1.05it/s][A
Iteration:  47%|████▋     | 636/1349 [09:58<11:20,  1.05it/s][A
Iteration:  47%|████▋     | 637/1349 [09:59<11:19,  1.05it/s][A
Iteration:  47%|████▋     | 638/1349 [10:00<11:16,  1.05it/s][A
Iteration:  47%|████▋     | 639/1349 [10:01<11:14,  1.05it/s][A
Iteration:  47%|████▋     | 640/1349 [10:02<11:13,  1.05it/s][A
Iteration:  48%|████▊     | 641/1349 [10:03<11:11,  1.05it/s][A
Iteration:  48%|████▊     | 642/1349 [10:04<11:10,  1.05it/s][A
Iteration:  48%|████▊     | 643/1349 [10:05<11:10,  1.05it/s][A
Iteration:  48%|████▊     | 644/1349 [10:06<11:09,  1.05it/s][A
Iteration:  48%|████▊     | 645/1349 [10:07<11:08,  1.05it/s][A
Iteration:  48%|████▊    

Iteration:  56%|█████▌    | 757/1349 [11:53<09:24,  1.05it/s][A
Iteration:  56%|█████▌    | 758/1349 [11:54<09:23,  1.05it/s][A
Iteration:  56%|█████▋    | 759/1349 [11:55<09:23,  1.05it/s][A
Iteration:  56%|█████▋    | 760/1349 [11:56<09:20,  1.05it/s][A
Iteration:  56%|█████▋    | 761/1349 [11:57<09:20,  1.05it/s][A
Iteration:  56%|█████▋    | 762/1349 [11:58<09:18,  1.05it/s][A
Iteration:  57%|█████▋    | 763/1349 [11:59<09:16,  1.05it/s][A
Iteration:  57%|█████▋    | 764/1349 [12:00<09:15,  1.05it/s][A
Iteration:  57%|█████▋    | 765/1349 [12:01<09:16,  1.05it/s][A
Iteration:  57%|█████▋    | 766/1349 [12:02<09:15,  1.05it/s][A
Iteration:  57%|█████▋    | 767/1349 [12:03<09:14,  1.05it/s][A
Iteration:  57%|█████▋    | 768/1349 [12:04<09:13,  1.05it/s][A
Iteration:  57%|█████▋    | 769/1349 [12:05<09:12,  1.05it/s][A
Iteration:  57%|█████▋    | 770/1349 [12:06<09:12,  1.05it/s][A
Iteration:  57%|█████▋    | 771/1349 [12:07<09:12,  1.05it/s][A
Iteration:  57%|█████▋   

Iteration:  65%|██████▌   | 883/1349 [13:53<07:24,  1.05it/s][A
Iteration:  66%|██████▌   | 884/1349 [13:54<07:24,  1.05it/s][A
Iteration:  66%|██████▌   | 885/1349 [13:55<07:23,  1.05it/s][A
Iteration:  66%|██████▌   | 886/1349 [13:56<07:22,  1.05it/s][A
Iteration:  66%|██████▌   | 887/1349 [13:57<07:21,  1.05it/s][A
Iteration:  66%|██████▌   | 888/1349 [13:58<07:21,  1.05it/s][A
Iteration:  66%|██████▌   | 889/1349 [13:59<07:20,  1.04it/s][A
Iteration:  66%|██████▌   | 890/1349 [14:00<07:18,  1.05it/s][A
Iteration:  66%|██████▌   | 891/1349 [14:01<07:17,  1.05it/s][A
Iteration:  66%|██████▌   | 892/1349 [14:02<07:17,  1.04it/s][A
Iteration:  66%|██████▌   | 893/1349 [14:03<07:16,  1.05it/s][A
Iteration:  66%|██████▋   | 894/1349 [14:04<07:14,  1.05it/s][A
Iteration:  66%|██████▋   | 895/1349 [14:05<07:13,  1.05it/s][A
Iteration:  66%|██████▋   | 896/1349 [14:06<07:11,  1.05it/s][A
Iteration:  66%|██████▋   | 897/1349 [14:07<07:11,  1.05it/s][A
Iteration:  67%|██████▋  

Iteration:  75%|███████▍  | 1008/1349 [15:53<05:25,  1.05it/s][A
Iteration:  75%|███████▍  | 1009/1349 [15:54<05:24,  1.05it/s][A
Iteration:  75%|███████▍  | 1010/1349 [15:55<05:23,  1.05it/s][A
Iteration:  75%|███████▍  | 1011/1349 [15:56<05:21,  1.05it/s][A
Iteration:  75%|███████▌  | 1012/1349 [15:57<05:21,  1.05it/s][A
Iteration:  75%|███████▌  | 1013/1349 [15:58<05:20,  1.05it/s][A
Iteration:  75%|███████▌  | 1014/1349 [15:59<05:20,  1.05it/s][A
Iteration:  75%|███████▌  | 1015/1349 [16:00<05:20,  1.04it/s][A
Iteration:  75%|███████▌  | 1016/1349 [16:01<05:18,  1.04it/s][A
Iteration:  75%|███████▌  | 1017/1349 [16:02<05:17,  1.05it/s][A
Iteration:  75%|███████▌  | 1018/1349 [16:02<05:16,  1.05it/s][A
Iteration:  76%|███████▌  | 1019/1349 [16:03<05:16,  1.04it/s][A
Iteration:  76%|███████▌  | 1020/1349 [16:04<05:15,  1.04it/s][A
Iteration:  76%|███████▌  | 1021/1349 [16:05<05:14,  1.04it/s][A
Iteration:  76%|███████▌  | 1022/1349 [16:06<05:13,  1.04it/s][A
Iteration:

Iteration:  84%|████████▍ | 1132/1349 [17:52<03:27,  1.05it/s][A
Iteration:  84%|████████▍ | 1133/1349 [17:52<03:27,  1.04it/s][A
Iteration:  84%|████████▍ | 1134/1349 [17:53<03:26,  1.04it/s][A
Iteration:  84%|████████▍ | 1135/1349 [17:54<03:24,  1.04it/s][A
Iteration:  84%|████████▍ | 1136/1349 [17:55<03:23,  1.04it/s][A
Iteration:  84%|████████▍ | 1137/1349 [17:56<03:22,  1.05it/s][A
Iteration:  84%|████████▍ | 1138/1349 [17:57<03:21,  1.05it/s][A
Iteration:  84%|████████▍ | 1139/1349 [17:58<03:20,  1.05it/s][A
Iteration:  85%|████████▍ | 1140/1349 [17:59<03:19,  1.05it/s][A
Iteration:  85%|████████▍ | 1141/1349 [18:00<03:18,  1.05it/s][A
Iteration:  85%|████████▍ | 1142/1349 [18:01<03:17,  1.05it/s][A
Iteration:  85%|████████▍ | 1143/1349 [18:02<03:16,  1.05it/s][A
Iteration:  85%|████████▍ | 1144/1349 [18:03<03:15,  1.05it/s][A
Iteration:  85%|████████▍ | 1145/1349 [18:04<03:14,  1.05it/s][A
Iteration:  85%|████████▍ | 1146/1349 [18:05<03:13,  1.05it/s][A
Iteration:

Iteration:  93%|█████████▎| 1256/1349 [19:50<01:29,  1.04it/s][A
Iteration:  93%|█████████▎| 1257/1349 [19:51<01:28,  1.04it/s][A
Iteration:  93%|█████████▎| 1258/1349 [19:52<01:27,  1.04it/s][A
Iteration:  93%|█████████▎| 1259/1349 [19:53<01:26,  1.04it/s][A
Iteration:  93%|█████████▎| 1260/1349 [19:54<01:25,  1.04it/s][A
Iteration:  93%|█████████▎| 1261/1349 [19:55<01:24,  1.05it/s][A
Iteration:  94%|█████████▎| 1262/1349 [19:56<01:23,  1.04it/s][A
Iteration:  94%|█████████▎| 1263/1349 [19:57<01:22,  1.05it/s][A
Iteration:  94%|█████████▎| 1264/1349 [19:58<01:21,  1.05it/s][A
Iteration:  94%|█████████▍| 1265/1349 [19:59<01:20,  1.05it/s][A
Iteration:  94%|█████████▍| 1266/1349 [20:00<01:19,  1.04it/s][A
Iteration:  94%|█████████▍| 1267/1349 [20:01<01:18,  1.04it/s][A
Iteration:  94%|█████████▍| 1268/1349 [20:02<01:17,  1.04it/s][A
Iteration:  94%|█████████▍| 1269/1349 [20:03<01:16,  1.04it/s][A
Iteration:  94%|█████████▍| 1270/1349 [20:03<01:15,  1.05it/s][A
Iteration:

Train loss: 0.11175098000653502



Iteration:   0%|          | 1/1349 [00:00<21:31,  1.04it/s][A
Iteration:   0%|          | 2/1349 [00:01<21:25,  1.05it/s][A
Iteration:   0%|          | 3/1349 [00:02<21:23,  1.05it/s][A
Iteration:   0%|          | 4/1349 [00:03<21:21,  1.05it/s][A
Iteration:   0%|          | 5/1349 [00:04<21:23,  1.05it/s][A
Iteration:   0%|          | 6/1349 [00:05<21:21,  1.05it/s][A
Iteration:   1%|          | 7/1349 [00:06<21:22,  1.05it/s][A
Iteration:   1%|          | 8/1349 [00:07<21:25,  1.04it/s][A
Iteration:   1%|          | 9/1349 [00:08<21:24,  1.04it/s][A
Iteration:   1%|          | 10/1349 [00:09<21:26,  1.04it/s][A
Iteration:   1%|          | 11/1349 [00:10<21:24,  1.04it/s][A
Iteration:   1%|          | 12/1349 [00:11<21:22,  1.04it/s][A
Iteration:   1%|          | 13/1349 [00:12<21:21,  1.04it/s][A
Iteration:   1%|          | 14/1349 [00:13<21:18,  1.04it/s][A
Iteration:   1%|          | 15/1349 [00:14<21:12,  1.05it/s][A
Iteration:   1%|          | 16/1349 [00:15<21:13

Iteration:  18%|█▊        | 249/1349 [03:58<17:30,  1.05it/s][A
Iteration:  19%|█▊        | 250/1349 [03:59<17:28,  1.05it/s][A
Iteration:  19%|█▊        | 251/1349 [04:00<17:27,  1.05it/s][A
Iteration:  19%|█▊        | 252/1349 [04:01<17:29,  1.05it/s][A
Iteration:  19%|█▉        | 253/1349 [04:02<17:26,  1.05it/s][A
Iteration:  19%|█▉        | 254/1349 [04:03<17:25,  1.05it/s][A
Iteration:  19%|█▉        | 255/1349 [04:04<17:27,  1.04it/s][A
Iteration:  19%|█▉        | 256/1349 [04:05<17:26,  1.04it/s][A
Iteration:  19%|█▉        | 257/1349 [04:05<17:25,  1.04it/s][A
Iteration:  19%|█▉        | 258/1349 [04:06<17:25,  1.04it/s][A
Iteration:  19%|█▉        | 259/1349 [04:07<17:22,  1.05it/s][A
Iteration:  19%|█▉        | 260/1349 [04:08<17:21,  1.05it/s][A
Iteration:  19%|█▉        | 261/1349 [04:09<17:21,  1.04it/s][A
Iteration:  19%|█▉        | 262/1349 [04:10<17:19,  1.05it/s][A
Iteration:  19%|█▉        | 263/1349 [04:11<17:18,  1.05it/s][A
Iteration:  20%|█▉       

Iteration:  28%|██▊       | 375/1349 [05:58<15:25,  1.05it/s][A
Iteration:  28%|██▊       | 376/1349 [05:59<15:26,  1.05it/s][A
Iteration:  28%|██▊       | 377/1349 [06:00<15:26,  1.05it/s][A
Iteration:  28%|██▊       | 378/1349 [06:01<15:28,  1.05it/s][A
Iteration:  28%|██▊       | 379/1349 [06:02<15:27,  1.05it/s][A
Iteration:  28%|██▊       | 380/1349 [06:03<15:26,  1.05it/s][A
Iteration:  28%|██▊       | 381/1349 [06:04<15:27,  1.04it/s][A
Iteration:  28%|██▊       | 382/1349 [06:05<15:26,  1.04it/s][A
Iteration:  28%|██▊       | 383/1349 [06:06<15:24,  1.04it/s][A
Iteration:  28%|██▊       | 384/1349 [06:07<15:23,  1.04it/s][A
Iteration:  29%|██▊       | 385/1349 [06:08<15:20,  1.05it/s][A
Iteration:  29%|██▊       | 386/1349 [06:09<15:19,  1.05it/s][A
Iteration:  29%|██▊       | 387/1349 [06:10<15:18,  1.05it/s][A
Iteration:  29%|██▉       | 388/1349 [06:11<15:16,  1.05it/s][A
Iteration:  29%|██▉       | 389/1349 [06:12<15:14,  1.05it/s][A
Iteration:  29%|██▉      

Iteration:  37%|███▋      | 501/1349 [07:59<13:29,  1.05it/s][A
Iteration:  37%|███▋      | 502/1349 [08:00<13:28,  1.05it/s][A
Iteration:  37%|███▋      | 503/1349 [08:01<13:29,  1.05it/s][A
Iteration:  37%|███▋      | 504/1349 [08:02<13:28,  1.05it/s][A
Iteration:  37%|███▋      | 505/1349 [08:03<13:27,  1.04it/s][A
Iteration:  38%|███▊      | 506/1349 [08:04<13:26,  1.05it/s][A
Iteration:  38%|███▊      | 507/1349 [08:05<13:24,  1.05it/s][A
Iteration:  38%|███▊      | 508/1349 [08:06<13:23,  1.05it/s][A
Iteration:  38%|███▊      | 509/1349 [08:07<13:21,  1.05it/s][A
Iteration:  38%|███▊      | 510/1349 [08:08<13:22,  1.04it/s][A
Iteration:  38%|███▊      | 511/1349 [08:09<13:21,  1.05it/s][A
Iteration:  38%|███▊      | 512/1349 [08:10<13:22,  1.04it/s][A
Iteration:  38%|███▊      | 513/1349 [08:11<13:22,  1.04it/s][A
Iteration:  38%|███▊      | 514/1349 [08:12<13:22,  1.04it/s][A
Iteration:  38%|███▊      | 515/1349 [08:13<13:21,  1.04it/s][A
Iteration:  38%|███▊     

Iteration:  46%|████▋     | 627/1349 [10:00<11:28,  1.05it/s][A
Iteration:  47%|████▋     | 628/1349 [10:01<11:27,  1.05it/s][A
Iteration:  47%|████▋     | 629/1349 [10:02<11:27,  1.05it/s][A
Iteration:  47%|████▋     | 630/1349 [10:02<11:26,  1.05it/s][A
Iteration:  47%|████▋     | 631/1349 [10:03<11:25,  1.05it/s][A
Iteration:  47%|████▋     | 632/1349 [10:04<11:24,  1.05it/s][A
Iteration:  47%|████▋     | 633/1349 [10:05<11:23,  1.05it/s][A
Iteration:  47%|████▋     | 634/1349 [10:06<11:24,  1.04it/s][A
Iteration:  47%|████▋     | 635/1349 [10:07<11:24,  1.04it/s][A
Iteration:  47%|████▋     | 636/1349 [10:08<11:24,  1.04it/s][A
Iteration:  47%|████▋     | 637/1349 [10:09<11:22,  1.04it/s][A
Iteration:  47%|████▋     | 638/1349 [10:10<11:20,  1.04it/s][A
Iteration:  47%|████▋     | 639/1349 [10:11<11:20,  1.04it/s][A
Iteration:  47%|████▋     | 640/1349 [10:12<11:18,  1.05it/s][A
Iteration:  48%|████▊     | 641/1349 [10:13<11:19,  1.04it/s][A
Iteration:  48%|████▊    

Iteration:  56%|█████▌    | 753/1349 [12:00<09:31,  1.04it/s][A
Iteration:  56%|█████▌    | 754/1349 [12:01<09:31,  1.04it/s][A
Iteration:  56%|█████▌    | 755/1349 [12:02<09:29,  1.04it/s][A
Iteration:  56%|█████▌    | 756/1349 [12:03<09:29,  1.04it/s][A
Iteration:  56%|█████▌    | 757/1349 [12:04<09:27,  1.04it/s][A
Iteration:  56%|█████▌    | 758/1349 [12:05<09:25,  1.04it/s][A
Iteration:  56%|█████▋    | 759/1349 [12:06<09:24,  1.04it/s][A
Iteration:  56%|█████▋    | 760/1349 [12:07<09:22,  1.05it/s][A
Iteration:  56%|█████▋    | 761/1349 [12:08<09:22,  1.05it/s][A
Iteration:  56%|█████▋    | 762/1349 [12:09<09:20,  1.05it/s][A
Iteration:  57%|█████▋    | 763/1349 [12:10<09:20,  1.05it/s][A
Iteration:  57%|█████▋    | 764/1349 [12:11<09:20,  1.04it/s][A
Iteration:  57%|█████▋    | 765/1349 [12:12<09:20,  1.04it/s][A
Iteration:  57%|█████▋    | 766/1349 [12:13<09:19,  1.04it/s][A
Iteration:  57%|█████▋    | 767/1349 [12:14<09:18,  1.04it/s][A
Iteration:  57%|█████▋   

Iteration:  65%|██████▌   | 879/1349 [14:01<07:29,  1.05it/s][A
Iteration:  65%|██████▌   | 880/1349 [14:02<07:29,  1.04it/s][A
Iteration:  65%|██████▌   | 881/1349 [14:03<07:28,  1.04it/s][A
Iteration:  65%|██████▌   | 882/1349 [14:04<07:26,  1.05it/s][A
Iteration:  65%|██████▌   | 883/1349 [14:05<07:25,  1.05it/s][A
Iteration:  66%|██████▌   | 884/1349 [14:06<07:25,  1.04it/s][A
Iteration:  66%|██████▌   | 885/1349 [14:07<07:25,  1.04it/s][A
Iteration:  66%|██████▌   | 886/1349 [14:08<07:23,  1.04it/s][A
Iteration:  66%|██████▌   | 887/1349 [14:09<07:22,  1.04it/s][A
Iteration:  66%|██████▌   | 888/1349 [14:10<07:21,  1.04it/s][A
Iteration:  66%|██████▌   | 889/1349 [14:11<07:21,  1.04it/s][A
Iteration:  66%|██████▌   | 890/1349 [14:11<07:20,  1.04it/s][A
Iteration:  66%|██████▌   | 891/1349 [14:12<07:19,  1.04it/s][A
Iteration:  66%|██████▌   | 892/1349 [14:13<07:18,  1.04it/s][A
Iteration:  66%|██████▌   | 893/1349 [14:14<07:17,  1.04it/s][A
Iteration:  66%|██████▋  

Iteration:  74%|███████▍  | 1004/1349 [16:01<05:31,  1.04it/s][A
Iteration:  74%|███████▍  | 1005/1349 [16:02<05:29,  1.04it/s][A
Iteration:  75%|███████▍  | 1006/1349 [16:03<05:28,  1.04it/s][A
Iteration:  75%|███████▍  | 1007/1349 [16:04<05:27,  1.05it/s][A
Iteration:  75%|███████▍  | 1008/1349 [16:04<05:25,  1.05it/s][A
Iteration:  75%|███████▍  | 1009/1349 [16:05<05:25,  1.04it/s][A
Iteration:  75%|███████▍  | 1010/1349 [16:06<05:24,  1.05it/s][A
Iteration:  75%|███████▍  | 1011/1349 [16:07<05:22,  1.05it/s][A
Iteration:  75%|███████▌  | 1012/1349 [16:08<05:21,  1.05it/s][A
Iteration:  75%|███████▌  | 1013/1349 [16:09<05:20,  1.05it/s][A
Iteration:  75%|███████▌  | 1014/1349 [16:10<05:20,  1.05it/s][A
Iteration:  75%|███████▌  | 1015/1349 [16:11<05:20,  1.04it/s][A
Iteration:  75%|███████▌  | 1016/1349 [16:12<05:19,  1.04it/s][A
Iteration:  75%|███████▌  | 1017/1349 [16:13<05:17,  1.04it/s][A
Iteration:  75%|███████▌  | 1018/1349 [16:14<05:16,  1.05it/s][A
Iteration:

Iteration:  84%|████████▎ | 1128/1349 [17:59<03:31,  1.04it/s][A
Iteration:  84%|████████▎ | 1129/1349 [18:00<03:30,  1.04it/s][A
Iteration:  84%|████████▍ | 1130/1349 [18:01<03:29,  1.04it/s][A
Iteration:  84%|████████▍ | 1131/1349 [18:02<03:28,  1.05it/s][A
Iteration:  84%|████████▍ | 1132/1349 [18:03<03:27,  1.04it/s][A
Iteration:  84%|████████▍ | 1133/1349 [18:04<03:27,  1.04it/s][A
Iteration:  84%|████████▍ | 1134/1349 [18:05<03:25,  1.04it/s][A
Iteration:  84%|████████▍ | 1135/1349 [18:06<03:24,  1.05it/s][A
Iteration:  84%|████████▍ | 1136/1349 [18:07<03:23,  1.04it/s][A
Iteration:  84%|████████▍ | 1137/1349 [18:08<03:22,  1.05it/s][A
Iteration:  84%|████████▍ | 1138/1349 [18:09<03:21,  1.05it/s][A
Iteration:  84%|████████▍ | 1139/1349 [18:10<03:20,  1.05it/s][A
Iteration:  85%|████████▍ | 1140/1349 [18:11<03:19,  1.05it/s][A
Iteration:  85%|████████▍ | 1141/1349 [18:12<03:18,  1.05it/s][A
Iteration:  85%|████████▍ | 1142/1349 [18:13<03:17,  1.05it/s][A
Iteration:

## Evaluate Model

In [22]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [23]:
preds, eval_loss, eval_accuracy = eval_token_model(model=model, 
                                                   eval_dataloader=valid_dataloader, 
                                                   model_config=model_config, 
                                                   device_config=device_config, 
                                                   label_list=label_list,
                                                   eval_func=flat_accuracy)

Validation loss: 0.2494839150706927
Validation Accuracy: 0.9776226190476196
Validation F1-Score: 0.7582687879422232
