In [1]:
def read_classification_report(filename):
    with open(filename) as f:
        s = f.read()
    lines = [ss.strip().split('     ') for ss in s.split('\n') if ss]
    micro = []
    macro = []
    wmicro = []
    for i, line in enumerate(lines):
        if i % 3 == 0:
            micro.append({'p': line[1], 'r': line[2], 'f1': line[3], 's': line[4]})
        elif i % 3 == 1:
            macro.append({'p': line[1], 'r': line[2], 'f1': line[3], 's': line[4]})
        else:
            wmicro.append({'p': line[1], 'r': line[2], 'f1': line[3], 's': line[4]})
    return {'micro': micro[-1], 'macro': macro[-1], 'wmicro': wmicro[-1]}

d = read_classification_report('classification_base_ep9.txt')
print('base')
print('micro avg', d['micro'])
print('macro avg', d['macro'])
print('wmicro avg', d['wmicro'])

d = read_classification_report('classification_kb_ohe_ep9.txt')
print('kb_ohe')
print('micro avg', d['micro'])
print('macro avg', d['macro'])
print('wmicro avg', d['wmicro'])

d = read_classification_report('classification_kb_ce_ep9.txt')
print('kb_ce')
print('micro avg', d['micro'])
print('macro avg', d['macro'])
print('wmicro avg', d['wmicro'])

base
micro avg {'p': ' 0.764', 'r': '0.762', 'f1': '0.763', 's': '12721'}
macro avg {'p': ' 0.674', 'r': '0.559', 'f1': '0.597', 's': '12721'}
wmicro avg {'p': ' 0.763', 'r': '0.762', 'f1': '0.760', 's': '12721'}
kb_ohe
micro avg {'p': ' 0.755', 'r': '0.758', 'f1': '0.757', 's': '12721'}
macro avg {'p': ' 0.666', 'r': '0.547', 'f1': '0.579', 's': '12721'}
wmicro avg {'p': ' 0.757', 'r': '0.758', 'f1': '0.754', 's': '12721'}
kb_ce
micro avg {'p': ' 0.749', 'r': '0.757', 'f1': '0.753', 's': '12721'}
macro avg {'p': ' 0.647', 'r': '0.564', 'f1': '0.589', 's': '12721'}
wmicro avg {'p': ' 0.747', 'r': '0.757', 'f1': '0.749', 's': '12721'}


In [2]:
with open('crf_outputs/classification_report_crf_beam_fix.txt') as f:
    for l in [[ll.strip() for ll in l.split('     ') if ll] for l in f.readlines() if l.strip()][-3:]:
        print(l[0], {'p': l[1], 'r': l[2], 'f1': l[3], 's': l[4]})

micro avg {'p': '0.846', 'r': '0.667', 'f1': '0.746', 's': '12852'}
macro avg {'p': '0.768', 'r': '0.476', 'f1': '0.567', 's': '12852'}
weighted avg {'p': '0.842', 'r': '0.667', 'f1': '0.736', 's': '12852'}


In [3]:
import random

def select(x, k=10):
    return random.sample(x, k)

def select_dict(x, k=10):
    return sorted(x.items(), key=lambda x: -x[1])[:10]

In [4]:
from sklearn.metrics import confusion_matrix

y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
LABELS = ["ant", "bird", "cat"]
print(list(enumerate(LABELS)))

def show_confusion(y_true, y_pred, LABELS, min_count=0):
    cm = confusion_matrix(y_true, y_pred, labels=LABELS)  # axis=(gold, pred)
    print('gold -> pred : support')
    for gold_i, a in enumerate(cm):
        for pred_i, c in enumerate(a):
            if gold_i != pred_i and c > min_count:
                print(f'{LABELS[gold_i]} -> {LABELS[pred_i]} : {c}')
show_confusion(y_true, y_pred, LABELS, min_count=0)

[(0, 'ant'), (1, 'bird'), (2, 'cat')]
gold -> pred : support
bird -> cat : 1
cat -> ant : 1


In [5]:
from collections import Counter

def read_token_label_list(filename):
    with open(filename) as f:
        sentences = [[l.split('\t') for l in s.split('\n') if len(l.split('\t')) == 3] for s in f.read().split('\n\n')]
    print(len(sentences))
    sentences_err = [s for s in sentences if any(pred != gold for surf, pred, gold in s)]
    print(len(sentences_err))
    sentences_pos = [s for s in sentences if all(pred == gold for surf, pred, gold in s)]
    return sentences_err, sentences_pos

def categorize_errors(sentences_err):
    fns = []
    fps_o = []
    fps_conf = []

    for i, s in enumerate(sentences_err):
        sent = '\n'.join( ' '.join(line) for line in s)
    #     print(sent)
        for surf, pred, gold in s:
            if pred == 'O' and gold != 'O':  # gold != 'O'; FN(ひろいこぼし)
                fns.append({'target':(surf, pred, gold), 'sentence': sent , 'sid': i})
            elif pred != 'O':  # FP
                if gold == 'O':
                    fps_o.append({'target':(surf, pred, gold), 'sentence': sent , 'sid': i})
                elif pred != gold:  # gold != 'O'
                    fps_conf.append({'target':(surf, pred, gold), 'sentence': sent , 'sid': i})
    print(len(sentences_err), len(fns), len(fps_o), len(fps_conf))
    return fns, fps_o, fps_conf


def categorize_errors_detail(fns, fps_conf, fps_o):

    # FN
    fns_counter = Counter(d['target'][2].split('-')[1] for d in fns)
    fns_entries = {d['target'][2].split('-')[1]: [] for d in fns}
    fns_uniq = {d['sid']: [] for d in fns}
    for d in fns:
        fns_uniq[ d['sid'] ].append(d['target'])
        fns_entries[ d['target'][2].split('-')[1] ].append(d['target'])

    # FP: confusion
    fps_conf_g_counter = Counter(d['target'][2].split('-')[1] for d in fps_conf)
    fps_conf_g_entries = {d['target'][2].split('-')[1]: [] for d in fps_conf}
    fps_conf_p_counter = Counter(d['target'][1].split('-')[1] for d in fps_conf if d['target'][1] != 'X')
    fps_conf_p_entries = {d['target'][1].split('-')[1]: [] for d in fps_conf if d['target'][1] != 'X'}
    fps_conf_uniq = {d['sid']: [] for d in fps_conf}

    y_gold, y_pred = [], []
    labels = set()
    fps_confusion = {(d['target'][2], d['target'][1]): [] for d in fps_conf}

    for d in fps_conf:
        fps_conf_uniq[ d['sid'] ].append(d['target'])
        fps_conf_g_entries[ d['target'][2].split('-')[1] ].append(d['target'])
        if d['target'][1] != 'X':
            fps_conf_p_entries[ d['target'][1].split('-')[1] ].append(d['target'])
            y_gold.append(d['target'][2])
            y_pred.append(d['target'][1])
            labels.add(d['target'][2])
            labels.add(d['target'][1])
            fps_confusion[(d['target'][2], d['target'][1])].append(d['target'])
    labels = sorted(labels, key=lambda x: x.split('-')[1]+x[0].split('-')[0])

    # FP: O as positive
    fps_o_counter = Counter(d['target'][1] for d in fps_o)  # .split('-')[1]
    fps_o_entries = {d['target'][1]: [] for d in fps_o}  # .split('-')[1]
    fps_o_uniq = {d['sid']: [] for d in fps_o}
    for d in fps_o:
        fps_o_uniq[ d['sid'] ].append(d['target'])
        fps_o_entries[ d['target'][1] ].append(d['target'])  # .split('-')[1]

    return {
'fn': {
'sid2result': fns_uniq,
'counter_gold': fns_counter,
'entry_map_gold': fns_entries,
},

'fp_confusion': {
'sid2result': fps_conf_uniq,
'counter_gold': fps_conf_g_counter, 'counter_pred': fps_conf_p_counter,
'entry_map_gold': fps_conf_g_entries, 'entry_map_pred': fps_conf_p_entries,
'confusion_info': {'y_gold': y_gold, 'y_pred': y_pred, 'labels': labels, 'confusion_map': fps_confusion}
},
    
'fp_o': {
'sid2result': fps_o_uniq,
'counter_pred': fps_o_counter,
'entry_map_pred': fps_o_entries,
}
    }


def error_analysis(filename):
    sentences_err, sentences_pos = read_token_label_list(filename)
    fns, fps_o, fps_conf = categorize_errors(sentences_err)
    return categorize_errors_detail(fns, fps_conf, fps_o)


In [8]:

def read_token_label_list(filename):
    with open(filename) as f:
        sentences = [[l.split('\t') for l in s.split('\n') if len(l.split('\t')) == 3] for s in f.read().split('\n\n')]
    print(len(sentences))
    sentences_err = [s for s in sentences if any(pred != gold for surf, pred, gold in s)]
    print(len(sentences_err))
    sentences_pos = [s for s in sentences if all(pred == gold for surf, pred, gold in s)]
    return sentences_err, sentences_pos

filename = 'output_result_base_ep9/token_label_list_epoch9_beam_fix.txt'
sentences_err, sentences_pos = read_token_label_list(filename)

filename = 'crf_outputs/token_label_pred_crf_beam_fix.txt'
sentences_err_crf, sentences_pos_crf = read_token_label_list(filename)


3369
1559
3374
1601


In [9]:


def extract_chunk(sentence, from_gold=True):
    chunks = []
    chunk = []
    for surf, pred, gold in sentence:
        if from_gold:
            if '-' in gold:
                bio, netype = gold.split('-')
                if bio == 'B':
                    chunk = [(surf, pred, gold)]
                elif bio == 'I':
                    chunk.append((surf, pred, gold))
            elif chunk:
                chunks.append(chunk)
                chunk = []
        else:
            if '-' in pred:
                bio, netype = pred.split('-')
                if bio == 'B':
                    chunk = [(surf, pred, gold)]
                elif bio == 'I':
                    chunk.append((surf, pred, gold))
            elif chunk:
                chunks.append(chunk)
                chunk = []
    if chunk:
        chunks.append(chunk)
    return chunks


In [21]:
def is_exact_match(chunk):
    return all(pred == gold for _, pred, gold in chunk)

def check_chunks_match(chunks_pos, chunks_err_gold, chunks_err_pred):
    tp = 0
    for chunks in chunks_pos:
        for chunk in chunks:
            assert is_exact_match(chunk)
            tp += 1
        # sentence
    for chunks in chunks_err_gold:
        for chunk in chunks:
            if is_exact_match(chunk):
                tp += 1

    fn = 0
    for chunks in chunks_err_gold:
        for chunk in chunks:
            if not is_exact_match(chunk):
                fn += 1

    fp = 0
    for chunks in chunks_err_pred:
        for chunk in chunks:
            if not is_exact_match(chunk):
                fp += 1

    p = tp / (tp + fp)
    r = tp / (tp + fn)
    f1 = 2 / (1 / p + 1 / r)
    return {'TP': tp, 'FP': fp, 'FN': fn, 'P': p, 'R': r, 'F1': f1}

In [22]:

chunks_pos = [extract_chunk(sentence) for sentence in sentences_pos]
chunks_err_gold = [extract_chunk(sentence) for sentence in sentences_err]
chunks_err_pred = [extract_chunk(sentence, from_gold=False) for sentence in sentences_err]

In [23]:

chunks_pos_crf = [extract_chunk(sentence) for sentence in sentences_pos_crf]
chunks_err_gold_crf = [extract_chunk(sentence) for sentence in sentences_err_crf]
chunks_err_pred_crf = [extract_chunk(sentence, from_gold=False) for sentence in sentences_err_crf]

In [24]:
check_chunks_match(chunks_pos, chunks_err_gold, chunks_err_pred)

{'TP': 5166,
 'FP': 1590,
 'FN': 1663,
 'P': 0.7646536412078153,
 'R': 0.7564797188460975,
 'F1': 0.7605447184394553}

In [25]:
check_chunks_match(chunks_pos_crf, chunks_err_gold_crf, chunks_err_pred_crf)

{'TP': 4713,
 'FP': 752,
 'FN': 2238,
 'P': 0.8623970722781336,
 'R': 0.678031937850669,
 'F1': 0.7591817010309277}

In [28]:
def chunk_error_detail(chunks_err_gold):
    fn_partial = 0  # 部分一致誤り
    fn_confusion = 0  # クラス誤り
    fn_o = 0  # 未抽出誤り
    rest = []
    for chunks in chunks_err_gold:
        for chunk in chunks:
            if any(pred.split('-')[-1] == gold.split('-')[-1] for _, pred, gold in chunk):
                fn_partial += 1
            elif all(pred == 'O' for _, pred, gold in chunk):
                fn_o += 1
            else:
                fn_confusion += 1
                rest.append(chunk)
    return {'partial': fn_partial, 'confusion': fn_confusion, 'O': fn_o}, rest  # {'partial': 2882, 'confusion': 499, 'O': 785}

In [29]:
result, rest = chunk_error_detail(chunks_err_gold)
result

{'partial': 2921, 'confusion': 403, 'O': 772}

In [30]:
result_crf, rest_crf = chunk_error_detail(chunks_err_gold_crf)
result_crf

{'partial': 2336, 'confusion': 243, 'O': 1723}

In [31]:
rest[:10]

[[('京都', 'B-City', 'B-Province')],
 [('小', 'B-Person', 'B-City'), ('鹿田', 'O', 'I-City')],
 [('葛屋', 'B-Ethnic_Group_Other', 'B-Position_Vocation')],
 [('本長', 'B-Person', 'B-Worship_Place'),
  ('谷寺', 'I-Person', 'I-Worship_Place')],
 [('日本', 'B-Country', 'B-Organization_Other'),
  ('三', 'O', 'I-Organization_Other'),
  ('大', 'O', 'I-Organization_Other'),
  ('荒神', 'O', 'I-Organization_Other')],
 [('笠', 'B-GOE_Other', 'B-City')],
 [('笠', 'B-GOE_Other', 'B-City')],
 [('京都', 'B-City', 'B-Province')],
 [('道路', 'B-Corporation_Other', 'B-Government'),
  ('公団', 'O', 'I-Government'),
  ('民営', 'O', 'I-Government'),
  ('化', 'O', 'I-Government'),
  ('委員', 'O', 'I-Government'),
  ('会', 'O', 'I-Government')],
 [('日本', 'B-Nationality', 'B-Country')]]

In [32]:
rest_crf[:10]

[[('京都', 'B-City', 'B-Province')],
 [('京都', 'B-City', 'B-Province')],
 [('道路', 'B-Position_Vocation', 'B-Government'),
  ('公団', 'I-Position_Vocation', 'I-Government'),
  ('民営', 'I-Position_Vocation', 'I-Government'),
  ('化', 'I-Position_Vocation', 'I-Government'),
  ('委員', 'I-Position_Vocation', 'I-Government'),
  ('会', 'O', 'I-Government')],
 [('学習', 'B-Person', 'B-School'), ('院', 'I-Person', 'I-School')],
 [('宮内', 'B-Position_Vocation', 'B-Government'),
  ('庁', 'I-Position_Vocation', 'I-Government')],
 [('宮内', 'B-Government', 'B-Position_Vocation'),
  ('庁', 'I-Government', 'I-Position_Vocation'),
  ('ＯＢ', 'O', 'I-Position_Vocation')],
 [('サッカー', 'B-Sports_Organization_Other', 'B-Position_Vocation'),
  ('日本', 'I-Sports_Organization_Other', 'I-Position_Vocation'),
  ('代表', 'I-Sports_Organization_Other', 'I-Position_Vocation')],
 [('大阪', 'B-City', 'B-Province')],
 [('マルコ', 'B-Company', 'B-Person'), ('ポーロ', 'I-Company', 'I-Person')],
 [('日', 'B-Person', 'B-Country'), ('朝', 'I-Person', 'I

In [190]:

error_detail = error_analysis('output_result_base_ep9/token_label_list_epoch9_greedy_fix.txt')


3369
1566
1566 2217 1578 1120


In [211]:
# 未抽出誤り; 知識依存
fn = error_detail['fn']
len(fn['sid2result']), select_dict(fn['counter_gold'])

(886,
 [('Position_Vocation', 566),
  ('Person', 289),
  ('Company', 161),
  ('Corporation_Other', 143),
  ('Government', 118),
  ('Organization_Other', 112),
  ('GOE_Other', 107),
  ('City', 102),
  ('International_Organization', 84),
  ('Country', 57)])

In [212]:
# FP混同; 文脈・語彙の問題?; KB的に対処可能
# - 正解との部分一致: 処理が面倒（一致するところも保持）
# - クラス誤り

fp_confusion = error_detail['fp_confusion']
len(fp_confusion['sid2result']), select_dict(fp_confusion['counter_gold']), select_dict(fp_confusion['counter_pred'])
#x[0].split('-')[1]+x[0].split('-')[0])

(637,
 [('Position_Vocation', 257),
  ('Person', 106),
  ('Company', 83),
  ('Government', 79),
  ('City', 67),
  ('GOE_Other', 53),
  ('Corporation_Other', 51),
  ('Organization_Other', 47),
  ('Public_Institution', 31),
  ('Sports_Organization_Other', 30)],
 [('Position_Vocation', 224),
  ('Person', 163),
  ('GOE_Other', 85),
  ('Corporation_Other', 84),
  ('Government', 84),
  ('City', 80),
  ('Company', 74),
  ('Public_Institution', 37),
  ('Country', 36),
  ('Province', 32)])

In [213]:
# FP; 過抽出誤り; 文脈・語彙の問題?

fp_o = error_detail['fp_o']
len(fp_o['sid2result']), select_dict(fp_o['counter_pred'])
#x[0].split('-')[1]+x[0].split('-')[0])

(799,
 [('B-Position_Vocation', 244),
  ('B-Person', 213),
  ('I-Position_Vocation', 182),
  ('X', 177),
  ('I-Person', 76),
  ('B-City', 69),
  ('I-GOE_Other', 59),
  ('B-Country', 57),
  ('B-Company', 56),
  ('B-GOE_Other', 40)])

In [214]:
# FN errors
sum([v for k, v in fn['counter_gold'].items()]), sum(v for k, v in fp_confusion['counter_gold'].items())

(2217, 1120)

In [215]:
# FP error
sum([v for k, v in fp_confusion['counter_pred'].items()]), sum(v for k, v in fp_o['counter_pred'].items())

(1105, 1578)

In [220]:
from sklearn_crfsuite import metrics

metrics.flat_classification_report

<function sklearn_crfsuite.metrics.flat_classification_report(y_true, y_pred, labels=None, **kwargs)>

# FN: 未抽出誤り
- 知識ベースで改善する見込みがある

In [None]:
select(list(fn['entry_map_gold'].items()))

In [193]:
select(fn['entry_map_gold']['Person'])

[('[UNK]', 'O', 'B-Person'),
 ('パウンド', 'O', 'B-Person'),
 ('ガトームソン', 'O', 'B-Person'),
 ('Ｐ', 'O', 'I-Person'),
 ('竹蔵', 'O', 'B-Person'),
 ('美紗', 'O', 'B-Person'),
 ('\u3000', 'O', 'I-Person'),
 ('トルシエ', 'O', 'B-Person'),
 ('サザエ', 'O', 'B-Person'),
 ('文公', 'O', 'B-Person')]

# FP混同
## - 正解との部分一致
## - クラス誤り

In [None]:
select(list(fp_confusion['entry_map_gold'].items()))

In [196]:
select(fp_confusion['entry_map_gold']['Person'])

[('武寿', 'B-City', 'B-Person'),
 ('家宝', 'B-Person', 'I-Person'),
 ('赤松', 'B-Family', 'B-Person'),
 ('シー', 'I-Position_Vocation', 'B-Person'),
 ('新太', 'B-Person', 'I-Person'),
 ('教夫', 'B-Person', 'I-Person'),
 ('深町', 'B-City', 'B-Person'),
 ('高崎', 'B-Position_Vocation', 'B-Person'),
 ('新渡戸', 'B-Station', 'B-Person'),
 ('信雄', 'B-Person', 'I-Person')]

In [197]:
# 'confusion_info': {'y_gold': y_gold, 'y_pred': y_pred, 'labels': labels, 'confusion_map': fps_confusion}
confusion_info = fp_confusion['confusion_info']
y_gold, y_pred, labels = confusion_info['y_gold'], confusion_info['y_pred'], confusion_info['labels'], 

fps_confusion_s = {f'{k[0]} -> {k[1]}': v for k, v in confusion_info['confusion_map'].items()}

show_confusion(y_gold, y_pred, labels, min_count=5)

gold -> pred : support
B-City -> B-GOE_Other : 11
B-City -> B-Person : 14
B-City -> B-Province : 13
B-Company -> B-Corporation_Other : 7
B-Company -> B-GOE_Other : 9
B-Company -> B-Person : 9
I-Company -> B-Company : 6
I-Company -> I-GOE_Other : 9
B-Company_Group -> B-Company : 7
I-Corporation_Other -> B-Corporation_Other : 6
I-Corporation_Other -> I-Government : 18
B-Country -> B-Nationality : 6
B-Family -> B-Person : 6
B-GOE_Other -> B-City : 11
I-GOE_Other -> B-GOE_Other : 8
B-GPE_Other -> B-City : 6
B-Government -> B-Corporation_Other : 8
B-Government -> I-Government : 9
B-Government -> B-Position_Vocation : 10
I-Government -> I-Corporation_Other : 20
I-Government -> I-Position_Vocation : 6
I-Organization_Other -> I-Corporation_Other : 6
I-Organization_Other -> I-Government : 7
B-Person -> B-City : 13
B-Person -> B-Company : 10
B-Person -> I-Person : 9
B-Person -> B-Position_Vocation : 7
I-Person -> B-Person : 28
B-Position_Vocation -> B-Person : 20
B-Position_Vocation -> I-Person 

In [198]:
fps_confusion_s['B-Company -> B-GOE_Other']

[('吉野', 'B-GOE_Other', 'B-Company'),
 ('吉野', 'B-GOE_Other', 'B-Company'),
 ('吉野', 'B-GOE_Other', 'B-Company'),
 ('トゥモロー', 'B-GOE_Other', 'B-Company'),
 ('ジャパレン', 'B-GOE_Other', 'B-Company'),
 ('ＳＯＧＯ', 'B-GOE_Other', 'B-Company'),
 ('大丸', 'B-GOE_Other', 'B-Company'),
 ('飯塚', 'B-GOE_Other', 'B-Company'),
 ('ラヂオ', 'B-GOE_Other', 'B-Company')]

# FP; 過抽出誤り

In [None]:
select(list(fp_o['entry_map_pred'].items()))