In [None]:
import json
import numpy as np
import pandas as pd
import re
from tqdm import tqdm
import cloudpickle
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizerFast, TFAutoModelForTokenClassification
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers.schedules import PolynomialDecay

In [2]:
class LabelEncoder:
    '''
    Label Encoder to encode and decode the entity labels
    '''
    def __init__(self):
        self.label_mapping = {'O': 0, 
                             'B-geo': 1, 
                             'I-geo': 2, 
                             'B-gpe': 3, 
                             'I-gpe': 4, 
                             'B-per': 5,
                             'I-per': 6,
                             'B-org': 7,
                             'I-org': 8,
                             'B-tim': 9,
                             'I-tim': 10,
                             'B-art': 11, 
                             'I-art': 12,
                             'B-nat': 13,
                             'I-nat': 14,
                             'B-eve': 15,
                             'I-eve': 16,
                             '[CLS]': -100,
                             '[SEP]': -100}
        
        self.inverse_label_mapping = {}
    
    def fit(self, x: pd.Series):
        self.inverse_label_mapping = {value: key for key, value in self.label_mapping.items()}
        return self
        
    def transform(self, x: pd.Series):
        x = x.map(self.label_mapping)
        return x
    
    def inverse_transform(self, x: pd.Series):
        x = x.map(self.inverse_label_mapping)
        return x
        

In [3]:
# Fitting and saving Label Encoder
label_encoder = LabelEncoder()
df = pd.read_csv('ner_dataset.csv', encoding='unicode_escape')
label_encoder.fit(df['Tag'])
with open('hf_ner_label_encoder.bin', 'wb') as f:
    cloudpickle.dump(label_encoder, f)

In [4]:
# data source: https://www.kaggle.com/datasets/saurabhprajapat/named-entity-recognition
def get_preprocessed_data(file_path):
    '''
    Function to read the data from CSV and collect tokens and tags of each
    sentence as lists.
    '''
    df = pd.read_csv(file_path, encoding='unicode_escape')
    df = df.groupby('Sentence #', as_index=False).agg({'Tag': lambda x: list(x), 'Word': lambda x: list(x)})
    df.drop(columns='Sentence #', inplace=True)
    df.columns = ['target', 'text']
    return df

In [5]:
def get_inputs_adjusted_labels(list_of_texts, list_of_labels, label_encoder, max_token_length=50):
    '''
    Function to rearrange the entity labels to match with the sub-word tokens and 
    [CLS], [PAD] and [SEP] tokens
    '''
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    
    adjusted_labels = []
    adjusted_encoded_labels = []
    tokenized_inputs = {}
    
    for idx in range(len(list_of_texts)):
        text = list_of_texts[idx]
        labels = list_of_labels[idx]
        
        #####################################
        #     Input Tokenization Start      #
        #####################################
    
        inputs = tokenizer(text, is_split_into_words=True, max_length=max_token_length, truncation=True, padding="max_length")
        word_ids = inputs.word_ids()
        # print(word_ids)
        
        
        if len(tokenized_inputs) == 0:
            tokenized_inputs['input_ids'] = [inputs['input_ids']]
            tokenized_inputs['attention_mask'] = [inputs['attention_mask']]
        else:
            tokenized_inputs['input_ids'].append(inputs['input_ids'])
            tokenized_inputs['attention_mask'].append(inputs['attention_mask'])
        
        #####################################
        #     Input Tokenization End        #
        #####################################
        
        #####################################
        #     Label Rearrangement Start     #
        #####################################
        res = ['[CLS]']
        p = 1
        
        while p < len(word_ids):
            if word_ids[p] is None:
                res.append('[SEP]')
                p += 1
                continue
            prev_label = res[p-1]
            curr_label = labels[word_ids[p]]
            if prev_label.find('-') != -1:
                prev_label_split = prev_label.split("-")
            else:
                prev_label_split = ['PO', 'PO']
            prev_label_prefix, prev_label_suffix = prev_label_split


            if curr_label.find('-') != -1:
                curr_label_split = curr_label.split("-")
            else:
                curr_label_split = ['CO', 'CO']
            curr_label_prefix, curr_label_suffix = curr_label_split

            if ((prev_label_prefix == 'B') or (prev_label_prefix == 'I')) and (prev_label_suffix == curr_label_suffix):
                res.append('I-'+prev_label_suffix)
            else:
                res.append(curr_label)
            p += 1
            
        #####################################
        #     Label Rearrangement End       #
        #####################################
        
        adjusted_labels.append(res)
        adjusted_encoded_labels.append([*label_encoder.transform(pd.Series(res))])
    return tokenized_inputs, adjusted_labels, adjusted_encoded_labels

In [6]:
def get_processed_train_test(file_path='ner_dataset.csv', label_encoder_path='hf_ner_label_encoder.bin',
                             test_size: float=0.15, input_col: str='text', target_col: str='target', 
                             max_token_length=50, random_state=42):
    
    '''
    Function to read CSV data and return preprocessed train and test sets
    '''
    df = get_preprocessed_data(file_path)
    x = df[input_col].copy()
    y = df[target_col].copy()
    del(df)
    
    if test_size > 0:
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=random_state)
    else:
        x_train, y_train = x, y
        x_test, y_test = None, None
    
    x_train, y_train  = x_train.to_list(), y_train.to_list()
    if x_test is not None:
        x_test, y_test = x_test.to_list(), y_test.to_list()
    
    with open(label_encoder_path, 'rb') as f:
        label_encoder = cloudpickle.load(f)
    
    x_train, _, y_train = get_inputs_adjusted_labels(x_train, y_train, label_encoder, max_token_length)
    
    train = x_train
    train['labels'] = y_train
    
    if x_test is not None:
        x_test, _, y_test = get_inputs_adjusted_labels(x_test, y_test, label_encoder, max_token_length)
        test = x_test
        test['labels'] = y_test
    else:
        test = None
        
    return train, test

In [7]:
def visualize_tokens_labels(idx):
    '''
    Function to visualize tokens and entity labels
    before and after preprocessing
    '''
    df = get_preprocessed_data('ner_dataset.csv')
    tokens = df.iloc[idx, 1]
    labels = df.iloc[idx, 0]
    print('BEFORE PREPROCESSING')
    for tok, lab in zip(tokens, labels):
        print(f'{tok: <20}{lab}')
    
    print('\nAFTER PREPROCESSING')
    x, _ = get_processed_train_test(max_token_length=50, test_size=0.0)
    input_ids = x['input_ids'][idx]
    labels = x['labels'][idx]
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    tokens = tokenizer.convert_ids_to_tokens(x['input_ids'][idx])

    label_encoder_path='hf_ner_label_encoder.bin'
    with open(label_encoder_path, 'rb') as f:
        label_encoder = cloudpickle.load(f)
        
    labels = [*label_encoder.inverse_transform(pd.Series(labels))]
    for tok, lab in zip(tokens, labels):
        print(f'{tok: <20}{lab}')

In [8]:
visualize_tokens_labels(123)

BEFORE PREPROCESSING
Spain               B-gpe
has                 O
begun               O
a                   O
trial               O
for                 O
24                  O
suspected           O
al-Qaida            B-org
members             O
,                   O
including           O
three               O
accused             O
of                  O
helping             O
plan                O
the                 O
September           B-tim
11                  I-tim
,                   I-tim
2001                I-tim
terrorist           O
attacks             O
in                  O
the                 O
United              B-geo
States              I-geo
.                   O

AFTER PREPROCESSING


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

[CLS]               [SEP]
spain               B-gpe
has                 O
begun               O
a                   O
trial               O
for                 O
24                  O
suspected           O
al                  B-org
-                   I-org
q                   I-org
##aid               I-org
##a                 I-org
members             O
,                   O
including           O
three               O
accused             O
of                  O
helping             O
plan                O
the                 O
september           B-tim
11                  I-tim
,                   I-tim
2001                I-tim
terrorist           O
attacks             O
in                  O
the                 O
united              B-geo
states              I-geo
.                   O
[SEP]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]             

In [9]:
def return_tf_tensors(data):
    data = tf.data.Dataset.from_tensor_slices(dict(data))
    data = data.prefetch(tf.data.AUTOTUNE)
    return data

In [10]:
CHECKPOINT = "distilbert-base-uncased"
N_TOKENS = 50
BATCH_SIZE = 32
N_LABELS = 18

In [11]:
train, test = get_processed_train_test(file_path='ner_dataset.csv', max_token_length=N_TOKENS, test_size=0.15)

In [12]:
train_tf_data = return_tf_tensors(train)
test_tf_data = return_tf_tensors(test)

In [13]:
for i in train_tf_data.take(1):
    print(i)

{'input_ids': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([  101,  1996,  3732,  2181,  2001,  4015,  2000,  2660,  1999,
        5774,  1010,  2029,  4548,  1996,  2642,  4664,  2076,  2088,
        2162,  1045,  1998,  2506,  2000, 21497,  1996,  4117,  2752,
        2127,  4336,  1999,  3339,  1012,   102,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0], dtype=int32)>, 'labels': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([-100,    0,    0,    0,    0,    0,    0,    1,    0,    9,    0,
          0,    0,    0,    0,    0,    0,    0,    0,   15,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    9,    0, -100,
       -1

In [14]:
def fit_model(train_data, val_data, epochs=2, eta=1e-4, early_stopping_patience=1, batch_size=BATCH_SIZE):
    model = TFAutoModelForTokenClassification.from_pretrained(CHECKPOINT, num_labels=N_LABELS)
    learning_schedule = PolynomialDecay(initial_learning_rate=eta, decay_steps=len(train_data) * epochs, end_learning_rate=0)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_schedule))

    print(model.summary())
    early_stop = EarlyStopping(monitor="val_loss", patience=early_stopping_patience, mode="min")
    model.fit(train_data.shuffle(len(train_data)).batch(batch_size), validation_data=val_data.shuffle(len(val_data)).batch(batch_size), 
          epochs=epochs, callbacks=[early_stop])
    return model

In [15]:
model = fit_model(train_data=train_tf_data, val_data=test_tf_data, epochs=25, early_stopping_patience=2)

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertForTokenClassification: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing TFDistilBertForTokenClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForTokenClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFDistilBertForTokenClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able t

Model: "tf_distil_bert_for_token_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 distilbert (TFDistilBertMa  multiple                  66362880  
 inLayer)                                                        
                                                                 
 dropout_19 (Dropout)        multiple                  0         
                                                                 
 classifier (Dense)          multiple                  13842     
                                                                 
Total params: 66376722 (253.21 MB)
Trainable params: 66376722 (253.21 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25


### INFERENCE

In [329]:
def process_output(res):
    '''
    Function to concatenate sub-word tokens and labels
    '''
    d = {}
    result = []
    res.append(['-', 'B-b'])
    for n, i in enumerate(res):
        try:
            split = i[1].split('-')
            token = i[0]
            prefix, suffix = split
            if prefix == 'B':
                if len(d) != 0:
                    result.append([(token.replace(" ##", ""), label) for label, token in d.items()][0])
                d = {}
                d[suffix] = token
            else:
                d[suffix] = d[suffix] + ' ' + token
        except:
            continue
            
    return result


def inference(txt):
    '''
    Function that returns model prediction
    '''
    test_data = [txt]
    tokenizer = DistilBertTokenizerFast.from_pretrained(CHECKPOINT)
    tokens = tokenizer.tokenize(txt)
    tokenized_data = tokenizer(test_data, is_split_into_words=True, max_length=N_TOKENS, 
                               truncation=True, padding="max_length")
    token_idx_to_consider = tokenized_data.word_ids()

    token_idx_to_consider = [i for i in range(len(token_idx_to_consider)) if token_idx_to_consider[i] is not None] 
    
    input_ = [tokenized_data['input_ids'], tokenized_data['attention_mask']]
    pred_logits = model.predict(input_, verbose=0).logits
    pred = tf.argmax(pred_logits, axis=-1)[0].numpy()
    pred = pred[token_idx_to_consider]
    pred_labels = label_encoder.inverse_transform(pd.Series(pred))
    result = [[token, label] for token, label in zip(tokens, pred_labels) if label.find('-') >= 0]
    output = process_output(result)
    return output

In [315]:
txt = '''
On October 9, the Israeli Defense Minister ordered a "complete siege" of Gaza
'''
print(inference(txt))

[('october 9', 'tim'), ('israeli', 'gpe'), ('gaza', 'geo')]


In [316]:
txt = '''
Venezuela issues arrest warrant for US-based opposition leader Juan Guaido
'''
print(inference(txt))

[('venezuela', 'geo'), ('juan guaido', 'per')]


In [317]:
txt = '''
Argentina presidential election heading to run-off with Massa leading Milei
'''
print(inference(txt))

[('argentina', 'geo'), ('massa', 'per'), ('mile', 'per')]


In [319]:
txt = '''
UN Security Council approves sending foreign forces to Haiti
'''
print(inference(txt))

[('un security council', 'org'), ('haiti', 'geo')]


In [322]:
txt = '''
Moody’s sends a warning to America: Your last AAA credit rating is at risk
'''
print(inference(txt))

[('moody ’ s', 'org'), ('america', 'geo'), ('aaa', 'org')]


In [323]:
txt = '''
Haris Rauf Breaks Record For Conceding Most Runs In History Of Cricket World Cup
'''
print(inference(txt))

[('haris rauf', 'per'), ('cricket', 'eve'), ('world cup', 'eve')]


In [330]:
txt = '''
North and South American markets finished broadly higher on Friday with shares in U.S. leading the region. 
The S&P 500 is up 1.56% while Brazil's Bovespa is up 1.29% and Mexico's IPC is up 0.37%.
'''
print(inference(txt))

[('american', 'gpe'), ('friday', 'tim'), ('u . s .', 'geo'), ('s & p', 'org'), ('brazil', 'geo'), ('bovespa', 'geo'), ('mexico', 'geo')]


In [332]:
txt = '''
Google, Lendlease axe plans for $15 billion development in Bay Area
'''
print(inference(txt))

[('google', 'org'), ('lendlease axe', 'org'), ('bay area', 'geo')]


In [338]:
txt = '''
American minimum wage has been $7.25 since 2009. What that means for the economy
'''
print(inference(txt))

[('american', 'gpe'), ('2009', 'tim')]


In [340]:
txt = '''
Hedge fund billionaire Leon Cooperman, in rare public rebuke of a Republican candidate, says Trump ‘belongs in jail’
'''
print(inference(txt))

[('hedge fund', 'org'), ('leon cooperman', 'per'), ('trump', 'per')]


In [341]:
txt = '''
More than 1,600 Jewish Harvard alumni threaten to withdraw donations over antisemitism concerns
'''
print(inference(txt))

[('harvard', 'org')]


In [346]:
txt = '''
Isa Soares speaks to David Culver about expectations for next week's meeting 
between American President Biden and Chinese President Xi Jinping.
'''
print(inference(txt))

[('isa soares', 'per'), ('david culver', 'per'), ('american', 'gpe'), ('president biden', 'per'), ('chinese', 'gpe'), ('president xi jinping', 'per')]
