In [2]:
import sys, os, time
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from gensim.models.wrappers import FastText
import gensim
import word_util as wtil
# import fastText
from collections import Counter
torch.set_printoptions(linewidth=120)
np.set_printoptions(linewidth=120, suppress=True)

In [3]:
# lang = fastText.load_model('../../fastText/wiki.en.bin')

In [4]:
def cosd(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = F.normalize(x, 2, -1)
    y = F.normalize(y, 2, -1)
    return -x @ y.transpose(-1,-2)/2+.5
def l2(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = x.unsqueeze(-2)
    y = y.unsqueeze(-3)
    return (x-y).pow(2).mean(-1)

In [5]:
def filter_tokens(s):
    s = s.lower()
    if s[-1] in {'.','?'}:
        s = s[:-1]
    s = s.split(' ')
    return s
def topk(query,k=5):
    picks = wtil.tfidf(Counter(filter_tokens(query)),full_bag)[:k]
    return [w for w,s in picks]

In [26]:
ds_name = 'elem'
if ds_name == 'elem':
    root = '../data/questions/AI2-Elementary-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_elem_tokens_emb.pth.tar'
elif ds_name == '8th':
    root = '../data/questions/AI2-8thGr-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_8thgr_tokens_emb.pth.tar'
else:
    raise Exception('unknown dataset')
    
questions = wtil.load_questions(root)
lookup = torch.load(lookup)

full_bag = lookup['bag']

lang = dict(zip(lookup['words'], lookup['vecs']))
len(questions), len(full_bag), len(lookup)

(432, 2805, 3)

In [27]:
# questions = wtil.load_questions(root)
# full_bag = Counter()
# for q in questions:
#     tokens = set(filter_tokens(q['question']['stem']))
#     for a in q['question']['choices']:
#         tokens.update(filter_tokens(a['text']))
#     full_bag.update(tokens)
# len(full_bag)

In [28]:
table = torch.load('../../fast_table.pth.tar')
rows = table['rows']
elements = table['elements']#np.array(table['elements'])
vecs = torch.from_numpy(table['vecs']).float()
len(elements), vecs.shape

(46460, torch.Size([46460, 300]))

In [29]:
added_path = '../../filtered.pth.tar'

In [30]:
added = torch.load(added_path)
print('Adding {} tuples'.format(len(added['rows'])))
rows.extend(added['rows'])
elements.extend(added['elements'])
vecs = torch.cat([vecs, torch.from_numpy(added['vecs'])],0)
len(rows), len(elements), vecs.shape

Adding 651 tuples


(283245, 47211, torch.Size([47211, 300]))

In [31]:
elements = np.array(elements)
table = dict(zip(elements,vecs))
len(table.keys()), vecs.shape

(46915, torch.Size([47211, 300]))

In [32]:
mentions = {}
for i, row in enumerate(rows):
    for w in row:
        if w not in mentions:
            mentions[w] = []
        mentions[w].append(i)

In [33]:
def get_connections(picks):
    matches = set()
    for q in picks:
        matches.update(mentions[q])
    return matches

In [34]:
def get_closest(query, vecs, k=2):
    
    D = l2(query, vecs)
    return torch.topk(D,k,dim=-1,largest=False, sorted=False)

def convert(words, lang):
    return torch.stack([lang[w] for w in words])
    return torch.from_numpy(np.stack([lang.get_word_vector(w) for w in words])).float()

In [35]:
def solve(q):
    
    words = topk(q['question']['stem'])
    
    v = convert(words, lang)
    
    cls = get_closest(v, vecs)[1]
    
    conns = get_connections(elements[cls].reshape(-1))
    
    wopts = set()
    for i in conns:
        wopts.update(rows[i])
    wopts = list(wopts)
    
    opts = torch.from_numpy(np.stack([table[w] for w in wopts])).float()
    
    lbls = []
    for a in q['question']['choices']:
        lbl = a['label']
        v = convert(topk(a['text']), lang).view(-1,300)
        nb = get_closest(v, opts, k=10)[0]
        conf = 1/nb.mean()
        lbls.append((lbl,conf))
    
    sol = sorted(lbls, key=lambda x: x[1])[-1][0]
    return sol

In [36]:
true = [q['answerKey'] for q in questions]

In [37]:
sols = []
correct = 0
for i, q in enumerate(questions):
    sol = solve(q)
    if sol == true[i]:
        correct += 1
    sols.append(sol)
    if i % 10 == 0:
        print('{}/{} {:.4f}'.format(i+1,len(questions), correct/(i+1)))

1/432 1.0000
11/432 0.1818
21/432 0.1905
31/432 0.2903
41/432 0.2195
51/432 0.2157
61/432 0.2295
71/432 0.2254
81/432 0.2469
91/432 0.2857
101/432 0.2772
111/432 0.2883
121/432 0.3058
131/432 0.3053
141/432 0.2979
151/432 0.2914
161/432 0.2981
171/432 0.2865
181/432 0.2873
191/432 0.2827
201/432 0.2886
211/432 0.2891
221/432 0.2805
231/432 0.2771
241/432 0.2780
251/432 0.2749
261/432 0.2759
271/432 0.2804
281/432 0.2776
291/432 0.2784
301/432 0.2757
311/432 0.2797
321/432 0.2804
331/432 0.2749
341/432 0.2698
351/432 0.2764
361/432 0.2825
371/432 0.2830
381/432 0.2835
391/432 0.2864
401/432 0.2868
411/432 0.2920
421/432 0.2898
431/432 0.2877


In [38]:
print('Done {:.4f}'.format(correct/(i+1)))

Done 0.2870


In [39]:
print(''.join(true))

CCACACCAABCBCCCDDABAAADADCBACBDCDCDDAACACDABAAACCDACBDBAACBCDAABCBABDADAACDBABBCDBADCADDACDBBCADBBDDDABACDDCACDCDAACCABDCBADAACDDBBADBDDACBCBBBDAACBBBBCDDABDBCDBDCDDCDDACCCACCCDBCBAAADCADCBBDADADCBBACCBBBCDBABADBACAACCBCDCBBCCBBADDAACCBDCCADCDCACACADDDDCACDDBBADCBBACBBCBADCBADBDBDAACBDCCBBACAADDCBDDDDDCBBBACADAADCCCBBACCACABCCABCDDCBDDCDDCCBBBABBABCBBBBDBBCCBBCCACBBCAACAAADBCDDCACBCABBDBCBABCDBDBCABCDBDBDAACACCDBDBADBBBBBDBBBBAC


In [40]:
print(''.join(sols))

CBDABDDCBACBBACACCDBBADADDAAAACBADAACBBDACBDCCDCDCABADDACDADDDACDCACBBBDCBBBCDBCDCDCCADDACCDBACBBCBBBADBBDBCBCBDAACDCABDBCADCBADCADABACBBBADBCDDCBAABACCDBBDAAADBBBBABCDBBAAADCBBBDABDADBDAAADBCAADCBDBBAACDBABCCADDDDADDDAAADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAAADDCADCDBBDDDCBCCBCCBDBCBDDCDCAADBDBABDDDCDCBCCDCCADAABAADDBDDACBDDABDBBBABDACBDCDAABBCDBACCBDBDBBBDCCAABAACACABDCABCACDBBADCCAACDCADBCCBBABBCDCCADCDAADCADDADABBDAABDACBADDB


In [None]:
# elem: 0.2894
# true: CCACACCAABCBCCCDDABAAADADCBACBDCDCDDAACACDABAAACCDACBDBAACBCDAABCBABDADAACDBABBCDBADCADDACDBBCADBBDDDABACDDCACDCDAACCABDCBADAACDDBBADBDDACBCBBBDAACBBBBCDDABDBCDBDCDDCDDACCCACCCDBCBAAADCADCBBDADADCBBACCBBBCDBABADBACAACCBCDCBBCCBBADDAACCBDCCADCDCACACADDDDCACDDBBADCBBACBBCBADCBADBDBDAACBDCCBBACAADDCBDDDDDCBBBACADAADCCCBBACCACABCCABCDDCBDDCDDCCBBBABBABCBBBBDBBCCBBCCACBBCAACAAADBCDDCACBCABBDBCBABCDBDBCABCDBDBDAACACCDBDBADBBBBBDBBBBAC
# pred: CBDABDDCBACBCACACCDBBADADDAAAADBADAACBBDACBDCCBCDCABADDACDADDDACDCDCBBBDCBBBCDBCDCDCCCDDACCDBACBBCBBBADBBDBCBCBDAACDCABDBCADCBADCADACACBBBADBCDDABAABACCDBBDAAADBBBBABCDBBAAADCBBBDCBDADBDAAADBCAADCBDBBABCDBABCCADDDDADDDAAADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAAADDCADCDBBDDDCBCCBCCBDBCBDDCDCAADBDBABDDDCDCBCCDCDADAABAADDBDDAABDDABDBBBABDAABACDAABBCDBACCBDBDDBBDCCAABAACACABDCABCACDBBADCCAACDCADBCADBAABCDCCADCDCADCADDADABDDACBDACBADDB

# elem-full: 0.2847
# pred: CBDABDDCBDCBBACACCDBBADADDCAAADAADBACBDDACBDCCDCDCABADDACDADDDACDCACBBBDCBBBCDBCDCDCCCDDACCDCAABBCBBBADBDDBCBCADAACDCABDBCADCBADCADABACBABADBCDBCBAABACADBBCAAADBBBBABCDBBAAADCABBDCBDADBDAAADBCAADCBDBBCACDBABCCADDDDADDDABADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAABDDCADCDBBDDDCBCCBCCBDBBADDCDCAADBDBABDDDCDCDCCDCDADAABAADDADDACBDDABDCBBABDACBDCDAABACDBACCBDBDBBBDCCAABAACACABDCABBACDBBADCCAACDCADBCCBBABBCDCCADBDAADCADDADABDDAABDACBCDDB

# elem-filter: 0.2870
# pred: CBDABDDCBACBBACACCDBBADADDAAAACBADAACBBDACBDCCDCDCABADDACDADDDACDCACBBBDCBBBCDBCDCDCCADDACCDBACBBCBBBADBBDBCBCBDAACDCABDBCADCBADCADABACBBBADBCDDCBAABACCDBBDAAADBBBBABCDBBAAADCBBBDABDADBDAAADBCAADCBDBBAACDBABCCADDDDADDDAAADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAAADDCADCDBBDDDCBCCBCCBDBCBDDCDCAADBDBABDDDCDCBCCDCCADAABAADDBDDACBDDABDBBBABDACBDCDAABBCDBACCBDBDBBBDCCAABAACACABDCABCACDBBADCCAACDCADBCCBBABBCDCCADCDAADCADDADABBDAABDACBADDB

# 8th: 0.3072
# true: BDDBCCDCBDACDBDCDADBDBABDDBBDCBCBCCBCBCACDBDBABCBAAABBABCADDCBABCBCBDDCCCDCAABCDADABDADCCACCCDDCBBACDBCACCAAABABDCABCCDCDDACBBCBADDBCCACACCDADADDAACDBBCADCCBBDADBBDAABACABADCABDDDCDCBBDDDDACADAACBCACACBBBCBDCBCCACCCBCCACCADDCBBBABCABBACACDCBCCBCDCADADBABDACADABBDDCBBDCBDADBBCCCCCBBDCDCDDBCBCD
# pred: BDDADABDBDBADCDCADCDCABBCABCBCBACBDDBCCDCDDCADCAABDABBDCBBCDABDBDBCDCAACDABCABCBDDADDCCAAABCAABBBBBCCCABCCCCCBCCDDABCADDDAABAADDDDBCACDACADDBCDBBDCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDADDBBBAADDACACADBDAAD

# 8th-full: 0.2969
# pred: CDDADADDBABADCDCADCDDABDCABCBCDACBDDBCCDCDDCADCAABDACBACBBCDABDBDBCCCAACDABCABABDCCADCCAABBCADBCCBDCCCABCCCCCBCCDDABCDDDDAABAADDDDBCACDACCDDBCDDBBCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDABDBBAAADDACACBDBDAAA

# 8th-filter: 0.3106
# pred: BDDADABDBDBADCDCADCDCABBCABCBCBACBDDBCCDCDDCADCAABDABBACBBCDABDBDBCDCAACDABCABCBDCADDCCAAABCAABBBBDCCCABCCCCCBCCDDABCACDDAABAADDDDBCACDACCDDBCDBBBCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDADDBBBAADDACACADBDAAD

