In [1]:
import sys
sys.path.append("../module/")
sys.path.append("../learning/")

import pickle
import menconn
from typelinking import *
import time
import marisa_trie
import pickle
from dataset import *
from wikidata_linker_utils.offset_array import OffsetArray
import train_type as tp
from collections import defaultdict
import re
from functools import partial
from tqdm import tqdm_notebook
import os

  from ._conv import register_converters as _register_converters


In [2]:
with open("test_data.pkl", 'rb') as f:
    data = pickle.load(f)

In [15]:
def load_settings(dataroot, lang='en'):
    tagger = tp.SequenceTagger(os.path.join(dataroot,'en_model/'))
    type_oracle = load_oracle_classification(os.path.join(dataroot, "data/classifications/type_classification"))
    trie_index2indices_values, trie_index2indices_counts, trie = load_trie(os.path.join(dataroot, 'data/{}_trie'.format(lang)))
    with open(os.path.join(dataroot, 'data/wikidata/indices2title.pkl'), 'rb') as hdl:
        indices2title = pickle.load(hdl)
    return tagger, indices2title, type_oracle, trie, trie_index2indices_values, trie_index2indices_counts

In [16]:
settings = load_settings('..')

INFO:tensorflow:Restoring parameters from ../en_model/model.ckpt


# 1. LinkProbだけで実行

In [4]:
for d in tqdm_notebook(data):
    if len(d[2]) == 0:
        continue
    sentence = d[1]
    ts = [str(t[0]).lower() for t in d[2]]
    true_entities = [str(t[1]).replace('_', ' ') for t in d[2]]
    entities = run(ts, None, None, indices2title, None, trie, trie_index2indices_values, trie_index2indices_counts, only_link=True)
    preds = []
    for entity in entities:
        if entity is not None:
            preds.append(entity['en'])
        else:
            preds.append(None)
    results += [{'doc_id':d[0], 'mention':x, 'true': y, 'pred': z} for x,y,z in zip(ts, true_entities, preds)]




In [5]:
import pandas as pd
train = [result for result in results if result['doc_id'] <= 946]
testa = [result for result in results if 947 <= result['doc_id'] and result['doc_id'] <= 1162]
testb = [result for result in results if 1163 <= result['doc_id']]

df_train = pd.DataFrame(train)
df_testa = pd.DataFrame(testa)
df_testb = pd.DataFrame(testb)
df = pd.DataFrame(results)

def calc_acc(df):
    matched = df['pred'] == df['true']
    length = df['pred'].shape[0]
    accuracy = float(sum(matched))/float(length)
    return accuracy

print('train:{}, testa:{}, testb:{}, all:{}'.format(
    calc_acc(df_train), calc_acc(df_testa), calc_acc(df_testb), calc_acc(df)))

train:0.6942641105934468, testa:0.6561004784688995, testb:0.6599241466498104, all:0.6822100913218714


# LinkProb + TypeProbで実行

In [6]:
for d in tqdm_notebook(data):
    if len(d[2]) == 0:
        continue
    sentence = d[1]
    ts = [str(t[0]).lower() for t in d[2]]
    true_entities = [str(t[1]).replace('_', ' ') for t in d[2]]
    tokenize = partial(menconn.en_tokenize, ts=ts)
    sent_splits, model_probs = solve_model_probs(sentence, tagger, tokenize=tokenize)
    entities = run(ts, sent_splits, model_probs, indices2title, type_oracle, trie, trie_index2indices_values, trie_index2indices_counts)
    preds = []
    for entity in entities:
        if entity is not None:
            preds.append(entity['en'])
        else:
            preds.append(None)
    results += [{'doc_id':d[0], 'mention':x, 'true': y, 'pred': z} for x,y,z in zip(ts, true_entities, preds)]




In [7]:
import pandas as pd
train = [result for result in results if result['doc_id'] <= 946]
testa = [result for result in results if 947 <= result['doc_id'] and result['doc_id'] <= 1162]
testb = [result for result in results if 1163 <= result['doc_id']]

df_train = pd.DataFrame(train)
df_testa = pd.DataFrame(testa)
df_testb = pd.DataFrame(testb)
df = pd.DataFrame(results)

def calc_acc(df):
    matched = df['pred'] == df['true']
    length = df['pred'].shape[0]
    accuracy = float(sum(matched))/float(length)
    return accuracy

print('train:{}, testa:{}, testb:{}, all:{}'.format(
    calc_acc(df_train), calc_acc(df_testa), calc_acc(df_testb), calc_acc(df)))

train:0.6968991063927289, testa:0.6569976076555024, testb:0.6610303413400759, all:0.6843018213356461


In [8]:
df.to_csv("results.csv", index=False)

In [9]:
df

Unnamed: 0,doc_id,mention,pred,true
0,1,european commission,European Commission,European Commission
1,1,spanish,Spanish language,Spain
2,1,brussels,Brussels,Brussels
3,1,france,France,France
4,1,bonn,Bonn,Bonn
5,1,loyola de palacio,Loyola de Palacio,Loyola de Palacio
6,1,european union,European Union,European Union
7,1,germany,Germany,Germany
8,1,german,Germany,Germany
9,1,commission,Ship commissioning,European Commission


# テスト文章で実行

In [10]:
tokenize = partial(menconn.en_tokenize, ts=ts)

def analyze(sentence, ts, tagger, indices2title, type_oracle, trie, trie_index2indices_values, trie_index2indices_counts, tokenize, lang='en'):
    if not ts:
        return []
    sent_splits, model_probs = solve_model_probs(sentence, tagger, tokenize=tokenize)
    entities = run(ts, sent_splits, model_probs, indices2title, type_oracle, trie, trie_index2indices_values, trie_index2indices_counts)
    preds = []
    for entity in entities:
        if entity is not None:
            preds.append(entity[lang])
        else:
            preds.append(None)
    return [{'mention':x, 'pred': y} for x,y in zip(ts, preds)]

In [17]:
sentence = "The man saw a Jaguar speed on the highway"
ts = ['jaguar']
print(analyze(sentence, ts, *settings))

[{'pred': 'Jaguar Cars', 'mention': 'jaguar'}]


In [18]:
sentence = "The prey saw the jaguar cross the jungle"
ts = ['jaguar']
print(analyze(sentence, ts, *settings))

[{'pred': 'Jaguar', 'mention': 'jaguar'}]


'ciaが'