In [None]:
import json
import numpy as np
import pandas as pd
import re
from tqdm import tqdm
import cloudpickle
from sklearn.model_selection import train_test_split
from transformers import DebertaTokenizerFast, TFAutoModelForTokenClassification
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers.schedules import PolynomialDecay

In [2]:
class LabelEncoder:
    '''
    Label Encoder to encode and decode the entity labels
    '''
    def __init__(self):
        self.label_mapping = {'O': 0, 
                             'B-geo': 1, 
                             'I-geo': 2, 
                             'B-gpe': 3, 
                             'I-gpe': 4, 
                             'B-per': 5,
                             'I-per': 6,
                             'B-org': 7,
                             'I-org': 8,
                             'B-tim': 9,
                             'I-tim': 10,
                             'B-art': 11, 
                             'I-art': 12,
                             'B-nat': 13,
                             'I-nat': 14,
                             'B-eve': 15,
                             'I-eve': 16,
                             '[CLS]': -100,
                             '[SEP]': -100}
        
        self.inverse_label_mapping = {}
    
    def fit(self, x: pd.Series):
        self.inverse_label_mapping = {value: key for key, value in self.label_mapping.items()}
        return self
        
    def transform(self, x: pd.Series):
        x = x.map(self.label_mapping)
        return x
    
    def inverse_transform(self, x: pd.Series):
        x = x.map(self.inverse_label_mapping)
        return x
        

In [3]:
# Fitting and saving Label Encoder
label_encoder = LabelEncoder()
df = pd.read_csv('general_ner_dataset.csv', encoding='unicode_escape')
label_encoder.fit(df['Tag'])
with open('hf_ner_label_encoder.bin', 'wb') as f:
    cloudpickle.dump(label_encoder, f)

In [4]:
# data source: https://www.kaggle.com/datasets/saurabhprajapat/named-entity-recognition
def get_preprocessed_data(file_path):
    '''
    Function to read the data from CSV and collect tokens and tags of each
    sentence as lists.
    '''
    df = pd.read_csv(file_path, encoding='unicode_escape')
    df = df.groupby('Sentence #', as_index=False).agg({'Tag': lambda x: list(x), 'Word': lambda x: list(x)})
    df.drop(columns='Sentence #', inplace=True)
    df.columns = ['target', 'text']
    return df

In [5]:
def get_inputs_adjusted_labels(list_of_texts, list_of_labels, label_encoder, max_token_length=50):
    '''
    Function to rearrange the entity labels to match with the sub-word tokens and 
    [CLS], [PAD] and [SEP] tokens
    '''
    model_checkpoint = "microsoft/deberta-base"
    tokenizer = DebertaTokenizerFast.from_pretrained(model_checkpoint, add_prefix_space=True)
    
    adjusted_labels = []
    adjusted_encoded_labels = []
    tokenized_inputs = {}
    
    for idx in range(len(list_of_texts)):
        text = list_of_texts[idx]
        labels = list_of_labels[idx]
        
        #####################################
        #     Input Tokenization Start      #
        #####################################
    
        inputs = tokenizer(text, is_split_into_words=True, max_length=max_token_length, truncation=True, padding="max_length")
        word_ids = inputs.word_ids()
        # print(word_ids)
        
        
        if len(tokenized_inputs) == 0:
            tokenized_inputs['input_ids'] = [inputs['input_ids']]
            tokenized_inputs['attention_mask'] = [inputs['attention_mask']]
        else:
            tokenized_inputs['input_ids'].append(inputs['input_ids'])
            tokenized_inputs['attention_mask'].append(inputs['attention_mask'])
        
        #####################################
        #     Input Tokenization End        #
        #####################################
        
        #####################################
        #     Label Rearrangement Start     #
        #####################################
        res = ['[CLS]']
        p = 1
        
        while p < len(word_ids):
            if word_ids[p] is None:
                res.append('[SEP]')
                p += 1
                continue
            prev_label = res[p-1]
            curr_label = labels[word_ids[p]]
            if prev_label.find('-') != -1:
                prev_label_split = prev_label.split("-")
            else:
                prev_label_split = ['PO', 'PO']
            prev_label_prefix, prev_label_suffix = prev_label_split


            if curr_label.find('-') != -1:
                curr_label_split = curr_label.split("-")
            else:
                curr_label_split = ['CO', 'CO']
            curr_label_prefix, curr_label_suffix = curr_label_split

            if ((prev_label_prefix == 'B') or (prev_label_prefix == 'I')) and (prev_label_suffix == curr_label_suffix):
                res.append('I-'+prev_label_suffix)
            else:
                res.append(curr_label)
            p += 1
            
        #####################################
        #     Label Rearrangement End       #
        #####################################
        
        adjusted_labels.append(res)
        adjusted_encoded_labels.append([*label_encoder.transform(pd.Series(res))])
    return tokenized_inputs, adjusted_labels, adjusted_encoded_labels

In [6]:
def get_processed_train_test(file_path='general_ner_dataset.csv', label_encoder_path='hf_ner_label_encoder.bin',
                             test_size: float=0.15, input_col: str='text', target_col: str='target', 
                             max_token_length=50, random_state=42):
    
    '''
    Function to read CSV data and return preprocessed train and test sets
    '''
    df = get_preprocessed_data(file_path)
    x = df[input_col].copy()
    y = df[target_col].copy()
    del(df)
    
    if test_size > 0:
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=random_state)
    else:
        x_train, y_train = x, y
        x_test, y_test = None, None
    
    x_train, y_train  = x_train.to_list(), y_train.to_list()
    if x_test is not None:
        x_test, y_test = x_test.to_list(), y_test.to_list()
    
    with open(label_encoder_path, 'rb') as f:
        label_encoder = cloudpickle.load(f)
    
    x_train, _, y_train = get_inputs_adjusted_labels(x_train, y_train, label_encoder, max_token_length)
    
    train = x_train
    train['labels'] = y_train
    
    if x_test is not None:
        x_test, _, y_test = get_inputs_adjusted_labels(x_test, y_test, label_encoder, max_token_length)
        test = x_test
        test['labels'] = y_test
    else:
        test = None
        
    return train, test

In [7]:
def visualize_tokens_labels(idx):
    '''
    Function to visualize tokens and entity labels
    before and after preprocessing
    '''
    df = get_preprocessed_data('general_ner_dataset.csv')
    tokens = df.iloc[idx, 1]
    labels = df.iloc[idx, 0]
    print('BEFORE PREPROCESSING')
    for tok, lab in zip(tokens, labels):
        print(f'{tok: <20}{lab}')
    
    
    x, _ = get_processed_train_test(max_token_length=50, test_size=0.0)
    input_ids = x['input_ids'][idx]
    labels = x['labels'][idx]
    model_checkpoint = "microsoft/deberta-base"
    tokenizer = DebertaTokenizerFast.from_pretrained(model_checkpoint, add_prefix_space=True)
    tokens = tokenizer.convert_ids_to_tokens(x['input_ids'][idx])

    label_encoder_path='hf_ner_label_encoder.bin'
    with open(label_encoder_path, 'rb') as f:
        label_encoder = cloudpickle.load(f)
    
    print('\nAFTER PREPROCESSING')
    labels = [*label_encoder.inverse_transform(pd.Series(labels))]
    for tok, lab in zip(tokens, labels):
        print(f'{tok: <20}{lab}')

In [8]:
visualize_tokens_labels(123)

BEFORE PREPROCESSING
Spain               B-gpe
has                 O
begun               O
a                   O
trial               O
for                 O
24                  O
suspected           O
al-Qaida            B-org
members             O
,                   O
including           O
three               O
accused             O
of                  O
helping             O
plan                O
the                 O
September           B-tim
11                  I-tim
,                   I-tim
2001                I-tim
terrorist           O
attacks             O
in                  O
the                 O
United              B-geo
States              I-geo
.                   O


Downloading (…)okenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/474 [00:00<?, ?B/s]


AFTER PREPROCESSING
[CLS]               [SEP]
ĠSpain              B-gpe
Ġhas                O
Ġbegun              O
Ġa                  O
Ġtrial              O
Ġfor                O
Ġ24                 O
Ġsuspected          O
Ġal                 B-org
-                   I-org
Qaida               I-org
Ġmembers            O
Ġ,                  O
Ġincluding          O
Ġthree              O
Ġaccused            O
Ġof                 O
Ġhelping            O
Ġplan               O
Ġthe                O
ĠSeptember          B-tim
Ġ11                 I-tim
Ġ,                  I-tim
Ġ2001               I-tim
Ġterrorist          O
Ġattacks            O
Ġin                 O
Ġthe                O
ĠUnited             B-geo
ĠStates             I-geo
Ġ.                  O
[SEP]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SEP]
[PAD]               [SE

In [9]:
def return_tf_tensors(data):
    data = tf.data.Dataset.from_tensor_slices(dict(data))
    data = data.prefetch(tf.data.AUTOTUNE)
    return data

In [10]:
CHECKPOINT = "microsoft/deberta-base"
N_TOKENS = 50
BATCH_SIZE = 32
N_LABELS = 18

In [11]:
train, test = get_processed_train_test(file_path='general_ner_dataset.csv', max_token_length=N_TOKENS, test_size=0.15)

In [12]:
train_tf_data = return_tf_tensors(train)
test_tf_data = return_tf_tensors(test)

In [13]:
for i in train_tf_data.take(1):
    print(i)

{'input_ids': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([    1,    20,  5442,   443,    21,  7225,     7,  1221,    11,
       43289,  2156,    61,  9533,     5,  3285,  4745,   148,   623,
        1771,    38,     8,  1143,     7, 26094,     5,  2771,   911,
         454,  5201,    11, 14873,   479,     2,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0], dtype=int32)>, 'labels': <tf.Tensor: shape=(50,), dtype=int32, numpy=
array([-100,    0,    0,    0,    0,    0,    0,    1,    0,    9,    0,
          0,    0,    0,    0,    0,    0,    0,    0,   15,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    9,    0, -100,
       -1

In [14]:
def fit_model(train_data, val_data, epochs=2, eta=1e-4, early_stopping_patience=1, batch_size=BATCH_SIZE):
    model = TFAutoModelForTokenClassification.from_pretrained(CHECKPOINT, 
                                                              num_labels=N_LABELS, 
                                                              attention_probs_dropout_prob=0.4,
                                                              hidden_dropout_prob=0.4)
    learning_schedule = PolynomialDecay(initial_learning_rate=eta, decay_steps=len(train_data) * epochs, end_learning_rate=0)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_schedule))

    print(model.summary())
    early_stop = EarlyStopping(monitor="val_loss", patience=early_stopping_patience, mode="min")
    model.fit(train_data.shuffle(len(train_data)).batch(batch_size), validation_data=val_data.shuffle(len(val_data)).batch(batch_size), 
          epochs=epochs, callbacks=[early_stop])
    return model

In [15]:
model = fit_model(train_data=train_tf_data, val_data=test_tf_data, epochs=15, early_stopping_patience=2)

Downloading tf_model.h5:   0%|          | 0.00/555M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFDebertaForTokenClassification.

Some layers of TFDebertaForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-base and are newly initialized: ['dropout', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model: "tf_deberta_for_token_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 deberta (TFDebertaMainLaye  multiple                  138601728 
 r)                                                              
                                                                 
 dropout (Dropout)           multiple                  0         
                                                                 
 classifier (Dense)          multiple                  13842     
                                                                 
Total params: 138615570 (528.78 MB)
Trainable params: 138615570 (528.78 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15


### INFERENCE

In [32]:
def softmax(x):
    return tf.exp(x) / tf.math.reduce_sum(tf.exp(x))

In [33]:
def process_output(res):
    '''
    Function to concatenate sub-word tokens, labels and 
    compute mean prediction probability of tokens
    '''
    d = {}
    result = []
    pred_prob = []
    res.append(['-', 'B-b', 0])
    for n, i in enumerate(res):
        try:
            split = i[1].split('-')
            token = i[0]
            token_prob = i[2]
            prefix, suffix = split
            if prefix == 'B':
                if len(d) != 0:
                    result.append([(token.replace("Ġ", " ").strip(), label, np.mean(pred_prob)) for label, token in d.items()][0])
                d = {}
                pred_prob = []
                pred_prob.append(token_prob)
                d[suffix] = token

            else:
                d[suffix] = d[suffix] + token
                pred_prob.append(token_prob)
        except:
            continue
            
    return result


def inference(txt):
    '''
    Function that returns model prediction and prediction probabitliy
    '''
    test_data = [txt]
    tokenizer = DebertaTokenizerFast.from_pretrained(CHECKPOINT, add_prefix_space=True)
    tokens = tokenizer.tokenize(txt)
    tokenized_data = tokenizer(test_data, is_split_into_words=True, max_length=N_TOKENS, 
                               truncation=True, padding="max_length")

    token_idx_to_consider = tokenized_data.word_ids()
    token_idx_to_consider = [i for i in range(len(token_idx_to_consider)) if token_idx_to_consider[i] is not None] 

    input_ = [tokenized_data['input_ids'], tokenized_data['attention_mask']]
    pred_logits = model.predict(input_, verbose=0).logits[0]

    pred_prob = tf.map_fn(softmax, pred_logits)

    pred_idx = tf.argmax(pred_prob, axis=-1).numpy()
    pred_idx = pred_idx[token_idx_to_consider]

    pred_prob = tf.math.reduce_max(pred_prob, axis=-1).numpy()
    pred_prob = np.round(pred_prob[token_idx_to_consider], 3)
    pred_labels = label_encoder.inverse_transform(pd.Series(pred_idx))

    result = [[token, label, prob] for token, label, 
              prob in zip(tokens, pred_labels, pred_prob) if label.find('-') >= 0]
    output = process_output(result)
    return output

In [18]:
txt = '''Xi Jinping arrives in US as his Chinese Dream sputters'''
print(inference(txt))

[('Xi Jinping', 'per', 0.996), ('US', 'geo', 0.995), ('Chinese', 'gpe', 0.915)]


In [19]:
txt = '''Sri Lanka's top court has ruled that ex-president Gotabaya Rajapaksa and his brother Mahinda 
were among 13 former leaders responsible for the country's worst-ever financial crisis.
'''
print(inference(txt))

[('Sri Lanka', 'geo', 0.7195), ('Gotabaya Rajapaksa', 'per', 0.71828574), ('Mahinda Ċ', 'per', 0.98825)]


In [20]:
txt = '''Ukrainian President Volodymyr Zelensky's chief of staff Andriy Yermak has said that 
Ukrainian forces have gained a foothold on the left (eastern) bank of the Dnipro river'''
print(inference(txt))

[('Ukrainian', 'gpe', 0.999), ('President Volodymyr Zelensky', 'per', 0.998), ('Andriy Yermak', 'per', 0.9725), ('Ukrainian', 'gpe', 0.9986667), ('Dnipro', 'geo', 0.98125)]


In [34]:
txt = '''Arinola Omolayo owns a frozen food store in Ogba, a suburb of Lagos, Nigerian commercial nerve centre, 
where she sells mostly imported chicken, fish and turkey.'''
print(inference(txt))

[('Ar', 'org', 0.35), ('Ogba', 'geo', 0.99399996), ('Lagos', 'geo', 0.97650003), ('Nigerian', 'gpe', 0.999)]


In [22]:
txt = '''Friends stars Courteney Cox and Matt Le Blanc have both paid their first 
individual tributes to co-star Mathew Perry, following his death last month.'''
print(inference(txt))

[('Courteney Cox', 'per', 0.98675007), ('Matt Le Blanc', 'per', 0.9906666), ('Mathew Perry', 'per', 0.98233336)]


In [36]:
txt = '''Haris Rauf breaks record for conceding most runs in history of Cricket World Cup.'''
print(inference(txt))

[('Haris Rauf', 'per', 0.97959995), ('Cricket World', 'org', 0.606), ('Cup', 'eve', 0.382)]


In [24]:
txt = '''Alistair Macrow told MPs it had received more than 400 
complaints from workers since July, when the BBC uncovered hundreds of allegations.'''
print(inference(txt))

[('Alistair Macrow', 'per', 0.9342), ('since July', 'tim', 0.923), ('BBC', 'org', 0.997)]


In [25]:
txt = '''Google, Lendlease axe plans for $15 billion development in Bay Area'''
print(inference(txt))

[('Google', 'org', 0.992), ('Lendlease', 'org', 0.921), ('Bay Area', 'geo', 0.99399996)]


In [26]:
txt = '''Sir Richard Branson has told the BBC that he has never faced media coverage as "painful" 
as when he attempted to acquire a loan for his Virgin Group during the pandemic.'''
print(inference(txt))

[('Sir Richard Branson', 'per', 0.996), ('BBC', 'org', 0.995), ('Virgin Group', 'org', 0.996)]


In [27]:
txt = '''Carlos Alcaraz kept alive his semi-final hopes in his maiden ATP Finals campaign with a win 
over Andrey Rublev in which the Russian hit himself so hard with his racquet he drew blood.
'''
print(inference(txt))

[('Carlos Alcaraz', 'per', 0.9885), ('ATP', 'org', 0.784), ('Andrey Rublev', 'per', 0.9375), ('Russian', 'gpe', 0.999)]


In [28]:
txt = '''Rishi Sunak says the government is working on a new treaty with Rwanda, after 
the government's asylum seeker plan was ruled unlawful
'''
print(inference(txt))

[('Rishi Sunak', 'per', 0.99050003), ('Rwanda', 'geo', 0.948)]


In [29]:
txt = '''How many people cross the sea in small boats and how many claim asylum
'''
print(inference(txt))

[]


In [30]:
txt = '''In a blistering letter, she said he had 
repeatedly failed on key policies and broken pledges over immigration.
'''
print(inference(txt))

[]


In [31]:
txt = '''Prince William said he believed this was the decade for collective action to protect the planet. 
He announced the winners of the £1m ($1.2m) prize at a ceremony on Tuesday.
'''
print(inference(txt))

[('Prince William', 'per', 0.981), ('decade', 'tim', 0.84), ('Tuesday', 'tim', 0.995)]


In [37]:
txt = '''Shell is suing Greenpeace for $2.1m (£1.7m) in damages after environmental protesters occupied a 
vessel transporting one of the oil company's floating platforms earlier this year.
'''
print(inference(txt))

[('Shell', 'org', 0.961), ('Greenpeace', 'org', 0.983)]


In [42]:
txt = '''An Israeli-Canadian peace advocate, taken hostage in Gaza, has been killed. Vivian Silver, 74, lived 
close to Israel's border with Gaza in kibbutz Be'eri - attacked by Hamas during 7 October attacks.
'''
print(inference(txt))

[('Gaza', 'geo', 0.995), ('Vivian Silver', 'per', 0.98066664), ('Israel', 'geo', 0.991), ('Gaza', 'geo', 0.99), ("kibbutz Be'eri", 'geo', 0.99200004), ('Hamas', 'org', 0.996), ('7 October', 'tim', 0.9845)]


In [45]:
txt = '''Delhi AQI: Why the Indian capital, Delhi lags behind Beijing in the battle to breathe.
'''
print(inference(txt))

[('Delhi', 'geo', 0.574), ('Indian', 'gpe', 0.999), ('Delhi', 'geo', 0.968), ('Beijing', 'geo', 0.989)]


In [46]:
txt = '''Tiger 3 box office collection day 4: Salman Khan film crosses ₹160 cr in India, likely 
to earn over ₹20 cr on Wednesday
'''
print(inference(txt))

[('day 4', 'tim', 0.9605), ('Salman Khan', 'per', 0.995), ('India', 'geo', 0.996), ('Wednesday', 'tim', 0.994)]


In [47]:
txt = '''Atlee confirms his next film will feature Shah Rukh Khan and Thalapathy Vijay together: Both 
superstars are ready for it
'''
print(inference(txt))

[('Atlee', 'per', 0.845), ('Shah Rukh Khan', 'per', 0.99275), ('Thalapathy Vijay', 'per', 0.8181667)]


In [48]:
txt = '''Mohammed Shami becomes fastest to 50 wickets in ODI World Cup history.
'''
print(inference(txt))

[('Mohammed Shami', 'per', 0.987), ('ODI Cup', 'org', 0.68733335)]
