Link: https://www.kaggle.com/datasets/abhinavwalia95/entity

In [64]:
import pandas as pd
import numpy as np
import collections
import sklearn_crfsuite
from sklearn_crfsuite import CRF
from sklearn.model_selection import cross_val_predict
from sklearn_crfsuite.metrics import flat_classification_report
import eli5
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report

In [65]:
data = pd.read_csv("train_sets/ner_train.csv", encoding="latin1")

In [66]:
data = data.ffill()

In [67]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1048575 entries, 0 to 1048574
Data columns (total 4 columns):
 #   Column      Non-Null Count    Dtype 
---  ------      --------------    ----- 
 0   Sentence #  1048575 non-null  object
 1   Word        1048575 non-null  object
 2   POS         1048575 non-null  object
 3   Tag         1048575 non-null  object
dtypes: object(4)
memory usage: 32.0+ MB


In [68]:
data.head()

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O


In [69]:
words = list(set(data["Word"].values))
n_words = len(words); n_words

35177

In [70]:
pos = list(set(data["POS"].values))
print(pos)

['DT', 'NN', ';', 'EX', ',', 'WP$', '``', 'VBD', ':', 'RP', 'JJS', 'LRB', 'IN', 'WP', 'POS', 'WRB', 'NNP', 'VBG', 'VBP', 'CD', 'NNPS', 'RBS', 'WDT', 'CC', 'PDT', 'RBR', 'UH', 'FW', 'MD', 'VB', 'JJ', 'PRP', 'RB', 'RRB', '.', 'JJR', 'VBZ', 'VBN', 'NNS', 'PRP$', 'TO', '$']


In [71]:
labels = list(set(data["Tag"].values))
print(labels)

['B-tim', 'B-nat', 'I-nat', 'I-per', 'O', 'I-art', 'I-geo', 'B-eve', 'I-eve', 'B-gpe', 'B-per', 'I-tim', 'I-gpe', 'B-art', 'I-org', 'B-geo', 'B-org']


In [72]:
label_counts = collections.Counter(list(data["Tag"].values))
print(label_counts)

Counter({'O': 887908, 'B-geo': 37644, 'B-tim': 20333, 'B-org': 20143, 'I-per': 17251, 'B-per': 16990, 'I-org': 16784, 'B-gpe': 15870, 'I-geo': 7414, 'I-tim': 6528, 'B-art': 402, 'B-eve': 308, 'I-art': 297, 'I-eve': 253, 'B-nat': 201, 'I-gpe': 198, 'I-nat': 51})


In [73]:
# Function that processes the data into sentences
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                           s["POS"].values.tolist(),
                                                           s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

In [74]:
getter = SentenceGetter(data)

  self.grouped = self.data.groupby("Sentence #").apply(agg_func)


In [75]:
sent = getter.get_next()
print(sent)

[('Thousands', 'NNS', 'O'), ('of', 'IN', 'O'), ('demonstrators', 'NNS', 'O'), ('have', 'VBP', 'O'), ('marched', 'VBN', 'O'), ('through', 'IN', 'O'), ('London', 'NNP', 'B-geo'), ('to', 'TO', 'O'), ('protest', 'VB', 'O'), ('the', 'DT', 'O'), ('war', 'NN', 'O'), ('in', 'IN', 'O'), ('Iraq', 'NNP', 'B-geo'), ('and', 'CC', 'O'), ('demand', 'VB', 'O'), ('the', 'DT', 'O'), ('withdrawal', 'NN', 'O'), ('of', 'IN', 'O'), ('British', 'JJ', 'B-gpe'), ('troops', 'NNS', 'O'), ('from', 'IN', 'O'), ('that', 'DT', 'O'), ('country', 'NN', 'O'), ('.', '.', 'O')]


In [76]:
sentences = getter.sentences
print(len(sentences))

47959


In [77]:
# input is a sentence as a structure show above 
#and and ith word from the sentence to return the features for that word

def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]
    
    # data structure consisting of a feature name and value for the token
    features = {
        'bias': 1.0,
        'word.lower()': word.lower(), # lower case variant of the token
        'word[-3:]': word[-3:], #suffix of 3 characters
        'word[-2:]': word[-2:], #suffix of 2 characters
        'word.isupper()': word.isupper(), # initial captial
        'word.istitle()': word.istitle(), # all words ini caps
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2], #first two characters of the PoS Tag
    }
    if i > 0:
        # adding features for the word based on the previous word
        word1 = sent[i-1][0] # previous word
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True # Beginning of sentence as a feature

    if i < len(sent)-1:
        # adding features for the word based on the next word
        word1 = sent[i+1][0] # next word
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True # end of sentence as a feature

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [78]:
#X = [sent2features(s) for s in sentences]
#y = [sent2labels(s) for s in sentences]

#If your enviornment breaks here, it might be because of very large lists being held in memory. Try loading first 10000 examples with:
X = [sent2features(s) for s in sentences[:10]]
y = [sent2labels(s) for s in sentences[:10]]

In [79]:
print(X[0])

[{'bias': 1.0, 'word.lower()': 'thousands', 'word[-3:]': 'nds', 'word[-2:]': 'ds', 'word.isupper()': False, 'word.istitle()': True, 'word.isdigit()': False, 'postag': 'NNS', 'postag[:2]': 'NN', 'BOS': True, '+1:word.lower()': 'of', '+1:word.istitle()': False, '+1:word.isupper()': False, '+1:postag': 'IN', '+1:postag[:2]': 'IN'}, {'bias': 1.0, 'word.lower()': 'of', 'word[-3:]': 'of', 'word[-2:]': 'of', 'word.isupper()': False, 'word.istitle()': False, 'word.isdigit()': False, 'postag': 'IN', 'postag[:2]': 'IN', '-1:word.lower()': 'thousands', '-1:word.istitle()': True, '-1:word.isupper()': False, '-1:postag': 'NNS', '-1:postag[:2]': 'NN', '+1:word.lower()': 'demonstrators', '+1:word.istitle()': False, '+1:word.isupper()': False, '+1:postag': 'NNS', '+1:postag[:2]': 'NN'}, {'bias': 1.0, 'word.lower()': 'demonstrators', 'word[-3:]': 'ors', 'word[-2:]': 'rs', 'word.isupper()': False, 'word.istitle()': False, 'word.isdigit()': False, 'postag': 'NNS', 'postag[:2]': 'NN', '-1:word.lower()': '

In [80]:
# different parameters are used for training
# check https://sklearn-crfsuite.readthedocs.io/en/latest/api.html?highlight=CRF
crf = CRF(algorithm='lbfgs',
          c1=0.1, #The coefficient for L1 regularization.
          c2=0.1, #The coefficient for L2 regularization.
          max_iterations=100,
          all_possible_transitions=False) #When True, CRFsuite generates transition features that associate all of possible label pairs, 
                                        #including ones that never occur. Suppose that the number of labels in the training data is L, this function will generate (L * L) transition features

In [81]:
pred = cross_val_predict(estimator=crf, X=X, y=y, cv=5)

In [82]:
report = flat_classification_report(y_pred=pred, y_true=y)
print(report)

              precision    recall  f1-score   support

       B-geo       0.63      0.92      0.75        13
       B-gpe       0.67      0.33      0.44         6
       B-org       0.00      0.00      0.00         2
       B-per       0.00      0.00      0.00         3
       B-tim       1.00      0.60      0.75         5
       I-geo       0.00      0.00      0.00         1
       I-per       0.00      0.00      0.00         2
           O       0.96      0.99      0.98       207

    accuracy                           0.93       239
   macro avg       0.41      0.36      0.37       239
weighted avg       0.91      0.93      0.91       239



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [83]:
crf.fit(X, y)

In [84]:
eli5.show_weights(crf, top=30)

From \ To,O,B-geo,I-geo,B-gpe,B-org,B-per,I-per,B-tim
O,1.98,0.941,0.0,-0.007,0.229,0.0,0.0,0.489
B-geo,0.123,0.0,0.805,0.0,0.0,0.0,0.0,0.0
I-geo,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-gpe,0.375,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-org,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-per,0.0,0.0,0.0,0.781,0.0,0.0,1.391,0.0
I-per,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B-tim,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7
+3.281,bias,,,,,,
+1.125,BOS,,,,,,
+0.911,postag:NN,,,,,,
+0.832,+1:postag[:2]:VB,,,,,,
+0.736,-1:word.istitle(),,,,,,
+0.707,+1:word.istitle(),,,,,,
+0.592,-1:postag[:2]:NN,,,,,,
+0.585,word[-2:]:id,,,,,,
+0.533,postag:NNS,,,,,,
+0.520,-1:word.lower():a,,,,,,

Weight?,Feature
+3.281,bias
+1.125,BOS
+0.911,postag:NN
+0.832,+1:postag[:2]:VB
+0.736,-1:word.istitle()
+0.707,+1:word.istitle()
+0.592,-1:postag[:2]:NN
+0.585,word[-2:]:id
+0.533,postag:NNS
+0.520,-1:word.lower():a

Weight?,Feature
+1.126,postag:NNP
+0.802,word.isupper()
+0.746,word.istitle()
+0.633,-1:word.lower():the
+0.423,+1:word.lower():and
+0.409,+1:postag[:2]:CC
+0.409,+1:postag:CC
+0.409,+1:postag:NN
+0.399,word.lower():u.s.
+0.399,word[-3:]:.S.

Weight?,Feature
0.592,word.lower():waziristan
0.592,word[-3:]:tan
0.592,-1:word.lower():south
0.456,word[-2:]:an
0.356,+1:postag[:2]:.
0.356,+1:word.lower():.
0.356,+1:postag:.
0.181,-1:postag:NNP
0.144,-1:word.istitle()
0.039,-1:postag[:2]:NN

Weight?,Feature
+1.341,postag:JJ
+1.274,postag[:2]:JJ
+1.015,word.istitle()
+0.837,+1:postag:NNS
+0.726,word[-3:]:ian
+0.606,word[-2:]:ka
+0.606,-1:word.lower():sri
+0.606,word[-3:]:nka
+0.606,word.lower():lanka
+0.358,word[-2:]:an

Weight?,Feature
0.702,word.lower():iaea
0.702,+1:word.lower():surveillance
0.702,word[-2:]:EA
0.702,word[-3:]:AEA
0.69,+1:word.lower():militants
0.69,-1:word.lower():many
0.69,word[-3:]:ban
0.69,word.lower():taliban
0.675,-1:word.lower():an
0.58,+1:postag[:2]:NN

Weight?,Feature
0.892,+1:postag:NNP
0.859,+1:word.lower():egeland
0.746,+1:word.istitle()
0.418,+1:postag[:2]:NN
0.348,+1:word.lower():lanka
0.348,word[-2:]:ri
0.348,word[-3:]:Sri
0.348,word.lower():sri
0.253,word.lower():jan
0.253,-1:word.lower():coordinator

Weight?,Feature
0.738,word.lower():egeland
0.636,+1:word.lower():said
0.602,word[-3:]:and
0.598,word[-2:]:nd
0.538,+1:postag:VBD
0.392,+1:postag[:2]:VB
0.221,-1:postag:NNP
0.139,-1:word.istitle()
0.133,-1:word.lower():jan
0.103,-1:word.lower():mr.

Weight?,Feature
1.652,word[-3:]:day
1.634,word[-2:]:ay
0.792,word.lower():indonesia
0.594,"+1:postag:,"
0.594,"+1:word.lower():,"
0.594,"+1:postag[:2]:,"
0.529,word[-3:]:sia
0.392,word[-2:]:ia
0.389,-1:word.lower():in
0.317,word.lower():saturday


In [85]:
print("precision-score: {:.1%}".format(precision_score(y, pred)))
print("recall-score: {:.1%}".format(recall_score(y, pred)))
print("F1-score: {:.1%}".format(f1_score(y, pred)))

precision-score: 65.4%
recall-score: 58.6%
F1-score: 61.8%


In [86]:
print(classification_report(y, pred))

              precision    recall  f1-score   support

         geo       0.63      0.92      0.75        13
         gpe       0.67      0.33      0.44         6
         org       0.00      0.00      0.00         2
         per       0.00      0.00      0.00         3
         tim       1.00      0.60      0.75         5

   micro avg       0.65      0.59      0.62        29
   macro avg       0.46      0.37      0.39        29
weighted avg       0.59      0.59      0.56        29



  _warn_prf(average, modifier, msg_start, len(result))
