In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import requests
import json
import pandas as pd
from database import db, cursor
from util import json_load, json_dump
import itertools
from tqdm.notebook import tqdm
import csv
from multiprocessing import Pool
from concurrent import futures
import logging

In [3]:
url = 'https://kgtk.isi.edu/similarity_api'
headers = {'User-Agent': 'User-Agent:Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36'} 
proxies = {'http': 'http://114.212.82.174:10809', 'https': 'http://114.212.82.174:10809'}
logging.basicConfig(filename='kgtk_similarity.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')


In [4]:
def call_semantic_similarity(input_file):
    file_name = os.path.basename(input_file)
    files = {
        'file': (file_name, open(input_file, mode='rb'), 'application/octet-stream')
    }
    try:
        resp = requests.post(url, files=files, params={'similarity_types': 'text'}, proxies=proxies, headers=headers)
        s = json.loads(resp.json())
    except Exception as e:
        logging.error(f'Error in calling {url} with {input_file}: {e}')
        s = []
    return s

In [None]:
from collections import defaultdict
cursor.execute('SELECT dataset_id, entity_id FROM `acordar1_metadata_NPR` WHERE entity_id is not NULL;')
acordar1_dataset_entity = defaultdict(list)
for (dataset_id, entity_id) in cursor.fetchall():
    acordar1_dataset_entity[int(dataset_id)].append(entity_id)

cursor.execute('SELECT dataset_id, entity_id FROM `ntcir_metadata_NPR` WHERE entity_id is not NULL;')
ntcir_dataset_entity = defaultdict(list)
for (dataset_id, entity_id) in cursor.fetchall():
    ntcir_dataset_entity[int(dataset_id)].append(entity_id)

In [None]:
acordar1_query_entity = defaultdict(list)
cursor.execute('SELECT query_id, entity_id FROM `acordar1_query_NPR` WHERE entity_id is not NULL;')
for (query_id, entity_id) in cursor.fetchall():
    acordar1_query_entity[str(query_id)].append(entity_id)

ntcir15_query_entity = defaultdict(list)
cursor.execute('SELECT query_id, entity_id FROM `ntcir15_query_NPR` WHERE entity_id is not NULL;')
for (query_id, entity_id) in cursor.fetchall():
    ntcir15_query_entity[str(query_id)].append(entity_id)

ntcir16_query_entity = defaultdict(list)
cursor.execute('SELECT query_id, entity_id FROM `ntcir16_query_NPR` WHERE entity_id is not NULL;')
for (query_id, entity_id) in cursor.fetchall():
    ntcir16_query_entity[str(query_id)].append(entity_id)

In [None]:
len(acordar1_dataset_entity), len(ntcir_dataset_entity), len(acordar1_query_entity), len(ntcir15_query_entity), len(ntcir16_query_entity)

In [None]:
acordar1_dataset_entity.keys()

In [None]:
pair_list = []
for test_collection in ['acordar1', 'ntcir15', 'ntcir16']:
    print('start', test_collection)
    sparse_res_path = f'/home/xxx/code/erm/data/retrieve_results/{test_collection}/candidates/BM25 [m] test_top100_sorted.json'
    sparse_res = json_load(sparse_res_path)
    for key, value in sparse_res.items():
        tmp_list = []
        for i in value[:10]:
            tmp_list.append(str(i[0]))
        sparse_res[key] = tmp_list

    with tqdm(total=len(sparse_res), ncols=100, leave=True) as pbar:
        for query_id, dataset_id_list in sparse_res.items():
            # cursor.execute(f'SELECT entity_id FROM {test_collection}_query_NPR WHERE query_id = "{query_id}";')
            query_uri_list = []
            if test_collection.startswith('ntcir15'):
                query_uri_list.extend(ntcir15_query_entity[query_id])
            elif test_collection.startswith('ntcir16'):
                query_uri_list.extend(ntcir16_query_entity[query_id])
            else:
                query_uri_list.extend(acordar1_query_entity[query_id])
            query_uri_list = list(set(query_uri_list))

            # print('query uri len:', len(query_uri_list))
            dataset_uri_list = []
            for dataset_id in dataset_id_list:
                if test_collection.startswith('ntcir'):
                    dataset_uri_list.extend(ntcir_dataset_entity[int(dataset_id)])
                else:
                    dataset_uri_list.extend(acordar1_dataset_entity[int(dataset_id)])
            dataset_uri_list = list(set(dataset_uri_list))
            # dataset_uri_list = list([i[0] for i in cursor.fetchall()])
            # print('dataset uri len:', len(dataset_uri_list))
            for query_uri, dataset_uri in itertools.product(query_uri_list, dataset_uri_list):
                pair_list.append((query_uri, dataset_uri))
            pbar.update(1)
    # print('pair len:', len(pair_list))
    with open(f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/pairs/{test_collection}_pairs.tsv', 'w+') as f:
        for pair in pair_list:
            f.write(f'{pair[0]}\t{pair[1]}\n')
    print(f'finish {test_collection} pairs')
pair_list = list(set(pair_list))
with open(f'/home/xxx/code/keds/kgtk_similarity/pairs/all_pairs.tsv', 'w+') as f:
    for pair in pair_list:
        f.write(f'{pair[0]}\t{pair[1]}\n')
print(f'finish {test_collection} pairs')

In [None]:
data_path = '/home/xxx/code/reproduce_keds/data/kgtk_similarity/pairs'
with open(f'{data_path}/all_pairs.tsv', 'w+') as fp:
    pairs = set()
    for test_collection in ['acordar1', 'ntcir15', 'ntcir16']:
        with open(f'{data_path}/{test_collection}_pairs.tsv', 'r') as f:
            for line in f:
                line = line.strip()
                if line in pairs:
                    continue
                pairs.add(line)
                fp.write(line + '\n')

In [5]:
target_list = []
error_list = []
def single_run(unhandled_list=None):
    global target_list
    pair_list = []
    similarity_path = f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/similarity/slices'
    pair_path = f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/pairs/all_pairs.tsv'
    for row in csv.reader(open(pair_path, 'r'), delimiter='\t'):
        pair_list.append((row[0], row[1]))
    
    print('pair len:', len(pair_list))
    l = 25

    if unhandled_list:
        for index in unhandled_list:
            index = int(index)
            with open('tmp.tsv', 'w+') as fp:
                fp.write('q1\tq2\n')
                for pair in pair_list[index:index+l]:
                    fp.write(f'{pair[0]}\t{pair[1]}\n')
            try:
                s = call_semantic_similarity('tmp.tsv')
                s = [dict(t) for t in set([tuple(d.items()) for d in s])]
                assert len(s) == l
            except Exception as e:
                logging.error(f'Error in calling {url} with [{index}:{index+l}]: {e} len s: {len(s)}')
                error_list.append(index)
                s = []
            target_list.extend(s)
            logging.info(f'finish {index + l}')
        json_dump(target_list, f'/home/xxx/code/keds/kgtk_similarity/similarity/all_unhandled_similarity.json')
        return
    
    # pair_list = pair_list[110000:]
    with tqdm(total=len(pair_list[110000:]), ncols=100, leave=True) as pbar:
        for i in range(110000, len(pair_list), l):
            with open('tmp.tsv', 'w+') as fp:
                fp.write('q1\tq2\n')
                for pair in pair_list[i:i+l]:
                    fp.write(f'{pair[0]}\t{pair[1]}\n')
            try:
                s = call_semantic_similarity('tmp.tsv')
                s = [dict(t) for t in set([tuple(d.items()) for d in s])]
                assert len(s) == len(pair_list[i:i+l])
            except Exception as e:
                logging.error(f'Error in calling {url} with [{i}:{i+l}]: {e} len s: {len(s)}')
                error_list.append(i)
                s = []
            target_list.extend(s)
            logging.info(f'finish {i + l}')
            if int(i + l) % 10000 == 0:
                target_list = [dict(t) for t in set([tuple(d.items()) for d in target_list])]
                json_dump(target_list[i + l - 10000: i + l], f'{similarity_path}/{i + l}.json')
            pbar.update(l)
    target_list = [dict(t) for t in set([tuple(d.items()) for d in target_list])]
    json_dump(target_list, f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/similarity/all_similarity.json')

In [6]:
single_run() 

pair len: 149099


  0%|                                                                     | 0/39099 [00:00<?, ?it/s]

In [8]:
all_sim = []
similarity_path = f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/similarity/slices'
for file in os.listdir(similarity_path):
    all_sim.extend(json_load(f'{similarity_path}/{file}'))

json_dump(all_sim, f'/home/xxx/code/reproduce_keds/data/kgtk_similarity/similarity/all_similarity.json')

In [7]:
import re

unhandled_list = []
with open('/home/xxx/code/reproduce_keds/kgtk_similarity.log', 'r+') as fp:
    for line in fp.readlines():
        if 'len s:' in line:
            error = re.search(r'\[\d+:\d+\]', line).group(0)
            unhandled_list.append(error[1:-1].split(':')[0])
print(unhandled_list)

[]


In [None]:
single_run(unhandled_list)

In [None]:
# src_path = f'/home/xxx/code/keds/kgtk_similarity/similarity/{test_collection}_rank{rank}_similarity.json'
src_path = '/home/xxx/code/keds/kgtk_similarity/similarity/all_similarity.json'

data = json_load(src_path)
row_list = []
for item in data:
    query_entity_uri = 'http://www.wikidata.org/entity/' + item['q1']
    dataset_entity_uri = 'http://www.wikidata.org/entity/' + item['q2']
    query_entity_label = item['q1_label'] if item['q1_label'] else ''
    dataset_entity_label = item['q2_label'] if item['q2_label'] else ''
    text_similarity = item['text'] if item['text'] else -1
    row_list.append((query_entity_uri, dataset_entity_uri, query_entity_label, dataset_entity_label, float(text_similarity)))

# ntcir_cursor_org.executemany(f'INSERT  INTO {test_collection}_query_dataset_entity_similartity_kgtk \
                        #  (query_entity_uri, dataset_entity_uri, query_entity_label, dataset_entity_label, text_similarity ) \
                        #  VALUES (%s,%s,%s,%s,%s);', row_list)
cursor.executemany(f'INSERT  INTO query_dataset_entity_similartity_kgtk \
                         (query_entity_uri, dataset_entity_uri, query_entity_label, dataset_entity_label, text_similarity ) \
                         VALUES (%s,%s,%s,%s,%s);', row_list)

db.commit()