In [1]:
import sys, os, time
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from gensim.models.wrappers import FastText
import gensim
import word_util as wtil
# import fastText
from collections import Counter
torch.set_printoptions(linewidth=120)
np.set_printoptions(linewidth=120, suppress=True)

In [2]:
# lang = fastText.load_model('../../fastText/wiki.en.bin')

In [3]:
def cosd(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = F.normalize(x, 2, -1)
    y = F.normalize(y, 2, -1)
    return -x @ y.transpose(-1,-2)/2+.5
def l2(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = x.unsqueeze(-2)
    y = y.unsqueeze(-3)
    return (x-y).pow(2).mean(-1)

In [4]:
def filter_tokens(s):
    s = s.lower()
    if s[-1] in {'.','?'}:
        s = s[:-1]
    s = s.split(' ')
    return s
def topk(query,k=5):
    picks = wtil.tfidf(Counter(filter_tokens(query)),full_bag)[:k]
    return [w for w,s in picks]

In [5]:
ds_name = '8th'
if ds_name == 'elem':
    root = '../data/questions/AI2-Elementary-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_elem_tokens_emb.pth.tar'
elif ds_name == '8th':
    root = '../data/questions/AI2-8thGr-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_8thgr_tokens_emb.pth.tar'
else:
    raise Exception('unknown dataset')
    
questions = wtil.load_questions(root)
lookup = torch.load(lookup)

full_bag = lookup['bag']

lang = dict(zip(lookup['words'], lookup['vecs']))
len(questions), len(full_bag), len(lookup)

(293, 2481, 3)

In [6]:
# questions = wtil.load_questions(root)
# full_bag = Counter()
# for q in questions:
#     tokens = set(filter_tokens(q['question']['stem']))
#     for a in q['question']['choices']:
#         tokens.update(filter_tokens(a['text']))
#     full_bag.update(tokens)
# len(full_bag)

In [10]:
table = torch.load('../../fast_table.pth.tar')
rows = table['rows']
elements = table['elements']#np.array(table['elements'])
vecs = torch.from_numpy(table['vecs']).float()
len(elements), vecs.shape

(46460, torch.Size([46460, 300]))

In [8]:
oie = torch.load('../../oie_table.pth.tar')
rows.extend(oie['rows'])
elements.extend(oie['elements'])
vecs = torch.cat([vecs, torch.from_numpy(oie['vecs'])],0)
len(rows), len(elements), vecs.shape

(286387, 49346, torch.Size([49346, 300]))

In [11]:
elements = np.array(elements)
table = dict(zip(elements,vecs))
len(table.keys()), vecs.shape

(46460, torch.Size([46460, 300]))

In [69]:
mentions = {}
for i, row in enumerate(rows):
    for w in row:
        if w not in mentions:
            mentions[w] = []
        mentions[w].append(i)

In [70]:
def get_connections(picks):
    matches = set()
    for q in picks:
        matches.update(mentions[q])
    return matches

In [71]:
def get_closest(query, vecs, k=2):
    
    D = l2(query, vecs)
    return torch.topk(D,k,dim=-1,largest=False, sorted=False)

def convert(words, lang):
    return torch.stack([lang[w] for w in words])
    return torch.from_numpy(np.stack([lang.get_word_vector(w) for w in words])).float()

In [72]:
q = questions[12]

In [73]:
words = topk(q['question']['stem'])
words

['working', 'company', 'testing', 'medicine', 'heal']

In [74]:
v = convert(words, lang)
v.shape

torch.Size([5, 300])

In [75]:
cls = get_closest(v, vecs)[1]
cls.shape

torch.Size([5, 2])

In [76]:
cls

tensor([[ 1852, 21195],
        [ 1087, 11452],
        [ 1434, 21867],
        [ 8084, 48455],
        [31850, 23661]])

In [77]:
conns = get_connections(elements[cls].reshape(-1))
len(conns)

1906

In [78]:
wopts = set()
for i in conns:
    wopts.update(rows[i])
wopts = list(wopts)
len(wopts)

1302

In [79]:
opts = torch.from_numpy(np.stack([table[w] for w in wopts])).float()
opts.shape

torch.Size([1302, 300])

In [80]:
lbls = []
for a in q['question']['choices']:
    lbl = a['label']
    v = convert(topk(a['text']), lang).view(-1,300)
    nb = get_closest(v, opts, k=10)[0]
    conf = 1/nb.mean()
    lbls.append((lbl,conf))

In [81]:
lbls

[('A', tensor(17.5232)),
 ('B', tensor(21.1525)),
 ('C', tensor(18.2112)),
 ('D', tensor(28.0311))]

In [82]:
sol = sorted(lbls, key=lambda x: x[1])[-1][0]

In [83]:
def solve(q):
    
    words = topk(q['question']['stem'])
    
    v = convert(words, lang)
    
    cls = get_closest(v, vecs)[1]
    
    conns = get_connections(elements[cls].reshape(-1))
    
    wopts = set()
    for i in conns:
        wopts.update(rows[i])
    wopts = list(wopts)
    
    opts = torch.from_numpy(np.stack([table[w] for w in wopts])).float()
    
    lbls = []
    for a in q['question']['choices']:
        lbl = a['label']
        v = convert(topk(a['text']), lang).view(-1,300)
        nb = get_closest(v, opts, k=10)[0]
        conf = 1/nb.mean()
        lbls.append((lbl,conf))
    
    sol = sorted(lbls, key=lambda x: x[1])[-1][0]
    return sol

In [84]:
true = [q['answerKey'] for q in questions]

In [85]:
sols = []
correct = 0
for i, q in enumerate(questions):
    sol = solve(q)
    if sol == true[i]:
        correct += 1
    sols.append(sol)
    if i % 10 == 0:
        print('{}/{} {:.4f}'.format(i+1,len(questions), correct/(i+1)))

1/293 0.0000
11/293 0.3636
21/293 0.3810
31/293 0.3226
41/293 0.2927
51/293 0.2549
61/293 0.2787
71/293 0.2958
81/293 0.2963
91/293 0.2747
101/293 0.2970
111/293 0.2973
121/293 0.3223
131/293 0.3130
141/293 0.3121
151/293 0.3046
161/293 0.2981
171/293 0.3099
181/293 0.3204
191/293 0.3141
201/293 0.3085
211/293 0.3175
221/293 0.3122
231/293 0.3030
241/293 0.3029
251/293 0.3028
261/293 0.3065
271/293 0.3026
281/293 0.2954
291/293 0.2990


In [86]:
print('Done {:.4f}'.format(correct/(i+1)))

Done 0.2969


In [229]:
print(''.join(true))

BDDBCCDCBDACDBDCDADBDBABDDBBDCBCBCCBCBCACDBDBABCBAAABBABCADDCBABCBCBDDCCCDCAABCDADABDADCCACCCDDCBBACDBCACCAAABABDCABCCDCDDACBBCBADDBCCACACCDADADDAACDBBCADCCBBDADBBDAABACABADCABDDDCDCBBDDDDACADAACBCACACBBBCBDCBCCACCCBCCACCADDCBBBABCABBACACDCBCCBCDCADADBABDACADABBDDCBBDCBDADBBCCCCCBBDCDCDDBCBCD


In [88]:
print(''.join(sols))

CDDADADDBABADCDCADCDDABDCABCBCDACBDDBCCDCDDCADCAABDACBACBBCDABDBDBCCCAACDABCABABDCCADCCAABBCADBCCBDCCCABCCCCCBCCDDABCDDDDAABAADDDDBCACDACCDDBCDDBBCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDABDBBAAADDACACBDBDAAA


In [None]:
# elem: 0.2894
# true: CCACACCAABCBCCCDDABAAADADCBACBDCDCDDAACACDABAAACCDACBDBAACBCDAABCBABDADAACDBABBCDBADCADDACDBBCADBBDDDABACDDCACDCDAACCABDCBADAACDDBBADBDDACBCBBBDAACBBBBCDDABDBCDBDCDDCDDACCCACCCDBCBAAADCADCBBDADADCBBACCBBBCDBABADBACAACCBCDCBBCCBBADDAACCBDCCADCDCACACADDDDCACDDBBADCBBACBBCBADCBADBDBDAACBDCCBBACAADDCBDDDDDCBBBACADAADCCCBBACCACABCCABCDDCBDDCDDCCBBBABBABCBBBBDBBCCBBCCACBBCAACAAADBCDDCACBCABBDBCBABCDBDBCABCDBDBDAACACCDBDBADBBBBBDBBBBAC
# pred: CBDABDDCBACBCACACCDBBADADDAAAADBADAACBBDACBDCCBCDCABADDACDADDDACDCDCBBBDCBBBCDBCDCDCCCDDACCDBACBBCBBBADBBDBCBCBDAACDCABDBCADCBADCADACACBBBADBCDDABAABACCDBBDAAADBBBBABCDBBAAADCBBBDCBDADBDAAADBCAADCBDBBABCDBABCCADDDDADDDAAADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAAADDCADCDBBDDDCBCCBCCBDBCBDDCDCAADBDBABDDDCDCBCCDCDADAABAADDBDDAABDDABDBBBABDAABACDAABBCDBACCBDBDDBBDCCAABAACACABDCABCACDBBADCCAACDCADBCADBAABCDCCADCDCADCADDADABDDACBDACBADDB

# elem-full: 0.2847
# pred: CBDABDDCBDCBBACACCDBBADADDCAAADAADBACBDDACBDCCDCDCABADDACDADDDACDCACBBBDCBBBCDBCDCDCCCDDACCDCAABBCBBBADBDDBCBCADAACDCABDBCADCBADCADABACBABADBCDBCBAABACADBBCAAADBBBBABCDBBAAADCABBDCBDADBDAAADBCAADCBDBBCACDBABCCADDDDADDDABADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAABDDCADCDBBDDDCBCCBCCBDBBADDCDCAADBDBABDDDCDCDCCDCDADAABAADDADDACBDDABDCBBABDACBDCDAABACDBACCBDBDBBBDCCAABAACACABDCABBACDBBADCCAACDCADBCCBBABBCDCCADBDAADCADDADABDDAABDACBCDDB

# 8th: 0.3072
# true: BDDBCCDCBDACDBDCDADBDBABDDBBDCBCBCCBCBCACDBDBABCBAAABBABCADDCBABCBCBDDCCCDCAABCDADABDADCCACCCDDCBBACDBCACCAAABABDCABCCDCDDACBBCBADDBCCACACCDADADDAACDBBCADCCBBDADBBDAABACABADCABDDDCDCBBDDDDACADAACBCACACBBBCBDCBCCACCCBCCACCADDCBBBABCABBACACDCBCCBCDCADADBABDACADABBDDCBBDCBDADBBCCCCCBBDCDCDDBCBCD
# pred: BDDADABDBDBADCDCADCDCABBCABCBCBACBDDBCCDCDDCADCAABDABBDCBBCDABDBDBCDCAACDABCABCBDDADDCCAAABCAABBBBBCCCABCCCCCBCCDDABCADDDAABAADDDDBCACDACADDBCDBBDCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDADDBBBAADDACACADBDAAD

# 8th-full: 0.2969
# pred: CDDADADDBABADCDCADCDDABDCABCBCDACBDDBCCDCDDCADCAABDACBACBBCDABDBDBCCCAACDABCABABDCCADCCAABBCADBCCBDCCCABCCCCCBCCDDABCDDDDAABAADDDDBCACDACCDDBCDDBBCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDABDBBAAADDACACBDBDAAA
