# BERT NE recognition LOC/ORG/FACILITY

## 1)  Preprocess / format data

In [1]:
from __future__ import absolute_import, division, print_function

In [2]:
import codecs
import glob
import logging
import multiprocessing
import os
import pprint
import re
import keras
import nltk
import spacy

Using TensorFlow backend.


In [3]:
from nltk import ne_chunk, pos_tag
from nltk.tokenize import sent_tokenize, word_tokenize, PunktSentenceTokenizer
from nltk.corpus import stopwords
import gensim.models.word2vec as w2v
from gensim import logging
import sklearn.manifold
import numpy as np
import matplotlib.pyplot as plt
import pandas as  pd
import seaborn as sns
import pickle
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import json
from nltk.tokenize import WhitespaceTokenizer
from nltk.tokenize import MWETokenizer

In [4]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [5]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

In [6]:
import nltk
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/chantana/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

## 2) load keyword: LOC/ORG from text

In [337]:
word_files = ['./keywords/list_hotel.txt']
word_list_name = []
for word_file in word_files:
    with open(word_file) as f:
        lines = f.readlines()
    word_list2 = [x.strip() for x in lines] 
    
    word_list_name.extend(word_list2)
word_list_name[0:5]

['Petpimarn Boutique Resort',
 'Airport Suite Bangkok (Don Muang Airport)',
 'The Riche Residence',
 'Regent Home1 at Donmuang',
 'Charoenpong Apartment']

In [338]:
word_files = ['./keywords/facility.txt']
word_list_fac = []
for word_file in word_files:
    with open(word_file) as f:
        lines = f.readlines()
    word_list2 = [x.strip() for x in lines] 
    
    word_list_fac.extend(word_list2)
word_list_fac[0:5]

['24-hour Receptionist',
 'Laundry service',
 'Luggage storage',
 'Tours',
 'Parking']

In [339]:
word_hotels = ['./keywords/list_location.txt']
lists_location = []
for word_hotel in word_hotels:
    with open(word_hotel) as f:
        lines = f.readlines()
    lists_l= [x.strip() for x in lines] 
    
    lists_location.extend(lists_l)
lists_location[0:5]

['99/89 Moo 6',
 'VIPHAVADEE RANDSIT RD',
 '111/1 Soi Ngam Wong Wan 47 Yaek 1',
 '310/72 Soi Phaholyothin 67/1',
 '3/2899 Soi Yak']

In [340]:
#remove duplicates
def ordered_set(in_list):
    out_list = []
    added = set()
    for val in in_list:
        if not val in added:
            out_list.append(val)
            added.add(val)
    return out_list

In [341]:
word_list_name = ordered_set(word_list_name)

In [342]:
word_list_fac = ordered_set(word_list_fac)

In [343]:
word_location= ordered_set(lists_location)

word_locates = ['../keywords/hotel_location.txt']
word_hotel_location = []
for word_locate in word_locates:
    with open(word_locate) as f:
        lines = f.readlines()
    word_location2 = [x.strip() for x in lines] 
    
    word_hotel_location.extend(word_location2)
word_hotel_location[0:5]

In [344]:
mw_list = []
for x in word_list_name:
    ws = x.split() 
    if len(ws) > 1:
        tws = tuple(ws)
        mw_list.append(tws)
        

for x in word_location:
    ws = x.split() 
    if len(ws) > 1:
        tws = tuple(ws)
        mw_list.append(tws)
        

        
for x in word_list_fac:
    ws = x.split() 
    if len(ws) > 1:
        tws = tuple(ws)
        mw_list.append(tws)
         
            
mw_list = ordered_set(mw_list)
mw_list[0:5]

[('Petpimarn', 'Boutique', 'Resort'),
 ('Airport', 'Suite', 'Bangkok', '(Don', 'Muang', 'Airport)'),
 ('The', 'Riche', 'Residence'),
 ('Regent', 'Home1', 'at', 'Donmuang'),
 ('Charoenpong', 'Apartment')]

In [345]:
my_idx = {}
for i,w in enumerate(mw_list):
    my_idx[" ".join(w)] = i
print(len(my_idx))

35968


In [86]:
for i in mw_list:
    if i == ('M&A','Guesthouse'):
        print('found')
    if i == ('Central','Pattaya'):
        print('found')

found
found


In [88]:
mwe = MWETokenizer(mw_list,separator=' ')

In [89]:
reviews = pd.read_csv("./keywords/all_loc_fac.csv")

In [90]:
reviews.shape

(26541, 1)

In [91]:
reviews.head()

Unnamed: 0,location
0,Staying at @ Home Executive Apartment is a goo...
1,Staying at @ Oasis Resort is a good choice whe...
2,Staying at @ T Boutique Hotel is a good choice...
3,Staying at AC Resort is a good choice when you...
4,Staying at Acqua Condominium is a good choice ...


In [92]:
reviews = reviews.dropna()
reviews = reviews.reset_index(drop=True)

In [93]:
def spans(txt):
    tokens=mwe.tokenize(word_tokenize(txt))
    offset = 0
    for token in tokens:
        offset = txt.find(token, offset)
        yield token, offset, offset+len(token)
        offset += len(token)

In [117]:
from tqdm import tqdm_notebook as tqdm
all_item = []

for i in tqdm(range(len(reviews))):
    word_ls = []
    #print("sentence:",i+1)
    
    for token in spans(reviews.location[i]):
        #print(token)
        #assert token[0]==reviews.location[i][token[1]:token[2]]
        my_tuple = token[0]
        #print("###",token)
        
        #my_tuples = ' , '.join(map(str, my_tuple))
        if token[0] in word_location:
            #word_ls.append(my_tuple)
            subwords = token[0].split()
            pos_list = [nltk.pos_tag([w]) for w in subwords]
            tag_list = ['I-LOC']*len(pos_list)
            tag_list[0] = 'B-LOC' 
            
            
            for s,p,t in zip(subwords,pos_list,tag_list):           
                if type(p) == list:
                    p = p[0][1]
                    #print('list')
                new_item = dict({'Sentence #': i+1, 'Tag' : t, 'Word': s,'POS': p})
                all_item.append(new_item)
                
                
            #lis_lo = nltk.pos_tag(word_ls),LOC
            #print(' , '.join(map(str, lis_lo)))
 
        elif token[0] in word_list_name:
            
            subwords = token[0].split()
            pos_list = [nltk.pos_tag([w]) for w in subwords]
            tag_list = ['I-ORG']*len(pos_list)
            tag_list[0] = 'B-ORG' 
            
            #print(subwords)
            for s,p,t in zip(subwords,pos_list,tag_list):           
                new_item = dict({'Sentence #': i+1, 'Tag' : t, 'Word': s,'POS': p[0][1]})
                all_item.append(new_item)
                #print(new_item)
            #word_ls.append(my_tuple)
            #print("found LOC")
            #lis_lo = nltk.pos_tag(word_ls),"LOC"
            #print(' , '.join(map(str, lis_lo)))
             
        elif token[0] in word_list_fac:
            #word_ls.append(my_tuple)
            #lis_lo = nltk.pos_tag(word_ls),"FACILITY"
            #print(' , '.join(map(str, lis_lo)))
            
            #my_pos = nltk.pos_tag([my_tuple])[0][1]
            #new_item =  dict({'Sentence #': i+1, 'Tag' : 'FACILITY', 'Word': my_tuple,'POS': my_pos})
            
            subwords = token[0].split()
            pos_list = [nltk.pos_tag([w]) for w in subwords]
            tag_list = ['I-FAC']*len(pos_list)
            tag_list[0] = 'B-FAC' 
            
             
            for s,p,t in zip(subwords,pos_list,tag_list):   
                if type(p) == list:
                    p = p[0][1]
                new_item = dict({'Sentence #': i+1, 'Tag' : t, 'Word': s,'POS': p})
                all_item.append(new_item)
                

             
        else:
            #print(type(my_tuple))
            my_pos = nltk.pos_tag([my_tuple])[0][1]
            new_item = dict({'Sentence #': i+1, 'Tag' : 'O', 'Word': my_tuple.lower(),'POS': my_pos})
            all_item.append(new_item)
            
            #if i+1 == 5177:
                #print(new_item)
             
            #print(nltk.pos_tag([my_tuple]),',','O')
        #print(new_item)
        #if not(my_pos == '.' or  my_pos == ',' or my_pos == ':' or my_pos == '(' or my_pos == ')') :
            
            #all_item.append(new_item) 

        

        
    

          

HBox(children=(IntProgress(value=0, max=26541), HTML(value=u'')))

## 3) save preproc data

In [119]:
file_csv = 'data-loc-fac-bert.csv'
with open(file_csv, 'w') as csv_file:
    csv_file.write('Sentence #*Word*POS*Tag\n')
    for item in all_item:
        
        #print(item['POS'])    
        csv_file.write(str(item['Sentence #'])+'*'+item['Word']+'*'+item['POS']+'*'+item['Tag']+'\n')
csv_file.close()

## 4) Load preprocessed data

In [7]:
import pandas as pd
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
#load data
# explore data
file_csv = 'data-loc-fac-bert.csv'
data = pd.read_csv(file_csv, sep='*',encoding="utf-8").fillna(method="ffill")

data.head(10)
 


Unnamed: 0,Sentence #,Word,POS,Tag
0,1,staying,VBG,O
1,1,at,IN,O
2,1,@,NN,B-ORG
3,1,Home,NN,I-ORG
4,1,Executive,NN,I-ORG
5,1,Apartment,NN,I-ORG
6,1,is,VBZ,O
7,1,a,DT,O
8,1,good,JJ,O
9,1,choice,NN,O


for i in range(100):
    locate_lis = []
    print("sentence:",i+1)
    for token in spans(reviews.location[i]):
        #print(token)
        #assert token[0]==reviews.location[i][token[1]:token[2]]
        my_tuple = token[0]
        print(token)
        #my_tuples = ' , '.join(map(str, my_tuple))
        if token[0] in word_location:
            locate_lis.append(my_tuple)
            lis_lo1 = "\n".join(map(str, (nltk.pos_tag(locate_lis))))
            print(lis_lo1,',',"LOC")
            continue
        print(nltk.pos_tag([my_tuple]),',','O')

### Load word2vec representation


In [20]:
filename_50 ="trained/hotel2vec-desc-50.w2v"
filename_g300 ="trained/hotel2vec-glove-desc-300.w2v"
filename_num ="trained/hotel2vec-num-desc-300.w2v"
filename_g50 ="trained/hotel2vec-glove-desc-50.w2v"
filename_gg ="trained/hotel2vec-gg-desc-300.w2v"

In [21]:
file_num = './model/numberbatch-en-17.06.txt' 
file_gg = 'GoogleNews-vectors-negative300.bin'
file_glove_50 = 'glove.6B.300d.txt'
file_glove_300 = 'glove.6B.300d.txt'

In [22]:
#load based model
filename = file_glove_50

embeddings_index = {}
seq = 0
with open(filename) as f:
#with open('./model/numberbatch-en-17.06.txt', encoding='utf-8') as f:
    for line in f:
        values = line.split(' ')
        #print(values)
        word = values[0].decode('utf-8')
        embedding = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = (seq,embedding)
        seq += 1

print('Word embeddings:', len(embeddings_index))

#load desc model
filename = filename_g50



model_d =  w2v.Word2Vec.load(filename)
word_vectors = model_d.wv

vocab_gg = word_vectors.vocab
 
for v in vocab_gg:
     
    word = v 
    embedding = np.asarray(model_d.wv[v], dtype='float32')
    embeddings_index[word] = (seq,embedding)
    seq +=1

    
print('Word embeddings:', len(embeddings_index))



2019-03-16 13:42:41,552 : INFO : loading Word2Vec object from trained/hotel2vec-glove-desc-50.w2v
2019-03-16 13:42:41,564 : INFO : loading vocabulary recursively from trained/hotel2vec-glove-desc-50.w2v.vocabulary.* with mmap=None
2019-03-16 13:42:41,567 : INFO : loading wv recursively from trained/hotel2vec-glove-desc-50.w2v.wv.* with mmap=None
2019-03-16 13:42:41,569 : INFO : setting ignored attribute vectors_norm to None
2019-03-16 13:42:41,576 : INFO : loading trainables recursively from trained/hotel2vec-glove-desc-50.w2v.trainables.* with mmap=None
2019-03-16 13:42:41,581 : INFO : setting ignored attribute cum_table to None
2019-03-16 13:42:41,584 : INFO : loaded trained/hotel2vec-glove-desc-50.w2v


Word embeddings: 400000
Word embeddings: 401831


### 5) BERT trained-formating

In [9]:
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                           s["POS"].values.tolist(),
                                                           s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

In [10]:
 
#concat sentence
getter = SentenceGetter(data)

In [377]:
word_list = [ [s[0] for s in sent] for sent in getter.sentences] 

In [11]:
sentences = [" ".join([s[0] for s in sent]) for sent in getter.sentences]
sentences[0]

u'staying at @ Home Executive Apartment is a good choice when you are visiting Central Pattaya .'

In [379]:
print(len(sentences))

26541


In [12]:
labels = [[s[2] for s in sent] for sent in getter.sentences]
print(len(labels[0]))

17


In [351]:
print(len(word_list[0]))

17


In [352]:
print(len(labels))

26541


In [13]:
tags_vals = list(set(data["Tag"].values))
tag2idx = {t: i for i, t in enumerate(tags_vals)}
idx2tag = {i: t for i, t in enumerate(tags_vals) }

In [128]:
tags_vals_save = tags_vals

In [129]:
idx2tag[2]

u'B-ORG'

In [92]:
def convert_word2idx (text_list):
    
    values = []
    #print(text_list)
     
    for word in text_list:
        if word in my_idx:
            values.append(my_idx[word])
        else:
            if word.lower() in embeddings_index:
                v = len(my_idx)+embeddings_index[word.lower()][0] 
            else:
                val = list(embeddings_index.keys())[0]
                embeddings_index[word.lower()] = (len(embeddings_index),[0.5]*len(val[1]))
                v = len(my_idx) +len(embeddings_index)
            values.append(v)
            
    return values

def convert_idx2word (idx_list):
    words_list = []
    for idx in idx_list:
        found = False
        for key, value in my_idx.items(): 
            if idx == value:
                words_list.append(key)
                found = True
                break
        if not found:
            for key, value in embeddings_index.items(): 
                if idx == value[0]+len(my_idx):
                    words_list.append(key)

                    break

            
         
        
    return words_list


# Training

In [14]:
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [70]:
MAX_LEN = 120
bs = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0) 

u'TITAN Xp'

In [71]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

2019-03-18 13:42:57,328 : INFO : loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/chantana/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


mwe = MWETokenizer(mw_list,separator=' ')

tokenized_texts = [mwe.tokenize(word_tokenize(sent)) for sent in sentences]


print(tokenized_texts[0],len(tokenized_texts[0]))

In [72]:
tokenized_texts = [ tokenizer.tokenize(sent)   for sent in sentences]
print(tokenized_texts[0])

[u'staying', u'at', u'@', u'home', u'executive', u'apartment', u'is', u'a', u'good', u'choice', u'when', u'you', u'are', u'visiting', u'central', u'pat', u'##ta', u'##ya', u'.']


In [73]:
sentences[0]

u'staying at @ Home Executive Apartment is a good choice when you are visiting Central Pattaya .'

In [74]:
print(tokenized_texts[0])

[u'staying', u'at', u'@', u'home', u'executive', u'apartment', u'is', u'a', u'good', u'choice', u'when', u'you', u'are', u'visiting', u'central', u'pat', u'##ta', u'##ya', u'.']


### Load embedding

In [75]:
from nltk.stem import PorterStemmer 

In [76]:
ps = PorterStemmer() 

In [77]:
def convert_word2idx (text_list):
    
    values = []
    #print(text_list)
    start = len(tokenizer.vocab)
    for word in text_list:
        
        if tokenizer.vocab.has_key(word):
            values.append(tokenizer.vocab[word])
        else:
            #print(word,"not found")
            tokenizer.vocab[word] = start
            values.append(tokenizer.vocab[word])
            start = start + 1
        
            
    return values


In [78]:
len(tokenizer.vocab)

30522

In [79]:
tokenizer_save =  tokenizer


In [80]:
tokenizer.vocab 

OrderedDict([(u'[PAD]', 0),
             (u'[unused0]', 1),
             (u'[unused1]', 2),
             (u'[unused2]', 3),
             (u'[unused3]', 4),
             (u'[unused4]', 5),
             (u'[unused5]', 6),
             (u'[unused6]', 7),
             (u'[unused7]', 8),
             (u'[unused8]', 9),
             (u'[unused9]', 10),
             (u'[unused10]', 11),
             (u'[unused11]', 12),
             (u'[unused12]', 13),
             (u'[unused13]', 14),
             (u'[unused14]', 15),
             (u'[unused15]', 16),
             (u'[unused16]', 17),
             (u'[unused17]', 18),
             (u'[unused18]', 19),
             (u'[unused19]', 20),
             (u'[unused20]', 21),
             (u'[unused21]', 22),
             (u'[unused22]', 23),
             (u'[unused23]', 24),
             (u'[unused24]', 25),
             (u'[unused25]', 26),
             (u'[unused26]', 27),
             (u'[unused27]', 28),
             (u'[unused28]', 29),
     

In [81]:
print(mw_list[:5])
for i in mw_list:
    if " ".join(i) == '@ Home Executive Apartment':
        print('found')

NameError: name 'mw_list' is not defined

In [227]:
def append_vocab(vocab_tuple_list):
    start = len(tokenizer.vocab)
    
    for v in vocab_tuple_list:
        ww = " ".join(v)
        if ww not in tokenizer.vocab:
            #print(v)
            tokenizer.vocab[ww] = start
            print(ww)
            start += 1
    return tokenizer.vocab


In [21]:
tokens_ids = [  tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts]


In [320]:
tokenizer.vocab['Pattya']

41276

In [321]:
 print(len(tokens_ids[0]),len(labels[0]))
 

17 17


In [325]:
#input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
input_ids = pad_sequences(tokens_ids,
                          maxlen=MAX_LEN, dtype="int64", truncating="post", padding="post")
#input_ids = pad_sequences([convert_word2idx(txt) for txt in tokenized_texts],
                          #maxlen=MAX_LEN, dtype="int64", truncating="post", padding="post")

In [82]:
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                          maxlen=MAX_LEN, dtype="int64", truncating="post", padding="post")


In [83]:
print(MAX_LEN)
for i in tokens_ids:
    if len(i) > MAX_LEN:
        #print(tokens_ids)
        print("need more")
        MAX_LEN = len(i)
print(MAX_LEN)

120
120


In [84]:
t_list = [[tag2idx.get(l) for l in lab] for lab in labels]
len(t_list[0])

17

In [85]:
print(input_ids[0])

[6595 2012 1030 2188 3237 4545 2003 1037 2204 3601 2043 2017 2024 5873
 2430 6986 2696 3148 1012    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0]


In [86]:
tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],
                     maxlen=MAX_LEN, value=tag2idx["O"], padding="post",
                     dtype="int64", truncating="post")

In [87]:
attention_masks = [[float(i>0) for i in ii] for ii in input_ids]

In [88]:
#split train test
tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids, tags, 
                                                            random_state=2018, test_size=0.3)
tr_masks, val_masks, _, _ = train_test_split(attention_masks, input_ids,
                                             random_state=2018, test_size=0.3)

In [89]:
#change to torhc tensor
tr_inputs = torch.tensor(tr_inputs)
val_inputs = torch.tensor(val_inputs)
tr_tags = torch.tensor(tr_tags)
val_tags = torch.tensor(val_tags)
tr_masks = torch.tensor(tr_masks)
val_masks = torch.tensor(val_masks)

In [90]:
print(tr_inputs.shape)
print(tr_masks.shape)
print(tr_tags.shape)

torch.Size([18578, 120])
torch.Size([18578, 120])
torch.Size([18578, 120])


In [91]:
print(val_inputs.shape)
print(val_masks.shape)
print(val_tags.shape)

torch.Size([7963, 120])
torch.Size([7963, 120])
torch.Size([7963, 120])


In [92]:
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [93]:
print(len(tag2idx))

7


In [94]:
model = BertForTokenClassification.from_pretrained(u"bert-base-uncased", num_labels=len(tag2idx))

2019-03-18 13:44:06,944 : INFO : loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/chantana/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
2019-03-18 13:44:06,948 : INFO : extracting archive file /home/chantana/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpyWtLQx
2019-03-18 13:44:10,503 : INFO : Model config {
  "attention_probs_dropout_prob": 0.1, 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "type_vocab_size": 2, 
  "vocab_size": 30522
}

2019-03-18 13:44:13,586 : INFO : Weights of BertForTokenClass

In [95]:
model.cuda();

In [96]:
FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters()) 
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)

In [97]:
#Fine tune BERT
from seqeval.metrics import f1_score

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat) 

In [98]:
epochs = 5
max_grad_norm = 1.0

for _ in trange(epochs, desc="Epoch"):
    # TRAIN loop
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # forward pass
        loss = model(b_input_ids, token_type_ids=None,
                     attention_mask=b_input_mask, labels=b_labels)
        # backward pass
        loss.backward()
        # track train loss
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        model.zero_grad()
    # print train loss per epoch
    print("Train loss: {}".format(tr_loss/nb_tr_steps))
    # VALIDATION on validation set
    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                                  attention_mask=b_input_mask, labels=b_labels)
            logits = model(b_input_ids, token_type_ids=None,
                           attention_mask=b_input_mask)
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.append(label_ids)
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy
        
        nb_eval_examples += b_input_ids.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss/nb_eval_steps
    print("Validation loss: {}".format(eval_loss))
    print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]
    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]
    print("F1-Score: {}".format(f1_score(pred_tags, valid_tags)))





Epoch:   0%|          | 0/5 [00:00<?, ?it/s][A[A[A[A

Train loss: 0.381081105619
Validation loss: 0.255078543651
Validation Accuracy: 0.800869450022






Epoch:  20%|██        | 1/5 [08:23<33:34, 503.75s/it][A[A[A[A

F1-Score: 0.267616214074
Train loss: 0.221693574455
Validation loss: 0.171089398483
Validation Accuracy: 0.833422579206






Epoch:  40%|████      | 2/5 [27:59<35:16, 705.39s/it][A[A[A[A

F1-Score: 0.366658020763
Train loss: 0.169982631978
Validation loss: 0.155627313331
Validation Accuracy: 0.840877398483






Epoch:  60%|██████    | 3/5 [36:43<21:41, 650.88s/it][A[A[A[A

F1-Score: 0.384366901261
Train loss: 0.141815415096
Validation loss: 0.139297030737
Validation Accuracy: 0.842304635207






Epoch:  80%|████████  | 4/5 [44:12<09:50, 590.32s/it][A[A[A[A

F1-Score: 0.391983509768
Train loss: 0.122520411009
Validation loss: 0.1095122506
Validation Accuracy: 0.866085174029






Epoch: 100%|██████████| 5/5 [50:58<00:00, 535.07s/it][A[A[A[A

F1-Score: 0.464652200901


In [99]:
#evaluate model
model.eval()
predictions = []
true_labels = []
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
print(len(valid_dataloader))
for batch in tqdm(valid_dataloader):
    #print(len(batch))
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch

    with torch.no_grad():
        tmp_eval_loss = model(b_input_ids, token_type_ids=None,
                              attention_mask=b_input_mask, labels=b_labels)
        logits = model(b_input_ids, token_type_ids=None,
                       attention_mask=b_input_mask)
        
    logits = logits.detach().cpu().numpy()
    predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
    label_ids = b_labels.to('cpu').numpy()
    true_labels.append(label_ids)
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)

    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy

    nb_eval_examples += b_input_ids.size(0)
    nb_eval_steps += 1

pred_tags = [[tags_vals[p_i] for p_i in p] for p in predictions]
valid_tags = [[tags_vals[l_ii] for l_ii in l_i] for l in true_labels for l_i in l ]
print("Validation loss: {}".format(eval_loss/nb_eval_steps))
print("Validation Accuracy: {}".format(eval_accuracy/nb_eval_steps))
print("Validation F1-Score: {}".format(f1_score(pred_tags, valid_tags)))





  0%|          | 0/996 [00:00<?, ?it/s][A[A[A[A

996






  0%|          | 1/996 [00:00<05:22,  3.09it/s][A[A[A[A



  0%|          | 2/996 [00:00<05:24,  3.06it/s][A[A[A[A



  0%|          | 3/996 [00:00<05:24,  3.06it/s][A[A[A[A



  0%|          | 4/996 [00:01<05:11,  3.18it/s][A[A[A[A



  1%|          | 5/996 [00:01<05:03,  3.27it/s][A[A[A[A



  1%|          | 6/996 [00:01<05:01,  3.28it/s][A[A[A[A



  1%|          | 7/996 [00:02<05:06,  3.23it/s][A[A[A[A



  1%|          | 8/996 [00:02<05:09,  3.19it/s][A[A[A[A



  1%|          | 9/996 [00:02<04:51,  3.39it/s][A[A[A[A



  1%|          | 10/996 [00:03<04:56,  3.33it/s][A[A[A[A



  1%|          | 11/996 [00:03<04:59,  3.29it/s][A[A[A[A



  1%|          | 12/996 [00:03<05:00,  3.28it/s][A[A[A[A



  1%|▏         | 13/996 [00:03<04:55,  3.33it/s][A[A[A[A



  1%|▏         | 14/996 [00:04<04:55,  3.33it/s][A[A[A[A



  2%|▏         | 15/996 [00:04<04:59,  3.28it/s][A[A[A[A



  2%|▏         | 16/996 [00:04<04:57,  3.30it

 45%|████▍     | 447/996 [00:46<00:39, 14.07it/s][A[A[A[A



 45%|████▌     | 449/996 [00:46<00:38, 14.17it/s][A[A[A[A



 45%|████▌     | 451/996 [00:46<00:38, 14.02it/s][A[A[A[A



 45%|████▌     | 453/996 [00:46<00:38, 14.02it/s][A[A[A[A



 46%|████▌     | 455/996 [00:46<00:38, 14.05it/s][A[A[A[A



 46%|████▌     | 457/996 [00:46<00:38, 14.04it/s][A[A[A[A



 46%|████▌     | 459/996 [00:46<00:37, 14.15it/s][A[A[A[A



 46%|████▋     | 461/996 [00:47<00:37, 14.11it/s][A[A[A[A



 46%|████▋     | 463/996 [00:47<00:37, 14.20it/s][A[A[A[A



 47%|████▋     | 465/996 [00:47<00:36, 14.38it/s][A[A[A[A



 47%|████▋     | 467/996 [00:47<00:36, 14.37it/s][A[A[A[A



 47%|████▋     | 469/996 [00:47<00:36, 14.40it/s][A[A[A[A



 47%|████▋     | 471/996 [00:47<00:38, 13.70it/s][A[A[A[A



 47%|████▋     | 473/996 [00:47<00:38, 13.69it/s][A[A[A[A



 48%|████▊     | 475/996 [00:48<00:37, 14.00it/s][A[A[A[A



 48%|████▊     | 477/996 

 95%|█████████▌| 951/996 [01:23<00:03, 12.44it/s][A[A[A[A



 96%|█████████▌| 953/996 [01:23<00:03, 12.40it/s][A[A[A[A



 96%|█████████▌| 955/996 [01:24<00:03, 12.32it/s][A[A[A[A



 96%|█████████▌| 957/996 [01:24<00:03, 12.75it/s][A[A[A[A



 96%|█████████▋| 959/996 [01:24<00:02, 13.14it/s][A[A[A[A



 96%|█████████▋| 961/996 [01:24<00:02, 13.24it/s][A[A[A[A



 97%|█████████▋| 963/996 [01:24<00:02, 13.03it/s][A[A[A[A



 97%|█████████▋| 965/996 [01:24<00:02, 13.19it/s][A[A[A[A



 97%|█████████▋| 967/996 [01:24<00:02, 13.23it/s][A[A[A[A



 97%|█████████▋| 969/996 [01:25<00:01, 13.57it/s][A[A[A[A



 97%|█████████▋| 971/996 [01:25<00:01, 13.92it/s][A[A[A[A



 98%|█████████▊| 973/996 [01:25<00:01, 14.16it/s][A[A[A[A



 98%|█████████▊| 975/996 [01:25<00:01, 14.34it/s][A[A[A[A



 98%|█████████▊| 977/996 [01:25<00:01, 13.75it/s][A[A[A[A



 98%|█████████▊| 979/996 [01:25<00:01, 13.83it/s][A[A[A[A



 98%|█████████▊| 981/996 

Validation loss: 0.1095122506
Validation Accuracy: 0.866085174029
Validation F1-Score: 0.464652200901


In [293]:
pred_tags

[[u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O'],
 [u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
 

for lab in b_labels:
    for l in lab:
        print(idx2tag[int(l)])
        

In [384]:
[[ idx2tag[int(l)] for l in lab] for lab in b_labels]

[[u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O'],
 [u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'ORG',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',
  u'O',