In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from datasets import load_dataset
from collections import Counter
from conlleval import evaluate
import tensorflow_hub as hub
import tensorflow_text as text
import pandas


  from .autonotebook import tqdm as notebook_tqdm


## Bert Model

In [2]:
tfhub_handle_encoder="https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
tfhub_handle_preprocess="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)

2023-06-09 15:28:32.977507: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-09 15:28:32.990950: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-09 15:28:32.992694: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-09 15:28:32.995179: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

## Preprocess data for bert model


In [3]:
conll_data = load_dataset("conll2003")

100%|██████████| 3/3 [00:00<00:00, 497.62it/s]


In [4]:
def make_tag_lookup_table():
    iob_labels = ["B", "I"]
    ner_labels = ["PER", "ORG", "LOC", "MISC"]
    all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]
    all_labels = ["-".join([a, b]) for a, b in all_labels]
    all_labels = ["[PAD]", "O"] + all_labels
    return dict(zip(range(0, len(all_labels) + 1), all_labels))

In [5]:
mapping = make_tag_lookup_table()
print(mapping)


{0: '[PAD]', 1: 'O', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-ORG', 6: 'B-LOC', 7: 'I-LOC', 8: 'B-MISC', 9: 'I-MISC'}


In [6]:
num_tags = len(mapping)

## One time costly preprocessing

In [7]:
# Uncomment for use

# def add_bert_len(pd_df):
#     bert_len = []
#     for id in pd_df.index:
#         tokens = pd_df['tokens'][id]
#         bert_len.append([np.sum(out)-2 for out in bert_preprocess_model(tokens)['input_mask']])
    
#     pd_df['bert_len'] = bert_len
#     return pd_df

# pd_train = conll_data["train"].data.to_pandas()
# pd_train_b_upd = add_bert_len(pd_train)   
# pd_train_b_upd.to_pickle("bert_data/train.csv")

# pd_val = conll_data["validation"].data.to_pandas()
# pd_val_b_upd = add_bert_len(pd_val)  
# pd_val_b_upd.to_pickle("bert_data/val.csv")

In [8]:
pd_train_b_upd_imp = pandas.read_pickle("bert_data/train.csv")
pd_train_b_upd_imp

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags,bert_len
0,0,"[EU, rejects, German, call, to, boycott, Briti...","[22, 42, 16, 21, 35, 37, 16, 21, 7]","[11, 21, 11, 12, 21, 22, 11, 12, 0]","[3, 0, 7, 0, 0, 0, 7, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1]"
1,1,"[Peter, Blackburn]","[22, 22]","[11, 12]","[1, 2]","[1, 1]"
2,2,"[BRUSSELS, 1996-08-22]","[22, 11]","[11, 12]","[5, 0]","[1, 5]"
3,3,"[The, European, Commission, said, on, Thursday...","[12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 3...","[11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 1...","[0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, ..."
4,4,"[Germany, 's, representative, to, the, Europea...","[22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 2...","[11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 1...","[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, ...","[1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, ..."
...,...,...,...,...,...,...
14036,14036,"[on, Friday, :]","[15, 22, 8]","[13, 11, 0]","[0, 0, 0]","[1, 1, 1]"
14037,14037,"[Division, two]","[21, 11]","[11, 12]","[0, 0]","[1, 1]"
14038,14038,"[Plymouth, 2, Preston, 1]","[21, 11, 22, 11]","[11, 12, 12, 12]","[3, 0, 3, 0]","[1, 1, 1, 1]"
14039,14039,"[Division, three]","[21, 11]","[11, 12]","[0, 0]","[1, 1]"


In [9]:
pd_val_b_upd_imp = pandas.read_pickle("bert_data/val.csv")
pd_val_b_upd_imp

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags,bert_len
0,0,"[CRICKET, -, LEICESTERSHIRE, TAKE, OVER, AT, T...","[22, 8, 22, 22, 15, 22, 22, 22, 22, 21, 7]","[11, 0, 11, 12, 13, 11, 12, 12, 12, 12, 0]","[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
1,1,"[LONDON, 1996-08-30]","[22, 11]","[11, 12]","[5, 0]","[1, 5]"
2,2,"[West, Indian, all-rounder, Phil, Simmons, too...","[22, 22, 21, 22, 22, 38, 11, 15, 11, 15, 22, 1...","[11, 12, 12, 12, 12, 21, 11, 13, 11, 13, 11, 1...","[7, 8, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 3, ...","[1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,3,"[Their, stay, on, top, ,, though, ,, may, be, ...","[29, 21, 15, 21, 6, 30, 6, 20, 37, 16, 15, 21,...","[11, 12, 13, 11, 0, 3, 0, 21, 22, 1, 13, 11, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, ..."
4,4,"[After, bowling, Somerset, out, for, 83, on, t...","[15, 39, 22, 33, 15, 11, 15, 12, 21, 21, 15, 2...","[13, 11, 12, 15, 13, 11, 13, 11, 12, 12, 13, 1...","[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 3, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...
3245,3245,"[But, the, prices, may, move, in, a, close, ra...","[10, 12, 24, 20, 37, 15, 12, 16, 21, 39, 12, 4...","[0, 11, 12, 21, 22, 13, 11, 12, 12, 13, 11, 12...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1]"
3246,3246,"[Brokers, said, blue, chips, like, IDLC, ,, Ba...","[24, 38, 16, 24, 15, 22, 6, 22, 22, 6, 22, 22,...","[11, 21, 11, 12, 13, 11, 0, 11, 12, 0, 11, 12,...","[0, 0, 0, 0, 0, 3, 0, 3, 4, 0, 3, 4, 0, 3, 4, ...","[2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 3, 1, 1, 1, 1, ..."
3247,3247,"[They, said, there, was, still, demand, for, b...","[28, 38, 13, 38, 30, 37, 15, 16, 24, 15, 21, 2...","[11, 21, 11, 21, 22, 22, 13, 11, 12, 13, 11, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3248,3248,"[The, DSE, all, share, price, index, closed, 2...","[12, 21, 12, 21, 21, 21, 38, 11, 24, 10, 11, 2...","[11, 12, 11, 12, 12, 12, 21, 11, 12, 0, 11, 12...","[0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 2, 1, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 1, 5, ..."


In [10]:
def generate_ner_tags_for_bert(pd_df, max_seq_len):
    ner_tags_bert = []
    sentence = []
    for id in pd_df.index :
        ner_tags_b = []
        ner_tags_col = pd_df['ner_tags'][id]
        bert_len_col = pd_df['bert_len'][id]

        # {0: '[PAD]', 1: 'O', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-ORG', 6: 'B-LOC', 7: 'I-LOC', 8: 'B-MISC', 9: 'I-MISC'}        
        b2I_mappig = [0,1,3,3,5,5,7,7,9,9]
        
        for t_id,tag in enumerate(ner_tags_col):
            tag = tag+1
            x = [tag] +  [b2I_mappig[tag]] * (bert_len_col[t_id] - 1)
            ner_tags_b +=x
        ner_tags_b = [0] + ner_tags_b + [0]    
        ner_tags_b += [0] * (max_seq_len - len(ner_tags_b))
        ner_tags_np = np.array([int(x) for x in ner_tags_b])
        ner_tags_bert.append(ner_tags_np)
        if len(ner_tags_b) > 128:
            print(str(id) + "-" + str(len(ner_tags_b)))
        tokens_col = pd_df['tokens'][id]
        sentence.append(" ".join(tokens_col))
    
    pd_df['ner_tags_bert'] = ner_tags_bert
    pd_df['sentence'] = sentence    
    return pd_df[pd_df['ner_tags_bert'].apply(lambda x: len(x)) == 128]


In [11]:
pd_train_upd = generate_ner_tags_for_bert(pd_train_b_upd_imp, max_seq_len=128)
pd_train_upd

13068-164


Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags,bert_len,ner_tags_bert,sentence
0,0,"[EU, rejects, German, call, to, boycott, Briti...","[22, 42, 16, 21, 35, 37, 16, 21, 7]","[11, 21, 11, 12, 21, 22, 11, 12, 0]","[3, 0, 7, 0, 0, 0, 7, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1]","[0, 4, 1, 8, 1, 1, 1, 8, 1, 1, 0, 0, 0, 0, 0, ...",EU rejects German call to boycott British lamb .
1,1,"[Peter, Blackburn]","[22, 22]","[11, 12]","[1, 2]","[1, 1]","[0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Peter Blackburn
2,2,"[BRUSSELS, 1996-08-22]","[22, 11]","[11, 12]","[5, 0]","[1, 5]","[0, 6, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",BRUSSELS 1996-08-22
3,3,"[The, European, Commission, said, on, Thursday...","[12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 3...","[11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 1...","[0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, ...","[0, 1, 4, 5, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, ...",The European Commission said on Thursday it di...
4,4,"[Germany, 's, representative, to, the, Europea...","[22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 2...","[11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 1...","[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, ...","[1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, ...","[0, 6, 1, 1, 1, 1, 1, 4, 5, 1, 1, 1, 1, 2, 3, ...",Germany 's representative to the European Unio...
...,...,...,...,...,...,...,...,...
14036,14036,"[on, Friday, :]","[15, 22, 8]","[13, 11, 0]","[0, 0, 0]","[1, 1, 1]","[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",on Friday :
14037,14037,"[Division, two]","[21, 11]","[11, 12]","[0, 0]","[1, 1]","[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Division two
14038,14038,"[Plymouth, 2, Preston, 1]","[21, 11, 22, 11]","[11, 12, 12, 12]","[3, 0, 3, 0]","[1, 1, 1, 1]","[0, 4, 1, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Plymouth 2 Preston 1
14039,14039,"[Division, three]","[21, 11]","[11, 12]","[0, 0]","[1, 1]","[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Division three


In [12]:
pd_val_upd = generate_ner_tags_for_bert(pd_val_b_upd_imp, max_seq_len=128)
pd_val_upd

2184-146
2594-132
2595-133


Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags,bert_len,ner_tags_bert,sentence
0,0,"[CRICKET, -, LEICESTERSHIRE, TAKE, OVER, AT, T...","[22, 8, 22, 22, 15, 22, 22, 22, 22, 21, 7]","[11, 0, 11, 12, 13, 11, 12, 12, 12, 12, 0]","[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[0, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...",CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTE...
1,1,"[LONDON, 1996-08-30]","[22, 11]","[11, 12]","[5, 0]","[1, 5]","[0, 6, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",LONDON 1996-08-30
2,2,"[West, Indian, all-rounder, Phil, Simmons, too...","[22, 22, 21, 22, 22, 38, 11, 15, 11, 15, 22, 1...","[11, 12, 12, 12, 12, 21, 11, 13, 11, 13, 11, 1...","[7, 8, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 3, ...","[1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 8, 9, 1, 1, 1, 1, 2, 3, 1, 1, 1, 1, 1, 1, ...",West Indian all-rounder Phil Simmons took four...
3,3,"[Their, stay, on, top, ,, though, ,, may, be, ...","[29, 21, 15, 21, 6, 30, 6, 20, 37, 16, 15, 21,...","[11, 12, 13, 11, 0, 3, 0, 21, 22, 1, 13, 11, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","Their stay on top , though , may be short-live..."
4,4,"[After, bowling, Somerset, out, for, 83, on, t...","[15, 39, 22, 33, 15, 11, 15, 12, 21, 21, 15, 2...","[13, 11, 12, 15, 13, 11, 13, 11, 12, 12, 13, 1...","[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 3, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 6, 7, 1, ...",After bowling Somerset out for 83 on the openi...
...,...,...,...,...,...,...,...,...
3245,3245,"[But, the, prices, may, move, in, a, close, ra...","[10, 12, 24, 20, 37, 15, 12, 16, 21, 39, 12, 4...","[0, 11, 12, 21, 22, 13, 11, 12, 12, 13, 11, 12...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1]","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",But the prices may move in a close range follo...
3246,3246,"[Brokers, said, blue, chips, like, IDLC, ,, Ba...","[24, 38, 16, 24, 15, 22, 6, 22, 22, 6, 22, 22,...","[11, 21, 11, 12, 13, 11, 0, 11, 12, 0, 11, 12,...","[0, 0, 0, 0, 0, 3, 0, 3, 4, 0, 3, 4, 0, 3, 4, ...","[2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 3, 1, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 4, 5, 1, 4, 5, 1, 4, 5, ...","Brokers said blue chips like IDLC , Bangladesh..."
3247,3247,"[They, said, there, was, still, demand, for, b...","[28, 38, 13, 38, 30, 37, 15, 16, 24, 15, 21, 2...","[11, 21, 11, 21, 22, 22, 13, 11, 12, 13, 11, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",They said there was still demand for blue chip...
3248,3248,"[The, DSE, all, share, price, index, closed, 2...","[12, 21, 12, 21, 21, 21, 38, 11, 24, 10, 11, 2...","[11, 12, 11, 12, 12, 12, 21, 11, 12, 0, 11, 12...","[0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 2, 1, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 1, 5, ...","[0, 1, 4, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",The DSE all share price index closed 2.73 poin...


In [13]:
train_target_inp = pd_train_upd['ner_tags_bert']
train_target = [[x for x in row] for row in train_target_inp]
train_features_inp = pd_train_upd['sentence']
train_features = [row for row in train_features_inp]

In [14]:
ff = tf.convert_to_tensor(train_features)
ff[0:4]

<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'EU rejects German call to boycott British lamb .',
       b'Peter Blackburn', b'BRUSSELS 1996-08-22',
       b'The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep .'],
      dtype=object)>

In [15]:
ll = tf.convert_to_tensor(train_target)
ll[0:4]

<tf.Tensor: shape=(4, 128), dtype=int32, numpy=
array([[0, 4, 1, 8, 1, 1, 1, 8, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 6, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0

In [16]:
class NERModelBert(keras.Model):
    def __init__(self, 
                 num_tags,
                 dropout_rate=0.1, 
                 tfhub_handle_encoder="https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
                 tfhub_handle_preprocess="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
                ):
        super().__init__()
        # text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name="text")
        self.preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name="preprocessing")
        self.encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name="BERT_encoder")
        self.dropout = layers.Dropout(dropout_rate)
        self.ff_final = layers.Dense(num_tags, activation="softmax")                                                  
    
    def call(self, text_input):
        encoder_inputs = self.preprocessing_layer(text_input)
        outputs = self.encoder(encoder_inputs)
        net = outputs["sequence_output"]
        net = self.dropout(net)
        net = self.ff_final(net)
        return net




In [17]:
class CustomNonPaddingTokenLoss(keras.losses.Loss):
    def __init__(self, name="custom_ner_loss"):
        super().__init__(name=name)
    
    def call(self, y_true, y_pred):
        loss_fn = keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=keras.losses.Reduction.NONE
        )
        loss = loss_fn(y_true, y_pred)
        mask = tf.cast((y_true > 0), dtype=tf.float32)
        loss = loss * mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)
    

In [18]:
loss = CustomNonPaddingTokenLoss()

In [19]:
ner_model_bert = NERModelBert(num_tags)

In [20]:
ner_model_bert.compile(optimizer="adam", loss=loss)

In [None]:
ner_model_bert.fit(ff, ll, epochs=5, batch_size=32)

Epoch 1/5


  return dispatch_target(*args, **kwargs)




In [298]:
output = ner_model_bert.predict(["hello hi"])
output

array([[[1.22572155e-05, 8.82906914e-01, 4.77831364e-02, ...,
         5.67343878e-03, 2.65362556e-03, 1.15086837e-03],
        [5.16935779e-06, 9.46937978e-01, 3.89702730e-02, ...,
         3.56830657e-04, 2.19833944e-03, 2.34802690e-04],
        [2.02438400e-06, 9.94865835e-01, 5.92860742e-04, ...,
         6.95358380e-04, 1.30436209e-04, 4.02434816e-04],
        ...,
        [1.20012703e-06, 9.95955348e-01, 3.78617464e-04, ...,
         6.37739897e-04, 5.90124437e-05, 2.22384435e-04],
        [5.39399434e-06, 9.59100246e-01, 1.44434161e-02, ...,
         2.54221587e-03, 8.75894562e-04, 5.95670776e-04],
        [9.75556304e-06, 9.20625031e-01, 2.80396007e-02, ...,
         4.19480354e-03, 1.84733886e-03, 9.39223217e-04]]], dtype=float32)

In [299]:
prediction = np.argmax(output, axis=-1)[0]
prediction = [mapping[i] for i in prediction]
prediction

['O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O']

## Metrics calculation

Here is a function to calculate the metrics. The function calculates F1 score for the
overall NER dataset as well as individual scores for each NER tag.

In [300]:
def calculate_metrics(dataset):
    all_true_tag_ids, all_predicted_tag_ids = [], []

    for x, y in dataset:
        output = ner_model_bert.predict(x)
        predictions = np.argmax(output, axis=-1)
        predictions = np.reshape(predictions, [-1])

        true_tag_ids = np.reshape(y, [-1])

        mask = (true_tag_ids > 0) & (predictions > 0)
        true_tag_ids = true_tag_ids[mask]
        predicted_tag_ids = predictions[mask]

        all_true_tag_ids.append(true_tag_ids)
        all_predicted_tag_ids.append(predicted_tag_ids)

    all_true_tag_ids = np.concatenate(all_true_tag_ids)
    all_predicted_tag_ids = np.concatenate(all_predicted_tag_ids)

    predicted_tags = [mapping[tag] for tag in all_predicted_tag_ids]
    real_tags = [mapping[tag] for tag in all_true_tag_ids]

    evaluate(real_tags, predicted_tags)

In [None]:
calculate_metrics(val_dataset)