In [60]:
import spacy
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
import numpy as np
import pickle as pkl
import json
nlp = spacy.load('en')
s_index = json.load(open("data/core_index.json","r"))
sub_embeddings = np.genfromtxt(fname="data/out.tsv", delimiter="\t", skip_header=0)[:,1:]
pred_embeddings = np.genfromtxt(fname="data/filtered_props.tsv", delimiter="\t", skip_header=0)[:,1:]

NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=None, n_neighbors=2, p=2,
                 radius=1.0)

In [121]:
import requests
import urllib
import faiss
import scipy

In [170]:
#nbrs = NearestNeighbors(10)
#nbrs.fit(sub_embeddings)
index = faiss.IndexFlatIP(200)
print(index.is_trained)
index.add(np.ascontiguousarray(sub_embeddings.astype('float32')))

True


In [53]:
entityuri2id = json.load(open("e2id.json","r"))
predicateuri2id = json.load(open("p2id.json","r"))
id2entity = json.load(open("id2e.json","r"))
id2predicate = json.load(open("id2p.json","r"))

In [212]:
input_sentence = "Who is the father of Barack Obama?"

def run_ner(sentence):
    doc = nlp(sentence)
    ents_extrd = []
    for ent in doc.ents:
        print(ent.label_)
        if ent.label_ in ["PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART"]:
            ents_extrd.append(ent.text)       
    return ents_extrd

def subject_lookup(entity):
    return s_index[entity][0]

def find_relation(sentence):
    return "http://www.wikidata.org/entity/P25"

def get_embeddings(e, r):
    entity_ind = entityuri2id["<"+e+">"]
    relation_index = predicateuri2id["<"+r+">"]
    
    return sub_embeddings[entity_ind], pred_embeddings[relation_index]

def find_objects(emb, number):
    emb = emb.reshape(1,-1)
    #indices = nbrs.kneighbors(emb, return_distance=True
    d, indices = index.search(emb.astype('float32'),30)
    # print(indices)
    print(d)
    uri = [id2entity[str(e)] for e in indices[0]]
    return [get_label(u) for u in uri]

def get_label(uri):
    query = """
    SELECT ?l WHERE {{
       {0!s} rdfs:label ?l FILTER (lang(?l)='en')
    }}
    """.format(uri)
    address = "http://query.wikidata.org/sparql"
    params = urllib.parse.urlencode({"query": query, "format": "json"})
    r = requests.get(address, params=params, headers={'User-Agent':'CISS2 agent'})
    try:
        results = json.loads(r.text)
        return results["results"]["bindings"][0]["l"]["value"]
    except Exception:
        raise Exception("Smth is wrong with the endpoint")

def nlg(s, p, o):
    return "{0!s} {1!s} {2!s}".format(s,p,o)

def dm(sentence):
    
    # run coref
    # entity_uri, relation_uri = run_coref(sentence)
    # if/else
    
    # run NER
    entity = run_ner(sentence)[0]
    print(u"Identified entity: {0!s}".format(entity))
    
    # index lookup
    entity_uri = subject_lookup(entity)
    print(u"Identified entity URI: {0!s}".format(entity_uri))
    
    # relation linking
    relation_uri = find_relation(sentence)
    print(u"Identified relation URI: {0!s}".format(relation_uri))
    
    # get s + p embeddings
    e, r = get_embeddings(entity_uri, relation_uri)
    us = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q766106>"]]
    
    # object emb
    #print(e)
    #print(r)
    object_embedding = e + r
    os = sub_embeddings + r
    scores = e.reshape((1, 200)) @ os.transpose()
    print([id2entity[str(r)] for r in np.argsort(scores.flatten())[-25:]])
    ranks = np.argsort(np.argsort(scores.flatten()))
    #print(ranks[entityuri2id["<http://www.wikidata.org/entity/Q766106>"]])
#     print(object_embedding @ us.transpose())
    
    # kNN in embedding space
    #closest_objects = find_objects(object_embedding, 20)
    #print(u"Closest object: {}".format(closest_objects))
    
    # get labels of p, o
    p_label = get_label("<"+relation_uri+">")
    #o_label = get_label("<+o+">")
    
    # verbalize
    answer = nlg(entity, p_label, closest_objects)
    
    return answer

In [213]:
dm(input_sentence)

PERSON
Identified entity: Barack Obama
Identified entity URI: http://www.wikidata.org/entity/Q76
Identified relation URI: http://www.wikidata.org/entity/P25
['<http://www.wikidata.org/entity/Q11696>', '<http://www.wikidata.org/entity/Q11201>', '<http://www.wikidata.org/entity/Q34296>', '<http://www.wikidata.org/entity/Q476068>', '<http://www.wikidata.org/entity/Q13133>', '<http://www.wikidata.org/entity/Q35041>', '<http://www.wikidata.org/entity/Q699872>', '<http://www.wikidata.org/entity/Q9582>', '<http://www.wikidata.org/entity/Q9696>', '<http://www.wikidata.org/entity/Q649593>', '<http://www.wikidata.org/entity/Q33866>', '<http://www.wikidata.org/entity/Q22686>', '<http://www.wikidata.org/entity/Q6294>', '<http://www.wikidata.org/entity/Q9640>', '<http://www.wikidata.org/entity/Q35236>', '<http://www.wikidata.org/entity/Q9588>', '<http://www.wikidata.org/entity/Q9916>', '<http://www.wikidata.org/entity/Q23685>', '<http://www.wikidata.org/entity/Q8007>', '<http://www.wikidata.org/ent

NameError: name 'closest_objects' is not defined

In [100]:
sub_embeddings.shape

(159995, 200)

In [177]:
obama = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q76>"]]
mother = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q766106>"]]
m_rel = pred_embeddings[predicateuri2id["<http://www.wikidata.org/entity/P25>"]]
print(obama-mother)
print(m_rel)

[ 0.1075 -0.1261 -0.7455 -0.2402  0.1471  0.2101  0.0534 -0.341  -0.1609
 -0.1139 -0.2344 -0.4218 -0.104  -0.1964 -0.0394  0.2543 -0.0348 -0.7975
 -0.417  -0.6311 -0.3675  0.1926 -0.0577  0.6249  0.1825 -0.0861 -0.3451
  0.1605  0.1689  0.4871 -0.1394 -0.7255 -0.2328 -0.3595  0.2584  0.0932
 -0.6473  0.5285  0.5691  0.4674 -0.4019 -0.2163  0.5829 -0.1235 -0.1218
 -0.3541  0.5138  0.3963 -0.2188  0.2783 -0.4704  0.7306  0.4429 -0.3103
 -0.5473 -0.0343  0.212   0.4081  0.1151 -0.4411  0.1249 -0.206  -0.2597
 -0.5915 -0.9693 -0.4091 -0.1723 -0.5988 -0.3103  0.0166 -0.2886  0.9184
 -0.5986  0.54   -0.3211  0.0288  0.0716 -0.0038 -0.0458 -0.3905  0.3551
 -0.1882  0.16    0.6068 -0.4796 -0.2616 -0.8932  0.4808 -0.2699  0.2524
  0.1405 -0.065   0.3391  0.4548  0.4399 -0.2892 -0.541   0.1967 -0.2979
 -0.188   0.5701 -0.3633 -0.9001  0.1739 -0.3426  0.2632  0.2393  0.6649
  0.3528 -0.174   0.4254 -0.6697 -0.2279  0.9489  0.9914  0.1456 -0.436
 -0.1931 -0.6125 -0.0726  0.0493 -0.2028 -0.6533  0.