In [None]:
from tqdm import tqdm

with open('./data/wikidata5m/wikidata5m_all_triplet.txt', encoding='utf-8') as f:
    triplets = f.readlines()
triplets = [t.strip().split('\t') for t in tqdm(triplets)]

with open('./data/wikidata5m/wikidata5m_entity.txt', encoding='utf-8') as f:
    entities_lines = f.readlines()
entities = {}
for e in tqdm(entities_lines):
    e = e.strip().split('\t')
    idx, alias = e[0], e[1:]
    entities[idx] = alias

with open('./data/wikidata5m/wikidata5m_relation.txt', encoding='utf-8') as f:
    relations_lines = f.readlines()
relations = {}
for r in tqdm(relations_lines):
    r = r.strip().split('\t')
    idx, alias = r[0], r[1:]
    relations[idx] = alias

with open('./data/wikidata5m/wikidata5m_text.txt', encoding='utf-8') as f:
    lines = f.readlines()
desc = {}
for l in tqdm(lines):
    l = l.strip().split('\t')
    desc[l[0]] = l[1]

In [None]:
from nltk.corpus import stopwords

stopw = stopwords.words('english')

def alias2id_entity(a):
    for idx, alias in entities.items():
        if a in alias:
            return idx
    return ''

def alias2id_relation(a):
    for idx, alias in relations.items():
        if a in alias:
            return idx
    return ''

def get_triplets(e_idx):
    res = []
    for t in triplets:
        e1, r, e2 = t
        if e1 == e_idx or e2 == e_idx:
            res.append(t)
    return res

def has_key(key, obj, loose=False):
    keys = [key] + key.split()
    if not loose:
        for k in keys:
            if len(k) == 1 or k in stopw:
                continue
            for aliases in obj:
                for a in aliases:
                    if a == k:
                        return True
    else:
        for k in keys:
            if len(k) == 1 or k in stopw:
                continue
            for aliases in obj:
                for a in aliases:
                    if k in a.split():
                        return True
    return False

def get_triplets_with_key(key, t):
    res = []
    for e1, r, e2 in t:
        if has_key(key, [entities[e1], relations[r], entities[e2]]):
            res.append([e1, r, e2])
    if len(res) > 0:
        return res
    
    for e1, r, e2 in t:
        if has_key(key, [entities[e1], relations[r], entities[e2]], loose=True):
            res.append([e1, r, e2])
    return res

def search(title, key):
    e = alias2id_entity(title)
    t = get_triplets(e)
    # for tt in t:
    #     _, r, e2 = tt
    #     print(relations[r], entities[e2])
    if len(t) > 1:
        t_ = get_triplets_with_key(key, t)
        if len(t_) > 0:
            t = t_

    return t


In [None]:
import json

with open('./data/train.json') as f:
    train = json.load(f)

with open('./data/iid_test.json') as f:
    iid_test = json.load(f)

with open('./data/ood_test.json') as f:
    ood_test = json.load(f)


In [None]:
iid_diff_title = []
prev_title = ''
for sample in iid_test:
    title = sample['title']
    if title != prev_title:
        prev_title = title
        iid_diff_title.append(sample)

ood_diff_title = []
prev_title = ''
for sample in ood_test:
    title = sample['title']
    if title != prev_title:
        prev_title = title
        ood_diff_title.append(sample)

In [None]:
c = 0
e_list_iid = {}
title_not_found = []
for sample in tqdm(iid_diff_title):
    title = sample['title']
    e = alias2id_entity(title)
    e_list_iid[title] = e
    if e == '' or e not in desc.keys():
        title_not_found.append([e, title])
        c += 1
c

In [None]:
import re
import string

def get_replace(sample, e_list, log=False):
    title = sample['title']
    p = sample['passage']
    e = e_list[title]
    # e = alias2id_entity(title)
    if e == '' or e == 'Q1750563':
        return 'NoEntity'
    d = desc[e]

    if log:
        print(p)
        print(d)
        print(p.split())
        print(d.split())

    replace = re.search(r'\[.*?\]', p)
    if not replace:
        return 'No []'

    span = replace.span()
    target = p[span[0]:span[1]]
    if log:
        print(span, target)
    # if len(replace) > 1:
    #     print(e, title, replace)

    p = p.replace(target, '_'.join(target.split()))
    p_clean = ''
    d_clean = ''
    for i, c in enumerate(p):
        if i in list(range(span[0], span[1])) or (c not in string.punctuation and c != '–'):
            p_clean += c
        else:
            p_clean += ' '
    for c in d:
        if c not in string.punctuation and c != '–':
            d_clean += c
        else:
            d_clean += ' '
    
    if log:
        print(p_clean)
        print(d_clean)
        print(p_clean.split())
        print(d_clean.split())

    d_w = d_clean.split()
    p_w = p_clean.split()
    target_idx = -1
    for i, w in enumerate(p_w):
        if w == '_'.join(target.split()):
            target_idx = i
    if target_idx == -1:
        return 'No []'

    num_prefix = 1
    num_suffix = 1
    prefix = []
    i = target_idx - 1
    while i >= 0 and target_idx - i <= num_prefix:
        prefix.insert(0, p_w[i])
        i -= 1
    suffix = []
    i = target_idx + 1
    while i < len(p_w) and i - target_idx <= num_suffix:
        suffix.append(p_w[i])
        i += 1
        
    if log:
        print(prefix, suffix)

    res = ''
    if len(prefix) > 0:
        lp = len(prefix)
        for i in range(len(d_w) - lp):
            flag = True
            count = 0
            for j in range(i, i + lp):
                if d_w[j] == prefix[j - i]:
                    count += 1
                flag = count >= lp
            if flag:
                start = i + lp
                end = -1
                if len(suffix) > 0:
                    ls = len(suffix)
                    for k in range(i + lp + 1, len(d_w) - ls):
                        flag = True
                        count = 0
                        for l in range(k, k + ls):
                            if d_w[l] == suffix[l - k]:
                                count += 1
                            flag = count >= ls
                        if flag:
                            end = k
                            break
                    if flag:
                        # print(start, end)
                        res = ' '.join(d_w[start : end])
                        break
                elif start < len(d_w):
                # if i + lp < len(d_w):
                    res = ' '.join(d_w[start:])
    elif len(suffix) > 0:
        ls = len(suffix)
        for i in range(len(d_w) - ls):
            flag = True
            for j in range(i, i + ls):
                if d_w[j] != suffix[j - i]:
                    flag = False
            if flag and i - 1 >= 0:
                res = ' '.join(d_w[:i])
                break
    
    # res_w = res.split()
    # res = []
    # for w in res_w:
    #     if w not in stopw:
    #         res.append(w)
    # res = ' '.join(res)

    return res

s = iid_diff_title[1]
print(s)
get_replace(s, e_list_iid, log=True)

In [None]:
from tqdm import tqdm

res = []
for s in tqdm(ood_diff_title):
    res.append(get_replace(s, e_list_ood))


In [None]:
with open('ood_question_entity.txt', 'w') as f:
    for i, r in enumerate(res):
        r = '' if r == 'No []' or r == 'NoEntity' else r
        f.write(f'{i}\t{r}\n')

In [None]:
res_add = []
for i, s in tqdm(enumerate(iid_diff_title)):
    if res[i] == '':
        res_add.append([i, get_replace(s, e_list_iid)])