In [7]:
import gensim
from nltk import word_tokenize
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
from sklearn.metrics import f1_score

In [16]:
# Function to load the graph from file
def load_graph():
    # Preparing the graph
    graph = defaultdict(list)
    for line in open('data/graph.txt'):
        line = eval(line[:-1])
        graph[line[0]].append([line[1], line[2]])
    return graph


# Function to load the queries from file
# Preparing the queries
def load_queries():
    queries = []
    for line in open('data/annotations.txt'):
        line = eval(line[:-1])
        queries.append(line)
    return queries

graph = load_graph()
queries = load_queries()
print(f'{len(queries)} queries, {len(graph)} nodes loaded')

word2vec_model = gensim.models.Word2Vec.load('data/word2vec_train_dev.dat')
def get_rel_score_word2vecbase(rel, query):
    if rel not in word2vec_model.wv:
        return 0.0
    words = word_tokenize(query.lower())
    w_embs = []
    for w in words:
        if w in word2vec_model.wv:
            w_embs.append(word2vec_model.wv[w])
    return np.mean(cosine_similarity(w_embs, [word2vec_model.wv[rel]]))


56 queries, 286 nodes loaded


In [17]:
queries

[[1,
  'what time zones are there in the us',
  'm.09c7w0',
  [['ns:m.09c7w0', 'ns:location.location.time_zones', '?x']],
  'United States of America',
  [{'AnswerType': 'Entity',
    'AnswerArgument': 'm.027wj2_',
    'EntityName': 'Samoa Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.027wjl3',
    'EntityName': 'Chamorro Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02fqwt',
    'EntityName': 'Central Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02hcv8',
    'EntityName': 'Eastern Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02hczc',
    'EntityName': 'Mountain Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02lcqs',
    'EntityName': 'Pacific Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02lcrv',
    'EntityName': 'Alaska Time Zone'},
   {'AnswerType': 'Entity',
    'AnswerArgument': 'm.02lctm',
    'EntityName': 'Hawaii-Aleutian Time Zone'},
   {'AnswerType': 'Enti

In [24]:
def explore(source: str, query: str, graph: dict, depth_limit: int = 3, score_threshold: int = 0.3) -> list:
    destination_nodes = [inner[1] for lis in graph.values() for inner in lis]

    dist = dict.fromkeys(list(graph.keys()) + destination_nodes, 0)

    prev = dict.fromkeys(list(graph.keys()) + destination_nodes, (0, 0))
    dist[source] = 0
    Q = dict()
    [Q.update({v: dist[v]}) for v in list(graph.keys())+destination_nodes]
    Q[source] = 0.1

    while Q:
        max_key = max(Q, key=Q.get)
        node = (max_key, Q.pop(max_key))
        u = node[0]

        for relation, v in graph[u]:
            if v not in Q.keys():
                continue
            score = get_rel_score_word2vecbase('ns:' + relation, query)
            if score < score_threshold:
                continue
            if prev[v][1] >= depth_limit:
                continue
            if v == source:
                continue

            alt = dist[u] + score

            if alt > dist[v]:
                dist[v] = alt
                prev[v] = (u, prev[v][1] + 1)
                Q[v] += alt

    scores = [(node_id, v) for node_id, v in dist.items() if v > 0]
    pointed_to = set([value[0] for key, value in prev.items()])
    end_points = [node_id for node_id, v in scores if node_id not in pointed_to]
    return end_points

In [25]:
prediction = explore('m.054c1', 'what year did michael jordan get drafted', graph)
prediction.sort()
true_ans = [dictionary.get("AnswerArgument") for dictionary in queries[0][5]]
true_ans.sort()
print(f'Found answers: {prediction}')
print(f'Actual answers:{true_ans}')
print(f'F1: {f1_score(true_ans, prediction, average="macro")}')

Found answers: ['m.01g0dx', 'm.02g1n3', 'm.06jncs', 'm.06kd6y', 'm.07qymj', 'm.08knpp', 'm.0b364c', 'm.0bx8pn', 'm.0j17fx_']
Actual answers:['m.027wj2_', 'm.027wjl3', 'm.02fqwt', 'm.02hcv8', 'm.02hczc', 'm.02lcqs', 'm.02lcrv', 'm.02lctm', 'm.042g7t']
F1: 0.0


In [23]:
from statistics import mean
scores = []

for l in tqdm(queries):
    query = l[1]
    source = l[2]
    prediction = explore(source, query, graph)
    prediction.sort()
    true_ans = [dictionary.get("AnswerArgument") for dictionary in l[5]]
    true_ans.sort()
    TP = len([ans for ans in prediction if ans in true_ans])
    FP = len([ans for ans in prediction if ans not in true_ans])
    FN = len([ans for ans in true_ans if ans not in prediction])
    scores.append(TP/(TP+0.5*(FP+FN)))
print(f'Average F1: {mean(scores)}')

100%|██████████| 56/56 [00:47<00:00,  1.19it/s]

Average F1: 0.31299108725629093



