<a href="https://colab.research.google.com/github/blawok/named-entity-recognition/blob/master/ner_distilbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/12/b5/ac41e3e95205ebf53439e4dd087c58e9fd371fd8e3724f2b9b4cdb8282e5/transformers-2.10.0-py3-none-any.whl (660kB)
[K     |▌                               | 10kB 29.2MB/s eta 0:00:01[K     |█                               | 20kB 5.6MB/s eta 0:00:01[K     |█▌                              | 30kB 6.4MB/s eta 0:00:01[K     |██                              | 40kB 7.5MB/s eta 0:00:01[K     |██▌                             | 51kB 6.6MB/s eta 0:00:01[K     |███                             | 61kB 7.6MB/s eta 0:00:01[K     |███▌                            | 71kB 7.5MB/s eta 0:00:01[K     |████                            | 81kB 8.4MB/s eta 0:00:01[K     |████▌                           | 92kB 7.7MB/s eta 0:00:01[K     |█████                           | 102kB 7.6MB/s eta 0:00:01[K     |█████▌                          | 112kB 7.6MB/s eta 0:00:01[K     |██████                          | 122kB 7.6

In [2]:
import pandas as pd
import numpy as np
import tensorflow as tf

from transformers import *
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
    TimeDistributed,
    Dense,
    Input,
    GlobalAveragePooling1D,
    Dropout
)


data = pd.read_csv("/content/drive/My Drive/Colab Notebooks/NER/ner_dataset.csv", encoding="latin1").fillna(method="ffill")
data.tail(10)

Using TensorFlow backend.


Unnamed: 0,Sentence #,Word,POS,Tag
1048565,Sentence: 47958,impact,NN,O
1048566,Sentence: 47958,.,.,O
1048567,Sentence: 47959,Indian,JJ,B-gpe
1048568,Sentence: 47959,forces,NNS,O
1048569,Sentence: 47959,said,VBD,O
1048570,Sentence: 47959,they,PRP,O
1048571,Sentence: 47959,responded,VBD,O
1048572,Sentence: 47959,to,TO,O
1048573,Sentence: 47959,the,DT,O
1048574,Sentence: 47959,attack,NN,O


Prepare data for preprocessing

In [3]:
class SentenceGetter(object):

    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                           s["POS"].values.tolist(),
                                                           s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]

    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

getter = SentenceGetter(data)
sentences = [[word[0] for word in sentence] for sentence in getter.sentences]
labels = [[s[2] for s in sentence] for sentence in getter.sentences]


tag_values = list(set(data["Tag"].values))
tag_values.append("PAD")
tag2idx = {t: i for i, t in enumerate(tag_values)}

print(sentences[0])
print(labels[0])
print(len(sentences))
print(len(labels))

['Thousands', 'of', 'demonstrators', 'have', 'marched', 'through', 'London', 'to', 'protest', 'the', 'war', 'in', 'Iraq', 'and', 'demand', 'the', 'withdrawal', 'of', 'British', 'troops', 'from', 'that', 'country', '.']
['O', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'O', 'O', 'O', 'O', 'O', 'B-geo', 'O', 'O', 'O', 'O', 'O', 'B-gpe', 'O', 'O', 'O', 'O', 'O']
47959
47959


Tokenize for DistilBERT

In [4]:
MAX_LEN = 50
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

def tokenize_and_preserve_labels(sentence, text_labels):
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)

    return tokenized_sentence, labels

tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(sentences, labels)
]

tokenized_texts = [token_label_pair[0] for token_label_pair in tokenized_texts_and_labels]
labels = [token_label_pair[1] for token_label_pair in tokenized_texts_and_labels]

input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                          maxlen=MAX_LEN, dtype="long", value=0.0,
                          truncating="post", padding="post")

tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],
                     maxlen=MAX_LEN, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")

attention_masks = [[float(i != 0.0) for i in ii] for ii in input_ids]

tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids, tags,
                                                            random_state=2018, test_size=0.1)
tr_masks, val_masks, _, _ = train_test_split(attention_masks, input_ids,
                                             random_state=2018, test_size=0.1)

print(tr_inputs[0])
print(tr_tags[0])
print(len(set(tag2idx)))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…


[ 1335  1655  1421 22452  1138  1151  1841  1107   170 19850  2035  1113
   170  2699  3227  1107  1890   118  4013 13705   119     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0]
[16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16  2  2  2  9 16 17 17 17
 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17
 17 17]
18


Declare model

In [43]:
def create_distilbert_model(freeze = True):
    input_id = Input((MAX_LEN,), dtype=tf.int32)
    input_mask = Input((MAX_LEN,), dtype=tf.int32)

    config = DistilBertConfig()
    config.output_hidden_states = False
    transformer_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased',
                                                          config=config)

    embedding = transformer_model(input_id, attention_mask=input_mask)[0]
    
    if freeze:
      # freezing layers to skip fine-tuning of BERT layers
      for layer in transformer_model.layers:
          layer.trainable = False

    x = TimeDistributed(Dense(18, activation='softmax'))(embedding)

    model = Model(inputs=[input_id, input_mask], outputs=x)
    
    return model

model = create_distilbert_model(freeze = False)
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 50)]         0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 50)]         0                                            
__________________________________________________________________________________________________
tf_distil_bert_model_2 (TFDisti ((None, 50, 768),)   66362880    input_5[0][0]                    
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, 50, 18)       13842       tf_distil_bert_model_2[0][0]     
Total params: 66,376,722
Trainable params: 66,376,722
Non-trainable params: 0
______________

Compile model

In [0]:
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

Check shapes

In [7]:
print('Training inputs')
print(tr_inputs.shape)
print(tr_inputs)
print('Training masks')
print(np.array(tr_masks).shape)
print(np.array(tr_masks))
print('Training labels/tags')
print(tr_tags.shape)
print(tr_tags)

Training inputs
(43163, 50)
[[ 1335  1655  1421 ...     0     0     0]
 [15769  1163  1199 ...     0     0     0]
 [  138  2315  2430 ...     0     0     0]
 ...
 [16228  1144   170 ...     0     0     0]
 [23077   117  1126 ...     0     0     0]
 [ 1109  1938  2078 ...     0     0     0]]
Training masks
(43163, 50)
[[1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]]
Training labels/tags
(43163, 50)
[[16 16 16 ... 17 17 17]
 [15 16 16 ... 17 17 17]
 [15 15 15 ... 17 17 17]
 ...
 [15 16 16 ... 17 17 17]
 [16 16 16 ... 17 17 17]
 [16  5 16 ... 17 17 17]]


In [0]:
train_obs = tr_inputs.shape[0]

In [0]:
train_inputs = [tr_inputs[:train_obs,],np.array(tr_masks)[:train_obs,]]

In [20]:
tr_tags[:train_obs,].shape

(43163, 50)

Train the model with freezed DistilBERT layers

In [21]:
history = model.fit(train_inputs,
                    tr_tags[:train_obs,], 
                    epochs=3,
                    batch_size = 32)

Epoch 1/3
Epoch 2/3
Epoch 3/3


Unfreeze layers

In [45]:
history = model.fit(train_inputs,
                    tr_tags[:train_obs,], 
                    epochs=3,
                    batch_size = 32)

Epoch 1/3
Epoch 2/3
Epoch 3/3


Evaluate

In [0]:
test_inputs = [val_inputs, np.array(val_masks)]
predictions = model.predict(test_inputs)

In [30]:
predictions[0].shape

(50, 18)

In [40]:
predictions[0][0]

array([3.3014374e-13, 1.1610874e-11, 9.2914427e-18, 3.4581113e-15,
       4.6539504e-14, 1.9064605e-15, 2.1595694e-15, 6.5220876e-15,
       3.5513276e-17, 1.7335930e-14, 2.9685181e-14, 7.9044066e-16,
       1.2492772e-14, 2.3777659e-15, 9.0894531e-12, 7.6622076e-16,
       1.0000000e+00, 5.6802061e-16], dtype=float32)

In [41]:
tag2idx

{'B-art': 14,
 'B-eve': 0,
 'B-geo': 2,
 'B-gpe': 5,
 'B-nat': 1,
 'B-org': 15,
 'B-per': 8,
 'B-tim': 11,
 'I-art': 4,
 'I-eve': 10,
 'I-geo': 9,
 'I-gpe': 3,
 'I-nat': 12,
 'I-org': 7,
 'I-per': 6,
 'I-tim': 13,
 'O': 16,
 'PAD': 17}