In [1]:
import os
os.chdir('../')

In [2]:
import re
import sys
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import transformers
import torch.nn as nn
from transformers import AdamW
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [3]:
device = torch.device("cpu")

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [5]:
dataset = pd.read_pickle('pickle/dataset.pickle')
print(dataset.shape)

(123235361, 2)


In [6]:
sentences = dataset['Sentences'].values.tolist()
tags = dataset['Tags'].values.tolist()

In [7]:
# cleaned_tags = [re.sub(r'''[',"\[\]]''', "", s) for s in tags]

# with open('cleaned_tags.pkl', 'wb') as f:
#     pickle.dump(cleaned_tags, f)

with open('cleaned_tags.pkl', 'rb') as f:
    cleaned_tags = pickle.load(f)

print(len(sentences), len(cleaned_tags))

123235361 123235361


In [8]:
unique_tags = list(set(cleaned_tags))

tag2idx = {}
for idx, tag in enumerate(unique_tags):
    tag2idx[tag] = idx

In [9]:
sentences_list = []
token_tag_list = []
sentence = []
token_tag = []
for token, tag in tqdm(zip(sentences, cleaned_tags)):
    sentence.append(token)
    token_tag.append(tag)
    try:
        if bool(re.match(r"[.]", token)):
            if len(sentence) >= 4:
                sentences_list.append(sentence)
                token_tag_list.append(token_tag)
                sentence = []
                token_tag = []
    except:
        print(f"Error in re for token: {token}")
        sentences_list.append(sentence)
        token_tag_list.append(token_tag)
        sentence = []
        token_tag = []

1831750it [00:07, 382015.51it/s]

Error in re for token: nan


2463479it [00:08, 742251.19it/s]

Error in re for token: nan
Error in re for token: nan


3333244it [00:10, 690676.49it/s]

Error in re for token: nan


3492530it [00:11, 739512.51it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


4668797it [00:14, 771409.45it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


5289218it [00:16, 364539.75it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


5679742it [00:16, 652422.23it/s]

Error in re for token: nan


6397807it [00:17, 782122.08it/s]

Error in re for token: nan


6556808it [00:18, 785678.41it/s]

Error in re for token: nan
Error in re for token: nan


6862829it [00:19, 296406.38it/s]

Error in re for token: nan


8578611it [00:22, 793524.68it/s]

Error in re for token: nan


9446360it [00:24, 657901.90it/s]

Error in re for token: nan


9764659it [00:25, 753120.76it/s]

Error in re for token: nan


10082292it [00:25, 762835.99it/s]

Error in re for token: nan


10317177it [00:26, 769544.76it/s]

Error in re for token: nan


10550942it [00:26, 766533.14it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


11866578it [00:29, 466511.46it/s]

Error in re for token: nan


12266669it [00:30, 720803.53it/s]

Error in re for token: nan


12587819it [00:30, 774833.15it/s]

Error in re for token: nan
Error in re for token: nan


16381461it [00:37, 776790.71it/s]

Error in re for token: nan


17336783it [00:38, 796841.07it/s]

Error in re for token: nan
Error in re for token: nan


17734774it [00:39, 780334.93it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


18931752it [00:40, 776005.34it/s]

Error in re for token: nan


19157333it [00:42, 194424.95it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


19393790it [00:43, 388553.19it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


20270328it [00:44, 767081.67it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


20819999it [00:44, 759753.03it/s]

Error in re for token: nan


21939663it [00:46, 794284.79it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


22661066it [00:47, 793524.28it/s]

Error in re for token: nan
Error in re for token: nan


23220302it [00:48, 768072.20it/s]

Error in re for token: nan


24091387it [00:49, 783847.13it/s]

Error in re for token: nan


24325791it [00:51, 189092.42it/s]

Error in re for token: nan


24635883it [00:51, 442570.90it/s]

Error in re for token: nan
Error in re for token: nan


24944652it [00:52, 649218.39it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


26952383it [00:55, 739358.62it/s]

Error in re for token: nan


27341942it [00:55, 770095.31it/s]

Error in re for token: nan


27731014it [00:56, 772438.82it/s]

Error in re for token: nan


28047637it [00:56, 784303.82it/s]

Error in re for token: nan
Error in re for token: nan


28439395it [00:57, 773644.85it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


29081880it [00:57, 751140.19it/s]

Error in re for token: nan


31782680it [01:03, 742359.80it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


33333248it [01:05, 748289.78it/s]

Error in re for token: nan
Error in re for token: nan


33562231it [01:05, 750941.98it/s]

Error in re for token: nan


33955648it [01:06, 772408.85it/s]

Error in re for token: nan


34465736it [01:07, 672770.30it/s]

Error in re for token: nan
Error in re for token: nan


35129280it [01:07, 731311.72it/s]

Error in re for token: nan
Error in re for token: nan


36641851it [01:10, 732370.45it/s]

Error in re for token: nan


37019596it [01:10, 726507.89it/s]

Error in re for token: nan


37960538it [01:11, 784155.89it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


38194782it [01:12, 766555.36it/s]

Error in re for token: nan


38423864it [01:12, 752503.65it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


39119795it [01:15, 428158.58it/s]

Error in re for token: nan


40795743it [01:17, 770225.74it/s]

Error in re for token: nan


42121291it [01:24, 565708.30it/s]

Error in re for token: nan


42503064it [01:24, 711225.06it/s]

Error in re for token: nan


43098309it [01:25, 695461.56it/s]

Error in re for token: nan
Error in re for token: nan


43958125it [01:26, 775781.89it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


44197790it [01:27, 788094.78it/s]

Error in re for token: nan


44916700it [01:28, 780434.11it/s]

Error in re for token: nan


45386937it [01:28, 766708.94it/s]

Error in re for token: nan


47125582it [01:31, 709910.38it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


47360019it [01:31, 754834.87it/s]

Error in re for token: nan
Error in re for token: nan


48792772it [01:53, 30897.81it/s] 

Error in re for token: nan


50278132it [01:55, 745281.15it/s]

Error in re for token: nan


51219109it [01:56, 704415.00it/s]

Error in re for token: nan
Error in re for token: nan


51419907it [01:57, 569441.31it/s]

Error in re for token: nan


52354575it [01:58, 704204.56it/s]

Error in re for token: nan


54461757it [02:01, 759336.85it/s]

Error in re for token: nan


54851396it [02:02, 774803.70it/s]

Error in re for token: nan
Error in re for token: nan


55387017it [02:03, 742925.46it/s]

Error in re for token: nan


55535603it [02:03, 706764.93it/s]

Error in re for token: nan


59861557it [02:10, 714991.71it/s]

Error in re for token: nan


60468679it [02:11, 751361.53it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


60926026it [02:12, 730270.19it/s]

Error in re for token: nan


62176008it [02:44, 540148.60it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


63128632it [02:45, 772936.49it/s]

Error in re for token: nan
Error in re for token: nan


63603816it [02:46, 782313.29it/s]

Error in re for token: nan
Error in re for token: nan


63915607it [02:46, 771742.49it/s]

Error in re for token: nan


65147969it [02:48, 751490.29it/s]

Error in re for token: nan


68468328it [02:53, 745385.40it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


69574146it [02:54, 731352.82it/s]

Error in re for token: nan


72393576it [02:58, 741783.37it/s]

Error in re for token: nan


72997723it [02:59, 749681.92it/s]

Error in re for token: nan


73147087it [02:59, 735640.85it/s]

Error in re for token: nan


75248381it [03:02, 748584.59it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


75994084it [03:03, 720473.05it/s]

Error in re for token: nan


76437808it [03:04, 731082.47it/s]

Error in re for token: nan


76723213it [03:46, 16553.59it/s] 

Error in re for token: nan


78086695it [03:48, 685358.83it/s]

Error in re for token: nan
Error in re for token: nan


78236569it [03:48, 715206.45it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


79855203it [03:50, 769694.45it/s]

Error in re for token: nan


80008141it [03:50, 751394.21it/s]

Error in re for token: nan
Error in re for token: nan


80761782it [03:51, 741309.60it/s]

Error in re for token: nan


82729542it [03:54, 749829.64it/s]

Error in re for token: nan
Error in re for token: nan


82956746it [03:54, 744439.07it/s]

Error in re for token: nan


85632297it [03:58, 750935.68it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


85860323it [03:58, 752543.99it/s]

Error in re for token: nan
Error in re for token: nan


86396651it [03:59, 755715.79it/s]

Error in re for token: nan
Error in re for token: nan


87465005it [04:00, 759719.83it/s]

Error in re for token: nan
Error in re for token: nan


87760863it [04:01, 672386.81it/s]

Error in re for token: nan
Error in re for token: nan


87911195it [04:01, 709873.63it/s]

Error in re for token: nan


89499305it [04:03, 672377.03it/s]

Error in re for token: nan


90674434it [04:05, 724423.22it/s]

Error in re for token: nan


91578435it [04:06, 737431.34it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


91796033it [04:06, 692036.65it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


92200280it [04:07, 651609.93it/s]

Error in re for token: nan
Error in re for token: nan


92702111it [04:08, 719436.51it/s]

Error in re for token: nan


93069897it [04:08, 707969.83it/s]

Error in re for token: nan


94395146it [04:10, 720346.79it/s]

Error in re for token: nan


95282870it [04:11, 738509.72it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


95719801it [04:12, 712590.48it/s]

Error in re for token: nan


96067248it [05:08, 7024.56it/s]  

Error in re for token: nan


97238385it [05:10, 547776.91it/s]

Error in re for token: nan
Error in re for token: nan


97564981it [05:10, 617827.47it/s]

Error in re for token: nan


97962340it [05:11, 634911.49it/s]

Error in re for token: nan
Error in re for token: nan


98233072it [05:11, 656338.89it/s]

Error in re for token: nan


99079646it [05:12, 699319.12it/s]

Error in re for token: nan


99361269it [05:13, 689181.04it/s]

Error in re for token: nan


99983157it [05:14, 688629.39it/s]

Error in re for token: nan
Error in re for token: nan


100818752it [05:15, 691237.67it/s]

Error in re for token: nan
Error in re for token: nan


102578706it [05:18, 708877.66it/s]

Error in re for token: nan
Error in re for token: nan


102858193it [05:18, 682141.24it/s]

Error in re for token: nan


103200409it [05:19, 665377.10it/s]

Error in re for token: nan
Error in re for token: nan


103408037it [05:19, 682724.95it/s]

Error in re for token: nan
Error in re for token: nan


104147395it [05:20, 733838.95it/s]

Error in re for token: nan


105642583it [05:22, 546960.52it/s]

Error in re for token: nan


106033899it [05:23, 652346.25it/s]

Error in re for token: nan


106237487it [05:23, 668236.10it/s]

Error in re for token: nan


107286307it [05:25, 684933.59it/s]

Error in re for token: nan


107712726it [05:25, 692112.13it/s]

Error in re for token: nan
Error in re for token: nan


108294180it [05:26, 706102.15it/s]

Error in re for token: nan


109361113it [05:28, 688285.20it/s]

Error in re for token: nan


110552984it [05:30, 678939.37it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


110905040it [05:30, 698339.31it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


111745825it [05:31, 676444.53it/s]

Error in re for token: nan


112316808it [05:32, 708357.28it/s]

Error in re for token: nan


112459925it [05:32, 698774.61it/s]

Error in re for token: nan


113297922it [05:34, 643288.30it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


113889273it [05:35, 661552.25it/s]

Error in re for token: nan
Error in re for token: nan


115170917it [05:36, 685538.28it/s]

Error in re for token: nan
Error in re for token: nan


116024584it [05:38, 712284.61it/s]

Error in re for token: nan
Error in re for token: nan


117782432it [05:40, 624894.96it/s]

Error in re for token: nan


120129055it [06:59, 12786.00it/s] 

Error in re for token: nan
Error in re for token: nan


120576845it [07:00, 100617.33it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


121472105it [07:01, 688793.13it/s]

Error in re for token: nan
Error in re for token: nan


121697876it [07:01, 725063.10it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


122219221it [07:02, 735438.87it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan


122655365it [07:03, 705617.15it/s]

Error in re for token: nan


123235361it [07:03, 290671.57it/s]

Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
Error in re for token: nan
E




In [10]:
sentences_length = []
for val in tqdm(sentences_list):
    sentences_length.append(len(val))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5911818/5911818 [00:02<00:00, 2593087.87it/s]


In [None]:
tokenized_sentences_list = []
tokenized_sent_tag_list = []
for token_list, tag_list in tqdm(zip(sentences_list, token_tag_list)):
    updated_token_list = []
    updated_tag_list = []
    for token, tag in zip(token_list, tag_list):
        try:
            tokenized_list = tokenizer.tokenize(token)
            updated_token_list.extend(tokenized_list)
            updated_tag_list.extend([tag2idx[tag]]*len(tokenized_list))
        except:
            print(f"Tokenizatin failed for token: {token}")
    tokenized_sentences_list.append(updated_token_list)
    tokenized_sent_tag_list.append(updated_tag_list)

81788it [01:30, 919.94it/s] 

Tokenizatin failed for token: nan


111672it [02:05, 926.44it/s] 

Tokenizatin failed for token: nan


113539it [02:07, 928.52it/s]

Tokenizatin failed for token: nan


145558it [02:41, 921.12it/s] 

Tokenizatin failed for token: nan


162669it [02:58, 1010.25it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


212893it [03:51, 945.22it/s] 

Tokenizatin failed for token: nan


216794it [03:55, 973.80it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


242878it [04:22, 1039.41it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


257586it [04:37, 1042.71it/s]

Tokenizatin failed for token: nan


294918it [05:16, 924.62it/s] 

Tokenizatin failed for token: nan


309760it [05:31, 972.28it/s] 

Tokenizatin failed for token: nan


310058it [05:31, 950.96it/s]

Tokenizatin failed for token: nan


318281it [05:39, 1022.16it/s]

Tokenizatin failed for token: nan


399409it [07:13, 694.58it/s] 

Tokenizatin failed for token: nan


442898it [08:06, 845.35it/s]

Tokenizatin failed for token: nan


456378it [08:22, 932.53it/s]

Tokenizatin failed for token: nan


469203it [08:36, 921.92it/s] 

Tokenizatin failed for token: nan


481012it [08:49, 916.69it/s]

Tokenizatin failed for token: nan


495459it [09:05, 897.69it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


557180it [10:13, 760.24it/s] 

Tokenizatin failed for token: nan


575203it [10:34, 860.67it/s]

Tokenizatin failed for token: nan


589525it [10:50, 929.63it/s] 

Tokenizatin failed for token: nan


592687it [10:53, 916.17it/s] 

Tokenizatin failed for token: nan


770016it [14:07, 933.50it/s] 

Tokenizatin failed for token: nan


811362it [14:49, 965.09it/s] 

Tokenizatin failed for token: nan


811684it [14:49, 1034.88it/s]

Tokenizatin failed for token: nan


830392it [15:08, 896.59it/s] 

Tokenizatin failed for token: nan


832698it [15:11, 892.38it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


890271it [16:10, 938.21it/s] 

Tokenizatin failed for token: nan


894729it [16:14, 953.78it/s] 

Tokenizatin failed for token: nan


898701it [16:18, 1007.82it/s]

Tokenizatin failed for token: nan


901626it [16:21, 984.13it/s] 

Tokenizatin failed for token: nan


906773it [16:26, 929.07it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


908062it [16:28, 938.61it/s]

Tokenizatin failed for token: nan


949187it [17:10, 1018.46it/s]

Tokenizatin failed for token: nan


950649it [17:11, 1056.70it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


974746it [17:35, 1024.28it/s]

Tokenizatin failed for token: nan


1024489it [18:26, 922.87it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1028800it [18:31, 1038.31it/s]

Tokenizatin failed for token: nan


1057573it [19:00, 965.65it/s] 

Tokenizatin failed for token: nan


1061954it [19:04, 917.04it/s] 

Tokenizatin failed for token: nan


1084949it [19:28, 990.59it/s] 

Tokenizatin failed for token: nan


1135410it [20:19, 1094.54it/s]

Tokenizatin failed for token: nan


1139175it [20:22, 1008.76it/s]

Tokenizatin failed for token: nan


1155426it [20:39, 973.85it/s] 

Tokenizatin failed for token: nan


1158772it [20:42, 953.41it/s] 

Tokenizatin failed for token: nan


1168756it [20:52, 979.44it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1168972it [20:53, 1005.97it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1262818it [22:27, 851.17it/s] 

Tokenizatin failed for token: nan


1280507it [22:46, 996.07it/s] 

Tokenizatin failed for token: nan


1294924it [25:22, 977.57it/s] 

Tokenizatin failed for token: nan


1310181it [25:38, 899.38it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1331997it [26:00, 957.49it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1360967it [26:30, 1041.36it/s]

Tokenizatin failed for token: nan


1491529it [28:38, 1055.49it/s]

Tokenizatin failed for token: nan


1498960it [28:46, 1080.10it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1559723it [29:45, 1023.70it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1572987it [29:58, 1088.86it/s]

Tokenizatin failed for token: nan


1592790it [30:17, 1085.09it/s]

Tokenizatin failed for token: nan


1616450it [30:39, 1055.27it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1652097it [31:14, 1058.20it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1725183it [32:25, 1021.28it/s]

Tokenizatin failed for token: nan


1744687it [32:44, 1011.75it/s]

Tokenizatin failed for token: nan


1790202it [33:29, 1089.64it/s]

Tokenizatin failed for token: nan


1792312it [33:31, 936.84it/s] 

Tokenizatin failed for token: nan


1795135it [33:34, 907.11it/s] 

Tokenizatin failed for token: nan


1795418it [33:34, 904.15it/s]

Tokenizatin failed for token: nan


1799021it [33:38, 924.13it/s] 

Tokenizatin failed for token: nan


1810705it [33:51, 982.73it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


1847289it [34:30, 972.44it/s] 

Tokenizatin failed for token: nan


1926079it [35:50, 1010.78it/s]

Tokenizatin failed for token: nan


1990691it [36:53, 1019.86it/s]

Tokenizatin failed for token: nan


2005751it [37:08, 977.51it/s] 

Tokenizatin failed for token: nan


2034142it [37:35, 987.51it/s] 

Tokenizatin failed for token: nan


2034686it [37:36, 1071.11it/s]

Tokenizatin failed for token: nan


2072452it [38:13, 1010.24it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2089621it [38:30, 964.57it/s] 

Tokenizatin failed for token: nan


2124277it [39:04, 989.42it/s] 

Tokenizatin failed for token: nan


2144295it [39:24, 1071.38it/s]

Tokenizatin failed for token: nan


2224145it [40:42, 966.60it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2241466it [41:00, 987.86it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2311435it [42:14, 847.73it/s] 

Tokenizatin failed for token: nan


2381417it [43:25, 1009.98it/s]

Tokenizatin failed for token: nan


2427124it [44:10, 1023.97it/s]

Tokenizatin failed for token: nan


2429340it [44:12, 921.09it/s] 

Tokenizatin failed for token: nan


2437183it [44:20, 944.74it/s] 

Tokenizatin failed for token: nan


2486304it [45:09, 947.44it/s] 

Tokenizatin failed for token: nan


2579906it [46:43, 994.45it/s] 

Tokenizatin failed for token: nan


2597416it [47:00, 1036.48it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2625130it [47:27, 1007.00it/s]

Tokenizatin failed for token: nan


2634019it [47:36, 1076.74it/s]

Tokenizatin failed for token: nan


2842897it [51:01, 1009.30it/s]

Tokenizatin failed for token: nan


2872087it [51:29, 1010.32it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2892923it [51:50, 984.43it/s] 

Tokenizatin failed for token: nan


2951827it [52:47, 919.64it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


2959393it [52:54, 1063.58it/s]

Tokenizatin failed for token: nan


3002430it [53:37, 1102.53it/s]

Tokenizatin failed for token: nan


3005777it [53:40, 1119.13it/s]

Tokenizatin failed for token: nan


3023218it [53:57, 893.31it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


3038912it [54:14, 1000.03it/s]

Tokenizatin failed for token: nan


3093636it [55:13, 915.88it/s] 

Tokenizatin failed for token: nan


3257422it [1:04:13, 802.45it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


3259291it [1:04:16, 851.01it/s]

Tokenizatin failed for token: nan


3262923it [1:04:20, 876.48it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


3311909it [1:05:16, 878.68it/s] 

Tokenizatin failed for token: nan


3447856it [1:07:41, 929.06it/s] 

Tokenizatin failed for token: nan


3475890it [1:08:10, 902.78it/s] 

Tokenizatin failed for token: nan


3484160it [1:08:18, 979.55it/s] 

Tokenizatin failed for token: nan


3492043it [1:08:27, 975.45it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


3589493it [1:10:08, 978.23it/s] 

Tokenizatin failed for token: nan


3622592it [1:10:44, 942.20it/s] 

Tokenizatin failed for token: nan


3641268it [1:11:03, 955.74it/s] 

Tokenizatin failed for token: nan


3658675it [1:11:22, 931.45it/s] 

Tokenizatin failed for token: nan


3722219it [1:12:28, 906.70it/s] 

Tokenizatin failed for token: nan


3729798it [1:12:36, 1034.22it/s]

Tokenizatin failed for token: nan


3731860it [1:12:38, 974.45it/s] 

Tokenizatin failed for token: nan


3735110it [1:12:41, 999.05it/s] 

Tokenizatin failed for token: nan


3738680it [1:12:45, 964.06it/s] 

Tokenizatin failed for token: nan


3807743it [1:13:56, 967.76it/s] 

Tokenizatin failed for token: nan


3817712it [1:14:06, 1014.24it/s]

Tokenizatin failed for token: nan


3823057it [1:14:12, 1001.00it/s]

Tokenizatin failed for token: nan


3854920it [1:14:44, 919.10it/s] 

Tokenizatin failed for token: nan


3947662it [1:16:18, 908.81it/s] 

Tokenizatin failed for token: nan


3955445it [1:16:26, 991.03it/s] 

Tokenizatin failed for token: nan


3956615it [1:16:27, 971.08it/s] 

Tokenizatin failed for token: nan


4090892it [1:18:42, 981.20it/s] 

Tokenizatin failed for token: nan


4094459it [1:18:46, 918.72it/s] 

Tokenizatin failed for token: nan


4096656it [1:18:48, 1068.68it/s]

Tokenizatin failed for token: nan


4098064it [1:18:50, 964.83it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4099277it [1:18:51, 1008.49it/s]

Tokenizatin failed for token: nan


4100793it [1:18:52, 1068.70it/s]

Tokenizatin failed for token: nan


4101024it [1:18:53, 1051.92it/s]

Tokenizatin failed for token: nan


4125994it [1:19:18, 995.36it/s] 

Tokenizatin failed for token: nan


4129753it [1:19:22, 1012.65it/s]

Tokenizatin failed for token: nan


4175954it [1:20:09, 948.18it/s] 

Tokenizatin failed for token: nan


4184766it [1:20:18, 1009.51it/s]

Tokenizatin failed for token: nan


4192118it [1:20:26, 963.39it/s] 

Tokenizatin failed for token: nan


4195119it [1:20:29, 982.57it/s] 

Tokenizatin failed for token: nan


4199742it [1:20:34, 962.36it/s] 

Tokenizatin failed for token: nan


4271131it [1:21:51, 1008.90it/s]

Tokenizatin failed for token: nan


4329531it [1:22:50, 974.41it/s] 

Tokenizatin failed for token: nan


4371961it [1:23:34, 925.20it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4375632it [1:23:38, 1001.09it/s]

Tokenizatin failed for token: nan


4387163it [1:23:50, 959.58it/s] 

Tokenizatin failed for token: nan


4388158it [1:23:51, 985.63it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4388376it [1:23:51, 986.40it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4388680it [1:23:51, 985.37it/s]

Tokenizatin failed for token: nan


4402800it [1:24:06, 915.70it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4426500it [1:24:31, 978.00it/s] 

Tokenizatin failed for token: nan


4443648it [1:24:49, 984.86it/s] 

Tokenizatin failed for token: nan


4507124it [1:25:54, 942.62it/s] 

Tokenizatin failed for token: nan


4546418it [1:26:35, 991.40it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4547769it [1:26:36, 1034.87it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4572462it [1:27:01, 1000.79it/s]

Tokenizatin failed for token: nan


4589704it [1:27:19, 933.93it/s] 

Tokenizatin failed for token: nan


4644541it [1:28:16, 920.42it/s] 

Tokenizatin failed for token: nan


4650415it [1:28:22, 959.94it/s] 

Tokenizatin failed for token: nan


4658984it [1:28:31, 1009.04it/s]

Tokenizatin failed for token: nan


4679741it [1:28:52, 978.40it/s] 

Tokenizatin failed for token: nan


4682120it [1:28:55, 959.92it/s] 

Tokenizatin failed for token: nan


4693872it [1:29:06, 973.89it/s] 

Tokenizatin failed for token: nan


4732592it [1:29:46, 954.56it/s] 

Tokenizatin failed for token: nan


4743074it [1:29:56, 997.61it/s] 

Tokenizatin failed for token: nan


4776833it [1:30:30, 948.48it/s] 

Tokenizatin failed for token: nan


4782845it [1:30:37, 989.33it/s] 

Tokenizatin failed for token: nan


4816061it [1:31:11, 976.59it/s] 

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4899159it [1:32:35, 1025.51it/s]

Tokenizatin failed for token: nan


4900967it [1:32:37, 976.42it/s] 

Tokenizatin failed for token: nan


4914080it [1:32:50, 904.12it/s] 

Tokenizatin failed for token: nan


4929243it [1:33:05, 960.84it/s] 

Tokenizatin failed for token: nan


4936746it [1:33:12, 926.22it/s] 

Tokenizatin failed for token: nan


4938987it [1:33:16, 889.12it/s]

Tokenizatin failed for token: nan
Tokenizatin failed for token: nan


4982068it [1:33:54, 1052.37it/s] 

Tokenizatin failed for token: nan


5055177it [1:35:07, 930.76it/s] 

Tokenizatin failed for token: nan


5070450it [1:35:22, 1003.61it/s]

Tokenizatin failed for token: nan


5082272it [1:35:34, 1064.36it/s]

Tokenizatin failed for token: nan


5135124it [1:36:25, 1022.19it/s]

Tokenizatin failed for token: nan


5156223it [1:36:47, 984.82it/s] 

Tokenizatin failed for token: nan


5157027it [1:36:47, 972.96it/s]

Tokenizatin failed for token: nan


5186009it [1:37:16, 994.57it/s] 

Tokenizatin failed for token: nan


5225877it [1:37:56, 992.52it/s] 

In [None]:
# Storing Tokenized Data
with open('pickle/tokenized_sentences_list.pkl', 'wb') as f:
    pickle.dump(tokenized_sentences_list, f)
with open('pickle/tokenized_sent_tag_list.pkl', 'wb') as f:
    pickle.dump(tokenized_sent_tag_list, f)
    
# # Loading Processed Tokenized Data
# with open('pickle/tokenized_sentences_list.pkl', 'rb') as f:
#     tokenized_sentences_list = pickle.load(f)
    
# with open('pickle/tokenized_sent_tag_list.pkl', 'rb') as f:
#     tokenized_sent_tag_list = pickle.load(f)

In [None]:
# Mapping tokens to ids
input_ids = []
for tokenized_sentence in tqdm(tokenized_sentences_list):
    input_ids.append(tokenizer.convert_tokens_to_ids(tokenized_sentence))

In [None]:
len(input_ids), len(tokenized_sent_tag_list)

In [None]:
attention_mask = []
for input_ in tqdm(input_ids):
    attention_mask.append(torch.ones(len(input_)))

padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
    attention_mask, 
    batch_first=True, 
    padding_value=0.0)

padded_input_ids = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(input_) for input_ in input_ids], 
    batch_first=True, 
    padding_value=0.0)

padded_tags = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(tag_) for tag_ in tokenized_sent_tag_list], 
    batch_first=True, 
    padding_value=0.0)

In [None]:
class NERBert(nn.Module):
    
    def __init__(self, tag_count=4):
        super(NERBert, self).__init__()
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, tag_count)
    
    def forward(self, input_ids, attention_mask):
        
        output = self.bert(input_ids, attention_mask=attention_mask) # Model gives 'last_hidden_state' and 'pooler_output'
        
        pre_classifier_layer = self.dropout(output.last_hidden_state)
        model_output = self.classifier(pre_classifier_layer)
        
        return model_output

In [None]:
train_tokens, temp_tokens, train_tags, temp_tags, train_mask, temp_mask =  train_test_split(
    padded_input_ids, padded_tags, padded_attention_mask,
    random_state=2018, 
    test_size=0.3
)

val_tokens, test_tokens, val_tags, test_tags, val_mask, test_mask = train_test_split(
    temp_tokens, temp_tags, temp_mask,
    random_state=2018, 
    test_size=0.5
)

In [None]:
batch_size = 32
# wrap tensors
train_data = TensorDataset(train_tokens, train_mask, train_tags)

# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)

# dataLoader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# wrap tensors
val_data = TensorDataset(val_tokens, val_mask, val_tags)

# sampler for sampling the data during training
val_sampler = SequentialSampler(val_data)

# dataLoader for validation set
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)

In [None]:
def train(model, optimizer, loss_criteria, train_dataloader):
    try:
        model.train()
        
        total_loss = 0
        total_logits = []
        
        # iterate over batches
        for step,batch in enumerate(train_dataloader):

            # progress update after every 50 batches.
            if step % 50 == 0 and not step == 0:
                print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))

            # push the batch to gpu
            batch = [r.to(device) for r in batch]

            sent_id, mask, labels = batch

            # clear previously calculated gradients 
            model.zero_grad()        

            # get model predictions for the current batch
            logits = model(sent_id, mask)

            # compute the loss between actual and predicted values
            loss = loss_criteria(logits.permute(0, 2, 1), labels)

            # add on to the total loss
            total_loss = total_loss + loss.item()

            # backward pass to calculate the gradients
            loss.backward()

            # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # update parameters
            optimizer.step()

            # model predictions are stored on GPU. So, push it to CPU
            logits = logits.detach().cpu().numpy()

            # append the model predictions
            total_logits.append(logits)
            
        # compute the training loss of the epoch
        avg_loss = total_loss / len(train_dataloader)
        
        total_logits = np.concatenate(total_logits, axis=0)
        

        return avg_loss, total_logits
    except Exception as e:
        print(f"Error during training the model on line: {sys.exc_info()[2].tb_lineno}")
        print(e)

In [None]:
# function for evaluating the model
def evaluate(model, val_dataloader, loss_criteria):
  
    print("\nEvaluating...")

    # deactivate dropout layers
    model.eval()

    total_loss, total_accuracy = 0, 0

    # empty list to save the model predictions
    total_logits = []

    # iterate over batches
    for step,batch in enumerate(val_dataloader):

        # Progress update every 50 batches.
        if step % 50 == 0 and not step == 0:

            # Calculate elapsed time in minutes.
#             elapsed = format_time(time.time() - t0)

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))

        # push the batch to gpu
        batch = [t.to(device) for t in batch]

        sent_id, mask, labels = batch

        # deactivate autograd
        with torch.no_grad():

            # model predictions
            logits = model(sent_id, mask)

            # compute the validation loss between actual and predicted values
            loss = loss_criteria(logits.permute(0, 2, 1),labels)

            total_loss = total_loss + loss.item()

            logits = logits.detach().cpu().numpy()

            total_logits.append(logits)

    # compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader) 

    # reshape the predictions in form of (number of samples, no. of classes)
    total_logits  = np.concatenate(total_logits, axis=0)

    return avg_loss, total_logits

In [None]:
model = NERBert(tag_count=len(tag2idx))
model = model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5) # learning rate

criterion = nn.CrossEntropyLoss()

In [None]:
epochs=50
# set initial loss to infinite
best_valid_loss = float('inf')

# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]

#for each epoch
for epoch in range(epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    
    #train model
    train_loss, _ = train(
        model = model, 
        optimizer = optimizer, 
        loss_criteria = criterion, 
        train_dataloader = train_dataloader 
        
    )
    
    #evaluate model
    valid_loss, _ = evaluate(
        model = model, 
        val_dataloader = val_dataloader, 
        loss_criteria = criterion
    )
    
    #save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'trained_weights/saved_weights.pt')
    
    # append training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print(f'\nTraining Loss: {train_loss:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')

In [None]:
def get_prediction_from_logits(logits):
    try:
        tag_prob = nn.Softmax(dim=2)(logits)
        tag_prediction = torch.argmax(tag_prob, dim=2).detach().cpu().numpy()
        return tag_prediction
    except Exception as e:
        print(f"Error in line: {sys.exc_info()[2].tb_lineno}")
        print(e)

In [None]:
def classification_result(tag2idx, c_tag_id):
    try:
        prediction_result = []
        for sent_ in c_tag_id:
            prediction_result.append(
                list(map(lambda x: list(tag2idx.keys())[list(tag2idx.values()).index(x)], sent_))
            )
            
        tagged_entity = np.concatenate(prediction_result, axis=0)
        return tagged_entity
    except Exception as e:
        print(f"Error in line: {sys.exc_info()[2].tb_lineno}")
        print(e)        

In [None]:
#load weights of best model
path = 'trained_weights/saved_weights.pt'
model.load_state_dict(torch.load(path))

In [None]:
# get predictions for test data
with torch.no_grad():
    logits = model(test_tokens.to(device), test_mask.to(device))
    preds = get_prediction_from_logits(logits=logits)

In [None]:
test_tags = test_tags.detach().cpu().numpy()

In [None]:
test_tags = classification_result(
    tag2idx = tag2idx, 
    c_tag_id = test_tags
)
preds = classification_result(
    tag2idx = tag2idx, 
    c_tag_id = preds
)

In [None]:
test_tags.shape, preds.shape

In [None]:
print(classification_report(np.squeeze(test_tags.reshape(1, -1)), np.squeeze(preds.reshape(1, -1))))