In [20]:
import pickle 
import sklearn
import sklearn_crfsuite
import string

def isName(word):
    if ' ' in word:
        broken = word.split(' ')
        for i in range(len(broken)):
            if broken[i].islower():
                return False
        return True
    return False

def isMixCase(word):
    if len(word) > 2:
        if word[0].islower() and word[1].istitle():
            return True
        return False
    return False

def wordShape(word):
    shape = ''
    for character in word:
        if character.istitle():
            shape += 'U'
        elif character.islower():
            shape += 'L'
        elif character.isdigit():
            shape += 'D'
        else:
            shape += character
    return shape

def parse_file(file_name):
    f = open(file_name, 'r')
    sentences = []
    sent = []
    for line in f:
        if line.strip():
            tokens = line.strip().split('\t')
            if len(tokens) == 4:
                token = tokens[0]
                pos = tokens[1]
                chunk = tokens[2]
                label = tokens[-1]
                sent.append((token, pos, chunk, label))
        else:
            sentences.append(sent)
            sent = []
    return sentences

def word2feature(sent, i):
    word = sent[i][0]
    if word in string.punctuation:
        word = '<punct>'
    if word[0].isdigit():
        word = '<number>'
    pos = sent[i][1]
    chunk = sent[i][2]
    features = {
        'w(0)': word,
        'w(0)[:1]': word[:1],
        'w(0)[:2]': word[:2],
        'w(0)[:3]': word[:3],
        'w(0)[:4]': word[:4],
        'w(0)[-1:]': word[-1:],
        'w(0)[-2:]': word[-2:],
        'w(0)[-3:]': word[-3:],
        'w(0)[-4:]': word[-4:],
        'word.islower':word.islower(),
        'word.lower': word.lower(),
        'isTitle': word[0].istitle(),
        'isNumber': word.isdigit(),
        'isUpper': word.isupper(),
        'isCapWithPeriod': word[0].istitle() and word[-1] == '.',
        'endsInDigit': word[-1].isdigit(),
        'containHyphen': '-' in word,
        'isDate': word[0].isdigit() and word[-1].isdigit() and '/' in word,
        'isCode': word[0].isdigit() and word[-1].istitle(),
        'isName': isName(word),
        'isMixCase': isMixCase(word),
        'd&comma': word[0].isdigit() and word[-1].isdigit() and ',' in word,
        'd&period': word[0].isdigit() and word[-1].isdigit() and '.' in word,
        'wordShape': wordShape(word),
        
        'pos(0)': pos,
        'chunk(0)': chunk
    }
    if ' ' in word:
        for idx, _ in enumerate(word.split(' ')):
            features.update({
                '{}thword'.format(idx): word.split(' ')[idx]
            })
    if (i > 0):
        prev_word = sent[i-1][0]
        if prev_word in string.punctuation:
            prev_word = '<punct>'
        if prev_word[0].isdigit():
            prev_word = '<number>'
        prev_pos = sent[i-1][1]
        prev_chunk = sent[i-1][2]
        features.update({
            'w(-1)': prev_word,
            'w(-1).lower':prev_word.lower(),
            'isTitle(-1)': prev_word[0].istitle(),
            'isNumber(-1)': prev_word.isdigit(),
            'isCapWithPeriod(-1)': prev_word[0].istitle() and prev_word[-1] == '.',
            'isName(-1)': isName(prev_word),
            'wordShape(-1)': wordShape(prev_word),
            'w(-1)+w(0)': prev_word + ' ' + word,
            
            'pos(-1)': prev_pos,
            'chunk(-1)': prev_chunk,
            'pos(-1) + pos(0)': prev_pos + ' ' + pos,
            'chunk(-1) + chunk(0)': prev_chunk + ' ' + chunk
        })
    else:
        features['BOS'] = True
    
    if i > 1:
        prev_2_word = sent[i-2][0]
        if prev_2_word in string.punctuation:
            prev_2_word = '<punction>'
        if prev_2_word[0].isdigit():
            prev_2_word = '<number>'
        prev_2_pos = sent[i-2][1]
        prev_2_chunk = sent[i-2][2]
        features.update({
            'w(-2)': prev_2_word,
            'w(-2)+w(-1)': prev_2_word + ' ' + prev_word,
            'w(-2).isTitle()': prev_2_word[0].istitle(),
            'w(-2).isdigit': prev_2_word[0].isdigit(),
            
            'pos(-2)': prev_2_pos,
            'chunk(-2)': prev_2_chunk,
            'pos(-2) + pos(-1)': prev_2_pos + ' ' + prev_pos,
            'chunk(-2) + chunk(-1)': prev_2_chunk + ' ' + prev_chunk
        })
    if i < (len(sent) - 1):
        next_word = sent[i+1][0]
        if next_word in string.punctuation:
            next_word = '<punct>'
        if next_word[0].isdigit():
            next_word = '<number>'
        next_pos = sent[i+1][1]
        next_chunk = sent[i+1][2]
        features.update({
            'w(1)': next_word,
            'w(1).lower': next_word.lower(),
            'isTitle(1)': next_word[0].istitle(),
            'isNumber(1)': next_word.isdigit(),
            'isCapWithPeriod(1)': next_word[0].istitle() and next_word[-1] == '.',
            'isName(1)': isName(next_word),
            'wordShape(1)': wordShape(next_word),
            'w(0)+w(1)': word + ' ' + next_word,
            
            'pos(1)': next_pos,
            'chunk(1)': next_chunk,
            'pos(0)+pos(1)': pos + ' ' + next_pos,
            'chunk(0)+chunk(1)': chunk +' '+ next_chunk,
        })
    
    else:
        features['EOS'] = True
    if i < (len(sent) - 2):
        next_2_word = sent[i+2][0]
        if next_2_word in string.punctuation:
            next_2_word = '<punct>'
        if next_2_word[0].isdigit():
            next_2_word = '<number>'
        next_2_pos = sent[i+2][1]
        next_2_chunk = sent[i+2][2]
        features.update({
            'w(2)': next_2_word,
            'w(1) + w(2)': word + ' ' + next_word,
            'w(2).isTitle()': next_2_word[0].istitle(),
            'w(2).isdigit': next_2_word[0].isdigit(),
            
            'pos(2)': next_2_pos,
            'chunk(2)': next_2_chunk,
            'pos(1) + pos(2)': next_pos + ' ' + next_2_pos,
            'chunk(1) + chunk(2)': next_2_chunk
        })
    return features

def get_features(sent):
    return [word2feature(sent, i) for i in range(len(sent))]
def get_labels(sent):
    return [label for token, _, _, label in sent]
def get_tokens(sent):
    return [token for token, _, _, label in sent]

In [21]:
train_sent = parse_file('vlsp2016/train.txt')
test_sent = parse_file('vlsp2016/test.txt')

In [22]:
X_train = [get_features(s) for s in train_sent]
y_train = [get_labels(s) for s in train_sent]
X_test = [get_features(s) for s in test_sent]
y_test = [get_labels(s) for s in test_sent]

In [23]:
crf = sklearn_crfsuite.CRF(
    algorithm = 'lbfgs',
    c1 = 0.06,
    c2 = 0.1,
    max_iterations = 100,
)

In [24]:
crf.fit(X_train, y_train)
labels = list(crf.classes_)
labels.remove('O')
y_pred = crf.predict(X_test)

In [25]:
file_name = 'trained_model.pkl'
pickle.dump(crf, open(file_name, 'wb'))

In [26]:
from sklearn_crfsuite import metrics

metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)

0.9290437013731012

In [27]:
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))

              precision    recall  f1-score   support

       B-LOC      0.919     0.936     0.927      1376
       I-LOC      0.904     0.933     0.918       463
      B-MISC      1.000     0.878     0.935        49
      I-MISC      1.000     0.878     0.935        49
       B-ORG      0.865     0.653     0.744       274
       I-ORG      0.944     0.899     0.921       397
       B-PER      0.934     0.944     0.939      1294
       I-PER      0.975     0.981     0.978       983

   micro avg      0.934     0.927     0.930      4885
   macro avg      0.942     0.888     0.912      4885
weighted avg      0.933     0.927     0.929      4885

