# NER using sklearn-crfsuite
- [sklearn-crfsuite tutorial](https://eli5.readthedocs.io/en/latest/_notebooks/debug-sklearn-crfsuite.html)

# 0. Settings

In [21]:
!pip install sklearn_crfsuite
!pip install eli5



In [22]:
import nltk
import sklearn_crfsuite
import eli5
from sklearn import preprocessing
from itertools import chain
from sklearn.metrics import classification_report, confusion_matrix

# 1. Training data
- MSRA train set

In [23]:
train_set = './data/msra_train_bio.txt'
test_set = './data/msra_test_bio.txt'

In [24]:
def raw_data_preprocessing(file_name):
    tagged_sentences = []
    sentence = []
    with open(file_name,'r') as f:
        for line in f.readlines():
            if len(line) == 0 or line[0] == '\n':
                if len(sentence) > 0:
                    tagged_sentences.append(sentence)
                    sentence = []
                continue
            if line =='0\t\n':
                continue
            word, ner_tag = line.strip().split('\t') 
            sentence.append((word, ner_tag)) # 단어와 개체명 태깅만 기록
    return tagged_sentences

In [25]:
train_sents = raw_data_preprocessing(train_set)
test_sents = raw_data_preprocessing(test_set)

# 2. Feature extraction

In [31]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias':1.0,
        'word.lower()': word.lower(), # word lower
        'word[-3:]': word[-3:],
        'word[-2:]':word[-2:],
        'word.isupper()': word.isupper(), # word is upper?
        'word.istitle()': word.istitle(), # word is title?
        'word.isdigit()': word.isdigit(), # word is digit?
        #'postag': postag,
        #'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            #'-1:postag': postag1,
            #'-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            #'+1:postag': postag1,
            #'+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, label in sent]

def sent2tokens(sent):
    return [token for token, label in sent]

In [32]:
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_test = [sent2features(s) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

In [33]:
X_train[0][1]

{'bias': 1.0,
 'word.lower()': '希',
 'word[-3:]': '希',
 'word[-2:]': '希',
 'word.isupper()': False,
 'word.istitle()': False,
 'word.isdigit()': False,
 '-1:word.lower()': '当',
 '-1:word.istitle()': False,
 '-1:word.isupper()': False,
 '+1:word.lower()': '望',
 '+1:word.istitle()': False,
 '+1:word.isupper()': False}

# 3. Train a CRF model

In [34]:
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=1.0, # coefficient for L1 penalty
    c2=1e-3, # coefficient for L2 prenalty
    max_iterations=50, # stop earlier
    all_possible_transitions=False,
    min_freq = 5
)
crf.fit(X_train, y_train);

# 4. Inspect model weights

In [35]:
eli5.show_weights(crf, top=10)



From \ To,O,B-LOC,I-LOC,B-ORG,I-ORG,B-PER,I-PER
O,4.319,2.347,0.0,1.852,0.0,3.261,0.0
B-LOC,-0.985,0.389,5.174,-4.49,0.0,0.0,0.0
I-LOC,-0.391,0.033,4.5,-1.908,0.0,-1.569,0.0
B-ORG,-2.717,0.0,0.0,-1.221,7.424,0.0,0.0
I-ORG,-0.892,-1.455,0.0,-3.014,6.462,-1.066,0.0
B-PER,-2.157,0.0,0.0,0.0,0.0,-0.887,5.33
I-PER,-0.251,0.0,0.0,-2.993,0.0,0.016,5.299

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6
+9.319,EOS,,,,,
+4.425,+1:word.lower():伍,,,,,
+4.031,-1:word.lower():该,,,,,
+3.897,+1:word.lower():雨,,,,,
+3.817,-1:word.lower():署,,,,,
+3.654,+1:word.lower():续,,,,,
+3.580,-1:word.lower():各,,,,,
+3.376,-1:word.lower():我,,,,,
+3.365,-1:word.lower():厂,,,,,
… 7170 more positive …,… 7170 more positive …,,,,,

Weight?,Feature
+9.319,EOS
+4.425,+1:word.lower():伍
+4.031,-1:word.lower():该
+3.897,+1:word.lower():雨
+3.817,-1:word.lower():署
+3.654,+1:word.lower():续
+3.580,-1:word.lower():各
+3.376,-1:word.lower():我
+3.365,-1:word.lower():厂
… 7170 more positive …,… 7170 more positive …

Weight?,Feature
+4.741,-1:word.lower():℃
+4.075,+1:word.lower():胞
+3.727,+1:word.lower():两
+3.648,+1:word.lower():淮
+3.554,+1:word.lower():侨
+3.409,-1:word.lower():赴
+3.335,+1:word.lower():友
… 1722 more positive …,… 1722 more positive …
… 393 more negative …,… 393 more negative …
-3.182,+1:word.lower():共

Weight?,Feature
+3.952,+1:word.lower():寺
+3.662,+1:word.lower():畔
+3.313,+1:word.lower():堂
… 2394 more positive …,… 2394 more positive …
… 823 more negative …,… 823 more negative …
-3.055,-1:word.lower():洲
-3.097,-1:word.lower():本
-3.118,-1:word.lower():村
-3.296,-1:word.lower():市
-3.670,-1:word.lower():省

Weight?,Feature
+3.082,-1:word.lower():兼
+3.019,+1:word.lower():钢
+2.941,-1:word.lower():讯
+2.738,-1:word.lower():原
+2.673,-1:word.lower():任
+2.670,-1:word.lower():见
+2.625,+1:word.lower():禁
+2.412,-1:word.lower():胜
… 1409 more positive …,… 1409 more positive …
… 403 more negative …,… 403 more negative …

Weight?,Feature
+3.107,-1:word.lower():爵
+2.991,+1:word.lower():寰
+2.833,-1:word.lower():敖
+2.338,+1:word.lower():秘
… 3257 more positive …,… 3257 more positive …
… 1377 more negative …,… 1377 more negative …
-2.208,-1:word.lower():会
-2.264,+1:word.lower():本
-2.475,-1:word.lower():处
-2.561,-1:word.lower():社

Weight?,Feature
+5.549,-1:word.lower():臣
+4.625,-1:word.lower():锋
+4.295,-1:word.lower():姓
+4.242,-1:word.lower():卿
+3.965,+1:word.lower():先
+3.776,-1:word.lower():号
+3.674,-1:word.lower():席
+3.453,-1:word.lower():授
+3.323,+1:word.lower():娘
+3.300,-1:word.lower():统

Weight?,Feature
+3.116,-1:word.lower():砚
+3.084,-1:word.lower():冯
+2.706,+1:word.lower():／
+2.609,+1:word.lower():先
+2.514,+1:word.lower():介
+2.415,word.lower():瑛
+2.415,word[-2:]:瑛
… 2660 more positive …,… 2660 more positive …
… 685 more negative …,… 685 more negative …
-2.628,+1:word.lower():地


# 5. BIO classification report

In [36]:
# for tagging performance
def bio_classification_report(y_true, y_pred):
    """
    Classification report for a list of BIO-encoded sequences.
    It computes token-level metrics and discards "O" labels.
    
    Note that it requires scikit-learn 0.15+ (or a version from github master)
    to calculate averages properly!
    """
    lb = preprocessing.LabelBinarizer()
    y_true_combined = lb.fit_transform(list(chain.from_iterable(y_true)))
    y_pred_combined = lb.transform(list(chain.from_iterable(y_pred)))
        
    tagset = set(lb.classes_) - {'O'}
    tagset = sorted(tagset, key=lambda tag: tag.split('-', 1)[::-1])
    class_indices = {cls: idx for idx, cls in enumerate(lb.classes_)}
    
    return classification_report(
        y_true_combined,
        y_pred_combined,
        labels = [class_indices[cls] for cls in tagset],
        target_names = tagset,
        
    )

In [37]:
ex_sent = test_sents[20]
print(sent2tokens(ex_sent))
print("predicted:",', '.join(crf.predict_single(sent2features(ex_sent))))
print("Correct:",', '.join(sent2labels(ex_sent)))

['中', '国', '共', '产', '党', '中', '央', '委', '员', '会', '１', '９', '９', '７', '年', '１', '１', '月', '１', '日', '（', '新', '华', '社', '北', '京', '１', '１', '月', '１', '日', '电', '）', '江', '主', '席', '离', '开', '纽', '约', '抵', '波', '士', '顿', '在', '哈', '佛', '大', '学', '发', '表', '重', '要', '演', '讲', '在', '纽', '约', '时', '出', '席', '大', '型', '晚', '宴', '并', '演', '讲', '本', '报', '波', '士', '顿', '１', '１', '月', '１', '日', '电', '记', '者', '陈', '特', '安', '、', '李', '云', '飞', '报', '道', '：', '江', '泽', '民', '主', '席', '一', '行', '今', '天', '上', '午', '乘', '专', '机', '从', '纽', '约', '抵', '达', '波', '士', '顿', '访', '问', '。']
predicted: B-ORG, I-ORG, I-ORG, I-ORG, I-ORG, I-ORG, I-ORG, I-ORG, I-ORG, I-ORG, O, O, O, O, O, O, O, O, O, O, O, B-ORG, I-ORG, I-ORG, B-LOC, I-LOC, O, O, O, O, O, O, O, B-PER, O, O, O, O, B-LOC, I-LOC, O, B-LOC, I-LOC, I-LOC, O, B-ORG, I-ORG, I-ORG, I-ORG, O, O, O, O, O, O, O, B-LOC, I-LOC, O, O, O, O, O, O, O, O, O, O, O, O, B-LOC, I-LOC, I-LOC, O, O, O, O, O, O, O, O, B-PER, I-PER, I-PER, O, B-PER, I-PER, I-PER, 

In [38]:
y_true = y_test
y_pred = []
for sent in test_sents:
    y_pred.append(crf.predict_single(sent2features(sent)))

In [39]:
bio_classification_report(y_true, y_pred)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


'              precision    recall  f1-score   support\n\n       B-LOC       0.89      0.78      0.83      2886\n       I-LOC       0.84      0.68      0.75      4405\n       B-ORG       0.74      0.66      0.70      1331\n       I-ORG       0.76      0.74      0.75      5646\n       B-PER       0.92      0.66      0.77      1973\n       I-PER       0.83      0.85      0.84      3851\n\n   micro avg       0.82      0.74      0.78     20092\n   macro avg       0.83      0.73      0.77     20092\nweighted avg       0.82      0.74      0.78     20092\n samples avg       0.09      0.09      0.09     20092\n'