In [1]:
from itertools import chain

import nltk
import sklearn
import scipy.stats
from sklearn.metrics import make_scorer
from sklearn.cross_validation import cross_val_score
from sklearn.grid_search import RandomizedSearchCV

import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics



In [2]:
#transfrom the sentence into correct form
# Totorial https://sklearn-crfsuite.readthedocs.io/en/latest/tutorial.html

def get_sentences_and_NER(filename):
    sentence = []
    sentences = []
    for line in filename:
        line = line.split(' ')
        #print(len(line))
        if(len(line) > 1):
            tuple = (line[0], line[1], line[3].replace('\n',''))
            sentence.append(tuple)
        if(len(line) == 1):
            sentences.append(sentence)
            sentence = []
    return sentences

In [3]:
# Features

def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i - 1][0]
        postag1 = sent[i - 1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent) - 1:
        word1 = sent[i + 1][0]
        postag1 = sent[i + 1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]


def sent2labels(sent):
    return [label for token, postag, label in sent]


def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [9]:
# Training
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=200,
    all_possible_transitions=True
)

import sys
#test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))
#print((test_sents))
file_train = open('../eng.train', 'r')
train_sents = get_sentences_and_NER(file_train)
#print(sentences)
file_test = open('../eng.test', 'r')
test_sents = get_sentences_and_NER(file_test)
#print(train_sents[1])
#sys.exit(0)
# how word2features looks like
#print(sent2features(train_sents[0])[0])

X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]


crf.fit(X_train, y_train)

labels = list(crf.classes_)
labels.remove('O')
print(labels)

y_pred = crf.predict(X_test)
metrics.flat_f1_score(y_test, y_pred,
                      average='weighted', labels=labels)

sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))


<zip object at 0x1a167c7048>
['I-ORG', 'I-MISC', 'I-PER', 'I-LOC', 'B-LOC', 'B-MISC', 'B-ORG']


  'precision', 'predicted', average, warn_for)


             precision    recall  f1-score   support

      B-LOC      0.000     0.000     0.000         6
      I-LOC      0.885     0.821     0.852      4013
     B-MISC      0.500     0.154     0.235        13
     I-MISC      0.832     0.794     0.813      2175
      B-ORG      0.000     0.000     0.000         5
      I-ORG      0.822     0.788     0.805      4590
      I-PER      0.889     0.907     0.898      5924

avg / total      0.862     0.838     0.849     16726



  'precision', 'predicted', average, warn_for)
