In [1]:
import os
os.chdir('../')

In [2]:
import re
import sys
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import transformers
import torch.nn as nn
from transformers import AdamW
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [29]:
device = torch.device("cpu")

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [5]:
dataset = pd.read_pickle('pickle/dataset.pickle')
print(dataset.shape)

(123235361, 2)


In [6]:
sentences = dataset['Sentences'].values.tolist()
tags = dataset['Tags'].values.tolist()

In [7]:
# cleaned_tags = [re.sub(r'''[',"\[\]]''', "", s) for s in tags]

# with open('cleaned_tags.pkl', 'wb') as f:
#     pickle.dump(cleaned_tags, f)

with open('cleaned_tags.pkl', 'rb') as f:
    cleaned_tags = pickle.load(f)

print(len(sentences), len(cleaned_tags))

123235361 123235361


In [8]:
unique_tags = list(set(cleaned_tags))

tag2idx = {}
for idx, tag in enumerate(unique_tags):
    tag2idx[tag] = idx

In [9]:
sentences_list = []
token_tag_list = []
sentence = []
token_tag = []
for token, tag in tqdm(zip(sentences, cleaned_tags)):
    sentence.append(token)
    token_tag.append(tag)
    try:
        if bool(re.match(r"[.]", token)):
            if len(sentence) >= 4:
                sentences_list.append(sentence)
                token_tag_list.append(token_tag)
                sentence = []
                token_tag = []
    except:
        print(f"Error in re for token: {token}")
        sentences_list.append(sentence)
        token_tag_list.append(token_tag)
        sentence = []
        token_tag = []

1805258it [00:07, 288972.60it/s]

Error in re for token: nan


2487122it [00:10, 131836.90it/s]

Error in re for token: nan
Error in re for token: nan


3240080it [00:11, 663371.86it/s]

Error in re for token: nan


3527263it [00:11, 668662.45it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


4686564it [00:14, 715135.11it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


5286490it [00:17, 404305.81it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


5683454it [00:17, 684726.57it/s]

Error in re for token: nan


6377357it [00:18, 732522.26it/s]

Error in re for token: nan


6529093it [00:18, 724528.96it/s]

Error in re for token: nan
Error in re for token: nan


6819593it [00:20, 279013.75it/s]

Error in re for token: nan


8530300it [00:23, 742572.29it/s]

Error in re for token: nan


9525771it [00:26, 672135.43it/s]

Error in re for token: nan


9754910it [00:26, 726178.32it/s]

Error in re for token: nan


9991767it [00:26, 764251.23it/s]

Error in re for token: nan


10302878it [00:27, 770048.27it/s]

Error in re for token: nan


10615012it [00:27, 775052.31it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


11873162it [00:30, 501775.25it/s]

Error in re for token: nan


12265185it [00:31, 715656.85it/s]

Error in re for token: nan


12659628it [00:31, 772733.63it/s]

Error in re for token: nan
Error in re for token: nan


16514327it [00:38, 749063.43it/s]

Error in re for token: nan


17341979it [00:39, 701352.21it/s]

Error in re for token: nan
Error in re for token: nan


17736682it [00:40, 767743.99it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


18924287it [00:41, 788592.13it/s]

Error in re for token: nan


19080991it [00:43, 169750.93it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


19388533it [00:44, 410662.89it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


20254981it [00:45, 768753.74it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


20868940it [00:46, 727036.29it/s]

Error in re for token: nan


21862675it [00:47, 752850.27it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


22561520it [00:48, 768816.44it/s]

Error in re for token: nan
Error in re for token: nan


23182182it [00:49, 747514.45it/s]

Error in re for token: nan


24152573it [00:51, 383525.98it/s]

Error in re for token: nan


24352536it [00:53, 173330.73it/s]

Error in re for token: nan


24762750it [00:54, 497840.67it/s]

Error in re for token: nan
Error in re for token: nan


24971519it [00:54, 605857.90it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


27068803it [00:57, 769548.70it/s]

Error in re for token: nan


27374035it [00:57, 746392.20it/s]

Error in re for token: nan


27678910it [00:58, 751679.46it/s]

Error in re for token: nan


27993152it [00:58, 763381.34it/s]

Error in re for token: nan
Error in re for token: nan


28449499it [00:59, 733421.83it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


29075773it [01:00, 776992.84it/s]

Error in re for token: nan


31905287it [01:06, 743795.35it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


33224394it [01:07, 758481.76it/s]

Error in re for token: nan
Error in re for token: nan


33538349it [01:08, 770231.15it/s]

Error in re for token: nan


33928375it [01:08, 774428.32it/s]

Error in re for token: nan


34390236it [01:09, 758302.61it/s]

Error in re for token: nan
Error in re for token: nan


35091142it [01:10, 769477.98it/s]

Error in re for token: nan
Error in re for token: nan


36650408it [01:12, 771904.52it/s]

Error in re for token: nan


37038162it [01:12, 759293.85it/s]

Error in re for token: nan


37965382it [01:14, 736500.29it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


38419778it [01:14, 719454.69it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


39154312it [01:22, 286605.09it/s]

Error in re for token: nan


40803526it [01:24, 703268.79it/s]

Error in re for token: nan


42242080it [01:26, 776974.24it/s]

Error in re for token: nan


42472201it [01:26, 746135.56it/s]

Error in re for token: nan


43096063it [01:27, 773300.58it/s]

Error in re for token: nan
Error in re for token: nan


43879129it [01:28, 779445.65it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


44193249it [01:29, 760164.76it/s]

Error in re for token: nan


44961865it [01:30, 753099.96it/s]

Error in re for token: nan


45357790it [01:30, 780944.29it/s]

Error in re for token: nan


47041943it [01:32, 759780.85it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


47357034it [01:33, 777585.44it/s]

Error in re for token: nan
Error in re for token: nan


48774649it [01:38, 126498.44it/s]

Error in re for token: nan


50294936it [01:40, 752604.36it/s]

Error in re for token: nan


51287892it [01:41, 761001.41it/s]

Error in re for token: nan
Error in re for token: nan


51521283it [01:41, 772093.21it/s]

Error in re for token: nan


52485405it [01:43, 683372.41it/s]

Error in re for token: nan


54472799it [01:45, 752657.70it/s]

Error in re for token: nan


54774056it [01:46, 719575.19it/s]

Error in re for token: nan
Error in re for token: nan


55365671it [01:47, 741758.96it/s]

Error in re for token: nan


55592977it [01:47, 749927.71it/s]

Error in re for token: nan


59901070it [01:54, 644473.99it/s]

Error in re for token: nan


60453095it [01:55, 692126.36it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


60965750it [01:55, 711969.88it/s]

Error in re for token: nan


62156417it [02:23, 557262.49it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


63153766it [02:24, 763578.61it/s]

Error in re for token: nan
Error in re for token: nan


63693255it [02:25, 768469.37it/s]

Error in re for token: nan
Error in re for token: nan


63925347it [02:25, 766942.30it/s]

Error in re for token: nan


65155605it [02:27, 752790.19it/s]

Error in re for token: nan


68434304it [02:31, 748952.63it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


69577751it [02:33, 729686.36it/s]

Error in re for token: nan


72380788it [02:37, 749882.29it/s]

Error in re for token: nan


72981966it [02:37, 740478.42it/s]

Error in re for token: nan


73279250it [02:38, 734569.63it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


75307122it [02:41, 740636.19it/s]

Error in re for token: nan


75978857it [02:42, 732098.46it/s]

Error in re for token: nan


76425157it [02:42, 735862.48it/s]

Error in re for token: nan


76713153it [03:23, 15144.12it/s] 

Error in re for token: nan


78082315it [03:25, 719964.48it/s]

Error in re for token: nan
Error in re for token: nan


78237241it [03:25, 744269.44it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


79829624it [03:27, 749419.19it/s]

Error in re for token: nan


80055101it [03:28, 742776.04it/s]

Error in re for token: nan
Error in re for token: nan


80875480it [03:29, 736473.87it/s]

Error in re for token: nan


82747001it [03:31, 752822.76it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


85696718it [03:35, 721926.85it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


85838570it [03:36, 684640.57it/s]

Error in re for token: nan
Error in re for token: nan


86365672it [03:37, 601835.55it/s]

Error in re for token: nan
Error in re for token: nan


87416679it [03:38, 722602.32it/s]

Error in re for token: nan


87633261it [03:38, 718088.44it/s]

Error in re for token: nan
Error in re for token: nan


87847877it [03:39, 709513.52it/s]

Error in re for token: nan
Error in re for token: nan


89415477it [03:41, 652950.51it/s]

Error in re for token: nan


90616969it [03:43, 709574.23it/s]

Error in re for token: nan


91553480it [03:44, 662556.99it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


91909822it [03:45, 688020.33it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


92179536it [03:45, 660742.70it/s]

Error in re for token: nan
Error in re for token: nan


92672983it [03:46, 691279.41it/s]

Error in re for token: nan


93035132it [03:46, 716099.00it/s]

Error in re for token: nan


94389390it [03:48, 715510.36it/s]

Error in re for token: nan


95165163it [03:49, 692405.58it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


95745287it [03:50, 716139.78it/s]

Error in re for token: nan


96059157it [04:43, 8236.75it/s]  

Error in re for token: nan


97229566it [04:45, 558365.27it/s]

Error in re for token: nan
Error in re for token: nan


97592679it [04:45, 693953.20it/s]

Error in re for token: nan


97951717it [04:46, 692312.95it/s]

Error in re for token: nan
Error in re for token: nan


98245953it [04:46, 725043.77it/s]

Error in re for token: nan


99091472it [04:47, 663993.59it/s]

Error in re for token: nan


99290320it [04:48, 644172.85it/s]

Error in re for token: nan


99981523it [04:49, 693209.70it/s]

Error in re for token: nan
Error in re for token: nan


100791437it [04:50, 721445.13it/s]

Error in re for token: nan
Error in re for token: nan


102545439it [04:52, 662072.41it/s]

Error in re for token: nan
Error in re for token: nan


102827791it [04:53, 690117.42it/s]

Error in re for token: nan


103186965it [04:53, 713474.12it/s]

Error in re for token: nan


103332178it [04:53, 716796.68it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


104226641it [04:55, 647785.78it/s]

Error in re for token: nan


105673516it [04:57, 676378.00it/s]

Error in re for token: nan


106026971it [04:57, 706382.62it/s]

Error in re for token: nan


106323363it [04:58, 728228.83it/s]

Error in re for token: nan


107261070it [04:59, 716727.46it/s]

Error in re for token: nan


107689336it [05:00, 710816.63it/s]

Error in re for token: nan
Error in re for token: nan


108422747it [05:01, 730258.21it/s]

Error in re for token: nan


109371868it [05:02, 704797.76it/s]

Error in re for token: nan


110583961it [05:04, 679198.20it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


110949770it [05:04, 724097.68it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


111813117it [05:06, 692625.11it/s]

Error in re for token: nan


112295761it [05:06, 679473.41it/s]

Error in re for token: nan
Error in re for token: nan


113352129it [05:08, 693729.06it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


113781009it [05:08, 710527.24it/s]

Error in re for token: nan
Error in re for token: nan


115138357it [05:10, 709289.77it/s]

Error in re for token: nan
Error in re for token: nan


116055750it [05:12, 705633.69it/s]

Error in re for token: nan
Error in re for token: nan


117783522it [05:14, 590208.22it/s]

Error in re for token: nan


120153630it [06:32, 16740.21it/s] 

Error in re for token: nan
Error in re for token: nan


120576669it [06:32, 129332.62it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


120716966it [06:33, 218317.69it/s]

Error in re for token: nan


121487248it [06:34, 677719.59it/s]

Error in re for token: nan


121776027it [06:34, 711584.23it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


122267057it [06:35, 659009.00it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


122666426it [06:35, 664626.60it/s]

Error in re for token: nan


123235361it [06:36, 310643.37it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
E




In [10]:
sentences_length = []
for val in tqdm(sentences_list):
    sentences_length.append(len(val))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5911818/5911818 [00:02<00:00, 2583444.99it/s]


In [11]:
tokenized_sentences_list = []
tokenized_sent_tag_list = []
for token_list, tag_list in tqdm(zip(sentences_list, token_tag_list)):
    updated_token_list = []
    updated_tag_list = []
    for token, tag in zip(token_list, tag_list):
        try:
            tokenized_list = tokenizer.tokenize(token)
            updated_token_list.extend(tokenized_list)
            updated_tag_list.extend([tag2idx[tag]]*len(tokenized_list))
        except:
            print(f"Tokenizatin failed for token: {token}")
    tokenized_sentences_list.append(updated_token_list)
    tokenized_sent_tag_list.append(updated_tag_list)

1000it [00:01, 821.09it/s]


In [None]:
# Storing Tokenized Data
with open('pickle/tokenized_sentences_list.pkl', 'wb') as f:
    pickle.dump(tokenized_sentences_list, f)
with open('pickle/tokenized_sent_tag_list.pkl', 'wb') as f:
    pickle.dump(tokenized_sent_tag_list, f)
    
# # Loading Processed Tokenized Data
# with open('pickle/tokenized_sentences_list.pkl', 'rb') as f:
#     tokenized_sentences_list = pickle.load(f)
    
# with open('pickle/tokenized_sent_tag_list.pkl', 'rb') as f:
#     tokenized_sent_tag_list = pickle.load(f)

In [12]:
# Mapping tokens to ids
input_ids = []
for tokenized_sentence in tqdm(tokenized_sentences_list):
    input_ids.append(tokenizer.convert_tokens_to_ids(tokenized_sentence))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 55714.57it/s]


In [13]:
len(input_ids), len(tokenized_sent_tag_list)

(1000, 1000)

In [14]:
attention_mask = []
for input_ in tqdm(input_ids):
    attention_mask.append(torch.ones(len(input_)))

padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
    attention_mask, 
    batch_first=True, 
    padding_value=0.0)

padded_input_ids = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(input_) for input_ in input_ids], 
    batch_first=True, 
    padding_value=0.0)

padded_tags = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(tag_) for tag_ in tokenized_sent_tag_list], 
    batch_first=True, 
    padding_value=0.0)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 21211.74it/s]


In [15]:
class NERBert(nn.Module):
    
    def __init__(self, tag_count=4):
        super(NERBert, self).__init__()
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, tag_count)
    
    def forward(self, input_ids, attention_mask):
        
        output = self.bert(input_ids, attention_mask=attention_mask) # Model gives 'last_hidden_state' and 'pooler_output'
        
        pre_classifier_layer = self.dropout(output.last_hidden_state)
        model_output = self.classifier(pre_classifier_layer)
        
        return model_output

In [54]:
train_tokens, temp_tokens, train_tags, temp_tags, train_mask, temp_mask =  train_test_split(
    padded_input_ids, padded_tags, padded_attention_mask,
    random_state=2018, 
    test_size=0.3
)

val_tokens, test_tokens, val_tags, test_tags, val_mask, test_mask = train_test_split(
    temp_tokens, temp_tags, temp_mask,
    random_state=2018, 
    test_size=0.5
)

In [17]:
batch_size = 32
# wrap tensors
train_data = TensorDataset(train_tokens, train_mask, train_tags)

# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)

# dataLoader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# wrap tensors
val_data = TensorDataset(val_tokens, val_mask, val_tags)

# sampler for sampling the data during training
val_sampler = SequentialSampler(val_data)

# dataLoader for validation set
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

In [25]:
def train(model, optimizer, loss_criteria, train_dataloader):
    try:
        model.train()
        
        total_loss = 0
        total_logits = []
        
        # iterate over batches
        for step,batch in enumerate(train_dataloader):

            # progress update after every 50 batches.
            if step % 50 == 0 and not step == 0:
                print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))

            # push the batch to gpu
            batch = [r.to(device) for r in batch]

            sent_id, mask, labels = batch

            # clear previously calculated gradients 
            model.zero_grad()        

            # get model predictions for the current batch
            logits = model(sent_id, mask)

            # compute the loss between actual and predicted values
            loss = loss_criteria(logits.permute(0, 2, 1), labels)

            # add on to the total loss
            total_loss = total_loss + loss.item()

            # backward pass to calculate the gradients
            loss.backward()

            # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # update parameters
            optimizer.step()

            # model predictions are stored on GPU. So, push it to CPU
            logits = logits.detach().cpu().numpy()

            # append the model predictions
            total_logits.append(logits)
            
        # compute the training loss of the epoch
        avg_loss = total_loss / len(train_dataloader)
        
        total_logits = np.concatenate(total_logits, axis=0)
        

        return avg_loss, total_logits
    except Exception as e:
        print(f"Error during training the model on line: {sys.exc_info()[2].tb_lineno}")
        print(e)

In [32]:
# function for evaluating the model
def evaluate(model, val_dataloader, loss_criteria):
  
    print("\nEvaluating...")

    # deactivate dropout layers
    model.eval()

    total_loss, total_accuracy = 0, 0

    # empty list to save the model predictions
    total_logits = []

    # iterate over batches
    for step,batch in enumerate(val_dataloader):

        # Progress update every 50 batches.
        if step % 50 == 0 and not step == 0:

            # Calculate elapsed time in minutes.
#             elapsed = format_time(time.time() - t0)

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))

        # push the batch to gpu
        batch = [t.to(device) for t in batch]

        sent_id, mask, labels = batch

        # deactivate autograd
        with torch.no_grad():

            # model predictions
            logits = model(sent_id, mask)

            # compute the validation loss between actual and predicted values
            loss = loss_criteria(logits.permute(0, 2, 1),labels)

            total_loss = total_loss + loss.item()

            logits = logits.detach().cpu().numpy()

            total_logits.append(logits)

    # compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader) 

    # reshape the predictions in form of (number of samples, no. of classes)
    total_logits  = np.concatenate(total_logits, axis=0)

    return avg_loss, total_logits

In [33]:
model = NERBert(tag_count=len(tag2idx))
model = model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5) # learning rate

criterion = nn.CrossEntropyLoss()

In [34]:
epochs=50
# set initial loss to infinite
best_valid_loss = float('inf')

# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]

#for each epoch
for epoch in range(epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    
    #train model
    train_loss, _ = train(
        model = model, 
        optimizer = optimizer, 
        loss_criteria = criterion, 
        train_dataloader = train_dataloader 
        
    )
    
    #evaluate model
    valid_loss, _ = evaluate(
        model = model, 
        val_dataloader = val_dataloader, 
        loss_criteria = criterion
    )
    
    #save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'trained_weights/saved_weights.pt')
    
    # append training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print(f'\nTraining Loss: {train_loss:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')


 Epoch 1 / 10

Evaluating...

Training Loss: 1.271
Validation Loss: 0.642

 Epoch 2 / 10

Evaluating...

Training Loss: 0.484
Validation Loss: 0.319

 Epoch 3 / 10

Evaluating...

Training Loss: 0.299
Validation Loss: 0.276

 Epoch 4 / 10

Evaluating...

Training Loss: 0.262
Validation Loss: 0.250

 Epoch 5 / 10

Evaluating...

Training Loss: 0.234
Validation Loss: 0.224

 Epoch 6 / 10

Evaluating...

Training Loss: 0.208
Validation Loss: 0.202

 Epoch 7 / 10

Evaluating...

Training Loss: 0.183
Validation Loss: 0.185

 Epoch 8 / 10

Evaluating...

Training Loss: 0.164
Validation Loss: 0.177

 Epoch 9 / 10

Evaluating...

Training Loss: 0.147
Validation Loss: 0.166

 Epoch 10 / 10

Evaluating...

Training Loss: 0.134
Validation Loss: 0.161


In [36]:
def get_prediction_from_logits(logits):
    try:
        tag_prob = nn.Softmax(dim=2)(logits)
        tag_prediction = torch.argmax(tag_prob, dim=2).detach().cpu().numpy()
        return tag_prediction
    except Exception as e:
        print(f"Error in line: {sys.exc_info()[2].tb_lineno}")
        print(e)

In [52]:
def classification_result(tag2idx, c_tag_id):
    try:
        prediction_result = []
        for sent_ in c_tag_id:
            prediction_result.append(
                list(map(lambda x: list(tag2idx.keys())[list(tag2idx.values()).index(x)], sent_))
            )
            
        tagged_entity = np.concatenate(prediction_result, axis=0)
        return tagged_entity
    except Exception as e:
        print(f"Error in line: {sys.exc_info()[2].tb_lineno}")
        print(e)        

In [35]:
#load weights of best model
path = 'trained_weights/saved_weights.pt'
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [37]:
# get predictions for test data
with torch.no_grad():
    logits = model(test_tokens.to(device), test_mask.to(device))
    preds = get_prediction_from_logits(logits=logits)

In [55]:
test_tags = test_tags.detach().cpu().numpy()

In [56]:
test_tags = classification_result(
    tag2idx = tag2idx, 
    c_tag_id = test_tags
)
preds = classification_result(
    tag2idx = tag2idx, 
    c_tag_id = preds
)

In [57]:
test_tags.shape, preds.shape

((9900,), (9900,))

In [58]:
print(classification_report(np.squeeze(test_tags.reshape(1, -1)), np.squeeze(preds.reshape(1, -1))))

                precision    recall  f1-score   support

    B-LOCATION       0.30      0.37      0.33        57
        B-MISC       0.40      0.32      0.36       212
B-ORGANIZATION       0.00      0.00      0.00        16
      B-PERSON       0.51      0.59      0.55       133
    I-LOCATION       0.00      0.00      0.00        24
        I-MISC       1.00      0.99      0.99      5593
I-ORGANIZATION       0.00      0.00      0.00         1
      I-PERSON       0.47      0.34      0.40        53
             O       0.94      0.96      0.95      3811

      accuracy                           0.95      9900
     macro avg       0.40      0.40      0.40      9900
  weighted avg       0.95      0.95      0.95      9900

