In [1]:
import spacy
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
import numpy as np
import pickle as pkl
import json
nlp = spacy.load('en')
s_index = json.load(open("data/core_index.json","r"))
sub_embeddings = np.genfromtxt(fname="data/out.tsv", delimiter="\t", skip_header=0)[:,1:]
pred_embeddings = np.genfromtxt(fname="data/filtered_props.tsv", delimiter="\t", skip_header=0)[:,1:]

In [3]:
import requests
import urllib
import faiss
import scipy
from data.nlg.lib import NLG, Config, Utils

In [4]:
#nbrs = NearestNeighbors(10)
#nbrs.fit(sub_embeddings)
index = faiss.IndexFlatIP(200)
print(index.is_trained)
index.add(np.ascontiguousarray(sub_embeddings.astype('float32')))

True


In [5]:
entityuri2id = json.load(open("e2id.json","r"))
predicateuri2id = json.load(open("p2id.json","r"))
id2entity = json.load(open("id2e.json","r"))
id2predicate = json.load(open("id2p.json","r"))

In [22]:
input_sentence = "Who is the father of Barack Obama?"

def run_ner(sentence):
    doc = nlp(sentence)
    ents_extrd = []
    for ent in doc.ents:
        print(ent.label_)
        if ent.label_ in ["PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART"]:
            ents_extrd.append(ent.text)       
    return ents_extrd

def subject_lookup(entity):
    return s_index[entity][0]

def find_relation(sentence):
    return "http://www.wikidata.org/entity/P735"

def get_embeddings(e, r):
    entity_ind = entityuri2id["<"+e+">"]
    relation_index = predicateuri2id["<"+r+">"]
    
    return sub_embeddings[entity_ind], pred_embeddings[relation_index]

def find_objects(emb, number):
    emb = emb.reshape(1,-1)
    #indices = nbrs.kneighbors(emb, return_distance=True
    d, indices = index.search(emb.astype('float32'),30)
    # print(indices)
    print(d)
    uri = [id2entity[str(e)] for e in indices[0]]
    return [get_label(u) for u in uri]

def find_objects_kg(s, p):
    query = """
    SELECT distinct ?label WHERE {{
       <{0!s}> wdt:P{1!s} ?o . ?o rdfs:label ?label FILTER(lang(?label)='en').
    }}
    """.format(s, p.split("http://www.wikidata.org/entity/P")[1])
    address = "http://query.wikidata.org/sparql"
    params = urllib.parse.urlencode({"query": query, "format": "json"})
    r = requests.get(address, params=params, headers={'User-Agent':'CISS2 agent'})
    try:
        results = json.loads(r.text)
        if len(results["results"]["bindings"]) > 0:
            return [p["label"]["value"] for p in results["results"]["bindings"]]
        else:
            return []
    except Exception:
        raise Exception("Smth is wrong with the endpoint")

def get_label(uri):
    query = """
    SELECT ?l WHERE {{
       {0!s} rdfs:label ?l FILTER (lang(?l)='en')
    }}
    """.format(uri)
    address = "http://query.wikidata.org/sparql"
    params = urllib.parse.urlencode({"query": query, "format": "json"})
    r = requests.get(address, params=params, headers={'User-Agent':'CISS2 agent'})
    try:
        results = json.loads(r.text)
        return results["results"]["bindings"][0]["l"]["value"]
    except Exception:
        raise Exception("Smth is wrong with the endpoint")

def nlg(s, p, o):
    CFG = Config('data/nlg/resources/dev.yml').cfg['nlg']
    triple = (s, p, o)
    template_res = NLG.templates(csv_path='data/nlg/' + CFG['dbp2018_verbalisations'], triple=triple, nlg_cfg=CFG, version=2018, lang="en")
    if template_res is None:
        return None
    else:
        return template_res


def dm(sentence):
    
    # run coref
    # entity_uri, relation_uri = run_coref(sentence)
    # if/else
    
    # run NER
    entity = run_ner(sentence)[0]
    print(u"Identified entity: {0!s}".format(entity))
    
    # index lookup
    entity_uri = subject_lookup(entity)
    print(u"Identified entity URI: {0!s}".format(entity_uri))
    
    # relation linking
    relation_uri = find_relation(sentence)
    print(u"Identified relation URI: {0!s}".format(relation_uri))
    
    # get s + p embeddings
    e, r = get_embeddings(entity_uri, relation_uri)
    us = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q766106>"]]
    
    # object emb
    #print(e)
    #print(r)
#     object_embedding = e + r
#     os = sub_embeddings + r
#     scores = e.reshape((1, 200)) @ os.transpose()
#     print([id2entity[str(r)] for r in np.argsort(scores.flatten())[-25:]])
#     ranks = np.argsort(np.argsort(scores.flatten()))
    #print(ranks[entityuri2id["<http://www.wikidata.org/entity/Q766106>"]])
#     print(object_embedding @ us.transpose())
    
    # kNN in embedding space
    #closest_objects = find_objects(object_embedding, 20)
    #print(u"Closest object: {}".format(closest_objects))
    
    # get labels of p, o
    objects = find_objects_kg(entity_uri, relation_uri)
    p_label = get_label("<"+relation_uri+">")
    #o_label = [get_label("<"+o+">") for o in objects]
    
    # verbalize
    answer = nlg(entity, relation_uri, ",".join(objects))
    print(answer)
    return answer

In [None]:
dm(input_sentence)

PERSON
Identified entity: Barack Obama
Identified entity URI: http://www.wikidata.org/entity/Q76
Identified relation URI: http://www.wikidata.org/entity/P735


In [100]:
sub_embeddings.shape

(159995, 200)

In [295]:
from torchbiggraph.model import TranslationOperator, DotComparator
obama = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q76>"]]
spouse = sub_embeddings[entityuri2id["<http://www.wikidata.org/entity/Q13133>"]]
m_rel = pred_embeddings[predicateuri2id["<http://www.wikidata.org/entity/P26>"]]


operator_state_dict = {
        "translation": torch.from_numpy(m_rel).float()
    }
operator = TranslationOperator(200)
operator.load_state_dict(operator_state_dict)
comparator = DotComparator()

scores, _, _ = comparator(
    comparator.prepare(torch.from_numpy(spouse).type(torch.FloatTensor).view(1, 1, 200).expand(1, 159995, 200)),
    comparator.prepare(
        #torch.tensor(sub_embeddings).view(1,159995, 200),
#         operator(
#             dest_embeddings,
#             torch.tensor([rel_type_index]).expand(entity_count),
#         ).view(1, entity_count, 400),
        operator(
            torch.from_numpy(sub_embeddings.astype('float')).float()
        ).view(1, 159995, 200),
    ),
    torch.empty(1, 0, 200).type(torch.FloatTensor),  # Left-hand side negatives, not needed
    torch.empty(1, 0, 200).type(torch.FloatTensor),  # Right-hand side negatives, not needed
)
# Sort the entities by their score
permutation = scores.flatten().argsort(descending=True)
print(permutation)
top5_entities = [get_label(id2entity[str(index.item())]) for index in permutation[:20]]
print(top5_entities)
#print(obama-mother)
#print(m_rel)

tensor([12717,  3091,  6427,  ...,  5274,   277,   856])
['Michelle Obama', 'Barack Obama', 'Hillary Clinton', 'Bill Clinton', 'Richard Nixon', 'George W. Bush', 'Harry S. Truman', 'Dwight D. Eisenhower', 'George H. W. Bush', 'Jimmy Carter', 'Donald Trump', 'Lyndon B. Johnson', 'Ronald Reagan', 'Herbert Hoover', 'Woodrow Wilson', 'Mary Robinson', 'Gerald Ford', 'Sasha Obama', 'Condoleezza Rice', 'Franklin Delano Roosevelt']


{'<http://www.wikidata.org/entity/P131>': 0,
 '<http://www.wikidata.org/entity/P31>': 1,
 '<http://www.wikidata.org/entity/P571>': 2,
 '<http://www.wikidata.org/entity/P18>': 3,
 '<http://www.wikidata.org/entity/P4227>': 4,
 '<http://www.wikidata.org/entity/P856>': 5,
 '<http://www.wikidata.org/entity/P17>': 6,
 '<http://www.wikidata.org/entity/P527>': 7,
 '<http://www.wikidata.org/entity/P3373>': 8,
 '<http://www.wikidata.org/entity/P26>': 9,
 '<http://www.wikidata.org/entity/P6>': 10,
 '<http://www.wikidata.org/entity/P40>': 11,
 '<http://www.wikidata.org/entity/P625>': 12,
 '<http://www.wikidata.org/entity/P21>': 13,
 '<http://www.wikidata.org/entity/P577>': 14,
 '<http://www.wikidata.org/entity/P569>': 15,
 '<http://www.wikidata.org/entity/P106>': 16,
 '<http://www.wikidata.org/entity/P276>': 17,
 '<http://www.wikidata.org/entity/P580>': 18,
 '<http://www.wikidata.org/entity/P19>': 19,
 '<http://www.wikidata.org/entity/P570>': 20,
 '<http://www.wikidata.org/entity/P361>': 21,
 '<ht