第一步先写一个快速比较字符串，召回的类，使得召回率尽量达到百分之95左右，精度百分之30左右最好
字符串距离函数采取jaccard 计算快
下一步可增加编辑距离、cosine距离 来看一下速度快不快
因为需要召回的速度快，召回的召回率高，因此需要损失掉较高的精度，下一步实体消歧义处理精度的问题

In [1]:
import json
from tqdm import tqdm
import os
from random import choice
from itertools import groupby
import numpy as np 

In [2]:
from tsim import TopSim

In [5]:
class KB(object):
    def __init__(self,kb_directory):
        print("start loading kb_data...")
        self.kb_directory = kb_directory
        self.id2kb,self.types,self.predicate = self.get_id2kb()
        self.kb2id = self.get_kb2id()
        self.kb = list(self.kb2id.keys())
        self.id = list(self.id2kb.keys())
        self.print_info()

    def print_info(self):
        print("KB DATA INFORMATION")
        print("TOKEN SIZE:{}".format(self.get_token_size()))
        print("ID SIZE:{}".format(len(self)))
        print("TYPE SIZE:{}".format(len(self.types)))
        print("PREDICATE SIZE:{}".format(len(self.predicate)))
        
    def get_id2kb(self):
        print("construct id2kb dict...")
        id2kb = {}
        kbtype = set()
        predicate = set()
        multi_type = []
        with open(self.kb_directory) as f:
            for l in tqdm(f):
                tmp = json.loads(l)
                subject_id = tmp['subject_id']
                subject_alias = list(set([tmp['subject']] + tmp.get('alias', [])))
                subject_alias = [alias.lower() for alias in subject_alias]
                subject_type = [i.lower() for i in tmp['type']]
                kbtype.update(subject_type)
                try:
                    assert(len(tmp['type'])==1)
                except AssertionError:
                    multi_type.append(tmp['type'])
                subject_data = {}
                for i in tmp['data']:
                    predicate.add(i['predicate'].lower())
                    subject_data[i['predicate'].lower()] = i['object'].lower()
                if subject_data:
                    id2kb[subject_id] = {'alias': subject_alias, 'data': subject_data,'type':subject_type}
#         print(multi_type)
        return id2kb,kbtype,predicate

    def get_kb2id(self):
        print("construct kb2id dict...")
        kb2id = {}
        for i,j in self.id2kb.items():
            for k in j['alias']:
                if k not in kb2id:
                    kb2id[k] = []
                kb2id[k].append(i)
        return kb2id
    
    def __len__(self):
        return len(self.id2kb)
    
    def get_token_size(self):
        return len(self.kb)

In [6]:
kb_data = KB('./ccks2019_el/kb_data')

1959it [00:00, 19588.85it/s]

start loading kb_data...
construct id2kb dict...


399252it [00:19, 20730.88it/s]


construct kb2id dict...
KB DATA INFORMATION
TOKEN SIZE:303375
ID SIZE:399233
TYPE SIZE:51
PREDICATE SIZE:41841


In [None]:
def load_data(data_directory):
    with open(data_directory) as f:
        train_data,dev_data,test_data = json.load(f)
    print('traindata size:',len(train_data))
    print('devdata size:',len(dev_data))
    print('testdata size',len(test_data))
    return train_data,dev_data,test_data

In [None]:
train_data,dev_data,test_data = load_data('./data/all_data.json')
print(train_data[0])

In [None]:
import jieba

class ngram_search(object):
    
    def __init__(self,data,kb,ngram = 4,similarity = 0.6,k=1,e=0.7,simFunc="tversky"): 
        self.n = ngram
        self.similarity = similarity
        self.data = data
        self.kb = kb
        self.k = k
        self.e = e
        self.simFunc = simFunc
        self.cut_data,self.offset = self.cut_words()
        self.ts = TopSim(self.kb)
        self.candidates = self.get_candidates()
        self.cand_name,self.cand_off,self.cand_with_off = self.get_candidates_name()
        
    def cut_words(self):
        print('starting build ngram list')
        print('ngram',self.n)
        result = []
        offset = []
        for d in tqdm(self.data):
#             print(d)
#             print(' '.join(jieba.cut(d)))
            tmp = list(jieba.cut(d))
            n = len(tmp)
            tmp_off = []
            tmp_off.append((0,len(tmp[0])))
#             tmp_off = [len(''.join(tmp[:i])) for i in range(len(tmp))]
            for i in range(1,len(tmp)):
                tmp_off.append((tmp_off[-1][1],tmp_off[-1][1]+len(tmp[i])))
            for j in range(2,self.n+1):
                for i in range(j-1,n):
                    tmp.append(''.join(tmp[i-j+1:i+1]))
                    tmp_off.append((tmp_off[i-j+1][0],tmp_off[i][1]))
#                     tmp_off.append(''.join(tmp[:i-n+1]))
            result.append(tmp)
            offset.append(tmp_off)
#         print(result[0])
#         print(offset[0])
        return result,offset
    
    def get_candidates(self):
#         self.similarity = similarity
        print('starting build candidates list')
        print('similarity:',self.similarity)
        candidates = []
        for dt in tqdm(self.cut_data):
            ts_result = []
            for i in dt:
                tmp = self.ts.search(i,k = self.k,e = self.e,worstSim = self.similarity,simFunc=self.simFunc)
                if tmp:
#                 if tmp and tmp[0][0] > self.similarity:
                    ts_result.append(tmp)
                else:
                    ts_result.append([])
            candidates.append(ts_result)
        return candidates
                                                     
    def get_candidates_name(self):
        print('starting get candidates name and offset')
        cand_name = []
        cand_offset = []
        cand_with_off = []
        cand_with_offend = []
        for i in tqdm(range(len(self.candidates))):
            cand = []
            off = []
            c_o_s = {}
            c_o_e = {}
            for j in range(len(self.candidates[i])):    
                if self.candidates[i][j]:
#                     print(self.candidates[i][j])
#                     print(self.candidates[i][j][0][1][0])
#                     print(self.kb[self.candidates[i][j][0][1][0]])
                    token = self.kb[self.candidates[i][j][0][1][0]]
                    token_off = self.offset[i][j]
#                     print(token_off)
                    begin = str(token_off[0])
                    end = str(token_off[1]) 
#                     if str(token_off[0]+1) in c_o:
#                         if token == c_o[str(token_off[0]+1)][-1][1]:
#                             continue
#                     if key_value not in c_o:
                    if end in c_o_e:
                        if token in c_o_e[end]:
                            continue
                    else:
                        c_o_e[end] = set()
                    if begin in c_o:
                        if token == c_o_s[begin][-1][1]:
                            continue
                    else:
                        c_o_s[begin] = []
                    cand.append(token)
                    off.append(token_off)
                    c_o_s[begin].append((token_off[1],token))
                    c_o_e[end].add(token)
            cand_name.append(cand)
            cand_offset.append(off)
            cand_with_off.append(c_o_s)
            cand_with_offend.append(c_o_e)
#         print(cand_with_offend[0])
        return cand_name,cand_offset,cand_with_off

In [None]:
dev_x = [i['text'] for i in dev_data]
print(dev_x[0])
en = []
en_off = []
for i in dev_data:
    tmp = []
    _ = []
    for j in i['mention_data']:
        tmp.append(j[0])
        _.append((str(j[1]),j[0]))
    en.append(tmp)
    en_off.append(_)
print(en[0])
print(en_off[0])

In [None]:
ns = ngram_search(dev_x,kb_data.kb)

In [None]:
print(ns.candidates[0])
print(ns.cand_name[0])
print(ns.cand_off[0])
print(ns.cand_with_off[0])

In [None]:
from extratools.mathtools import safediv

In [None]:
print(type(safediv(1,0)))

In [None]:
def evaluate_cand(cand_name,ground_truth):
    recall = []
    precision = []
    f1 = []
    error = []
    for i in tqdm(range(len(ground_truth))):
        tp = 0
        for j in ground_truth[i]:
            if j in cand_name[i]:
                tp += 1
        r = tp/len(ground_truth[i])
        p = tp/len(cand_name[i])
        f = safediv(2*r*p,r+p)
        recall.append(r)
        precision.append(p)
        if  f > 1 or np.isnan(f):
            error.append(i)
            f1.append(0)
        else:
            f1.append(f)
    av_recl = sum(recall)/len(recall)
    av_pre = sum(precision)/len(precision)
    av_f1 = sum(f1)/len(f1)
    print('average recall: {}'.format(av_recl))
    print('average precision: {}'.format(av_pre))
    print('average f1: {}'.format(av_f1))
    print('error number: {}'.format(len(error)))
    print('total number: {}'.format(len(cand_name)))
    print('error rate: {}'.format(len(error)/len(cand_name)))
    return recall, precision,f1,error

In [None]:
r1,p1,f11,error = evaluate_cand(ns.cand_name,en)

In [None]:
for i in error:
    print(dev_data[i])
    print(en[i])
    print(ns.cand_name[i])

In [None]:
a = [(1,2),(2,2)]
print(a[:][1])

In [None]:
def ajust_para(s,e):
    for i in range(1,10):
        p = i/10
        print('similarity',s+p,'e',e-p)
        if s+p <= 0.9:
            ns = ngram_search(dev_x,kb_data.kb,similarity=s+p,e = e-p)
            evaluate_cand(ns.cand_name,en)

In [None]:
ajust_para(0.2,1)

In [None]:
def ajust_para(s,e):
    for i in range(0,5):
        p = i*0.05
        print('similarity',s+p,'e',e-p)
        if s+p <= 0.9:
            ns = ngram_search(dev_x,kb_data.kb,similarity=s+p,e = e-p)
            evaluate_cand(ns.cand_name,en)

In [None]:
ajust_para(0.6,0.8)

In [None]:
def evaluate_cand_off(cand_with_off,ground_truth,cand_name):
    recall = []
    precision = []
    f1 = []
    error = []
    error_off = []
#     pre1 = []
    for i in tqdm(range(len(ground_truth))):
        tp = 0
        for j in ground_truth[i]:
#             print(j)
#             print(cand_with_off[i])
            if j[0] in cand_with_off[i]:
#                 tmp = j[0]
                for k in cand_with_off[i][j[0]]:
#                     print(j[1])
                    if j[1] == k[1]:
                        tp += 1
                        break;
            else:
                if j[0] in cand_name[i]:
                    error_off.append(i)
        r = tp/len(ground_truth[i])
#         p1 = tp/len(cand_name[i])
        p = tp/len(cand_with_off[i].keys())
        f = safediv(2*r*p,r+p)
        recall.append(r)
        precision.append(p)
#         pre1.append(p1)
        if  f > 1 or np.isnan(f):
            error.append(i)
            f1.append(0)
        else:
            f1.append(f)
    av_recl = sum(recall)/len(recall)
    av_pre = sum(precision)/len(precision)
    av_f1 = sum(f1)/len(f1)
#     av_p1 = sum(pre1)/len(pre1)
    print('average recall: {}'.format(av_recl))
    print('average precision: {}'.format(av_pre))
#     print('average p1: {}'.format(av_p1))
    print('average f1: {}'.format(av_f1))
    print('error number: {}'.format(len(error)))
    print('total number: {}'.format(len(cand_name)))
    print('error rate: {}'.format(len(error)/len(cand_name)))
    return recall, precision,f1,error,error_off

In [None]:
ns = ngram_search(dev_x,kb_data.kb,similarity=0.5,e = 0.6)
r,p,f1,er,er_o = evaluate_cand_off(ns.cand_with_off,en_off,ns.cand_name)

In [None]:
for i in er:
    print(dev_data[i])
    print(en[i])
    print(ns.cand_name[i])

In [None]:
for i in er_o:
    print(dev_data[i])
    print(en_o[i])
    print(en_off[i])
#     print(ns.cand_name[i])

In [None]:
def ajust_para(s,e):
    for i in range(1,10):
        p = i/10
        print('similarity',s+p,'e',e-p)
        if s+p <= 0.9:
            ns = ngram_search(dev_x,kb_data.kb,similarity=s+p,e = e-p)
            evaluate_cand_off(ns.cand_with_off,en_off,ns.cand_name)

In [None]:
ajust_para(0.2,1)

In [None]:
for i in range(1,10):
    p = i/10
    s = 0.5
    e = 1.2
    print('similarity',s,'e',e-p)
#     if s+p <= 0.9:
    ns = ngram_search(dev_x,kb_data.kb,similarity=s,e = e-p)
    evaluate_cand_off(ns.cand_with_off,en_off,ns.cand_name)

In [None]:
for i in range(1,10):
    p = i/10
    s = 0.5
    e = 0.8
    print('similarity',s,'e',e-p)
#     if s+p <= 0.9:
    ns = ngram_search(dev_x,kb_data.kb,similarity=s,e = e-p)
    evaluate_cand_off(ns.cand_with_off,en_off,ns.cand_name)

In [None]:
train_x = [i['text'] for i in train_data]
print(train_x[0])
train_en = []
for i in train_data:
    tmp = []
    for j in i['mention_data']:
        tmp.append(j[0])
    train_en.append(tmp)
print(train_en[0])