In [1]:
import sys, os, time
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from gensim.models.wrappers import FastText
import gensim
import word_util as wtil
import fastText
from collections import Counter
torch.set_printoptions(linewidth=120)
np.set_printoptions(linewidth=120, suppress=True)

In [2]:
lang = fastText.load_model('../../fastText/wiki.en.bin')

In [2]:
def cosd(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = F.normalize(x, 2, -1)
    y = F.normalize(y, 2, -1)
    return -x @ y.transpose(-1,-2)/2+.5
def l2(x,y):
    if x.ndimension() == 1:
        x = x.unsqueeze(0)
    if y.ndimension() == 1:
        y = y.unsqueeze(0)
    x = x.unsqueeze(-2)
    y = y.unsqueeze(-3)
    return (x-y).pow(2).mean(-1)

In [7]:
def filter_tokens(s):
    s = s.lower()
    if s[-1] in {'.','?'}:
        s = s[:-1]
    s = s.split(' ')
    return s
def topk(query,k=5):
    picks = wtil.tfidf(Counter(filter_tokens(query)),full_bag)[:k]
    return [w for w,s in picks]

In [23]:
ds_name = 'elem'
if ds_name == 'elem':
    root = '../data/questions/AI2-Elementary-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_elem_tokens_emb.pth.tar'
elif ds_name == '8th':
    root = '../data/questions/AI2-8thGr-NDMC-Feb2016-Train.jsonl'
    lookup = '../../train_8thgr_tokens_emb.pth.tar'
else:
    raise Exception('unknown dataset')
    
questions = wtil.load_questions(root)
lookup = torch.load(lookup)

full_bag = lookup['bag']

lookup = dict(zip(lookup['words'], lookup['vecs']))
len(questions), len(full_bag), len(lookup)

(432, 2805, 2805)

In [9]:
# questions = wtil.load_questions(root)
# full_bag = Counter()
# for q in questions:
#     tokens = set(filter_tokens(q['question']['stem']))
#     for a in q['question']['choices']:
#         tokens.update(filter_tokens(a['text']))
#     full_bag.update(tokens)
# len(full_bag)

2805

In [12]:
table = torch.load('../../fast_table.pth.tar')
rows = table['rows']
elements = np.array(table['elements'])
vecs = torch.from_numpy(table['vecs']).float()
table = dict(zip(elements,vecs))
len(table.keys()), vecs.shape

(46460, torch.Size([46460, 300]))

In [13]:
mentions = {}
for i, row in enumerate(rows):
    for w in row:
        if w not in mentions:
            mentions[w] = []
        mentions[w].append(i)

In [14]:
def get_connections(picks):
    matches = set()
    for q in picks:
        matches.update(mentions[q])
    return matches

In [15]:
def get_closest(query, vecs, k=2):
    
    D = l2(query, vecs)
    return torch.topk(D,k,dim=-1,largest=False, sorted=False)

def convert(words, lang):
    return torch.from_numpy(np.stack([lang.get_word_vector(w) for w in words])).float()

In [16]:
q = questions[11]

In [17]:
words = topk(q['question']['stem'])
words

['carry', 'berries', 'parent', 'another', 'seeds']

In [18]:
v = convert(words, lang)
v.shape

NameError: name 'lang' is not defined

In [188]:
cls = get_closest(v, vecs)[1]
cls.shape

torch.Size([5, 2])

In [189]:
cls

tensor([[  885, 19549],
        [   79, 30609],
        [ 2134, 31223],
        [21995, 30439],
        [16914, 42276]])

In [190]:
conns = get_connections(elements[cls].reshape(-1))
len(conns)

2868

In [191]:
wopts = set()
for i in conns:
    wopts.update(rows[i])
wopts = list(wopts)
len(wopts)

1999

In [192]:
opts = torch.from_numpy(np.stack([table[w] for w in wopts])).float()
opts.shape

torch.Size([1999, 300])

In [196]:
lbls = []
for a in q['question']['choices']:
    lbl = a['label']
    v = convert(topk(a['text']), lang).view(-1,300)
    nb = get_closest(v, opts, k=10)[0]
    conf = 1/nb.mean()
    lbls.append((lbl,conf))

In [197]:
lbls

[('A', tensor(15.7981)),
 ('B', tensor(18.0087)),
 ('C', tensor(17.7777)),
 ('D', tensor(17.0350))]

In [199]:
sol = sorted(lbls, key=lambda x: x[1])[-1][0]

'B'

In [225]:
def solve(q):
    
    words = topk(q['question']['stem'])
    
    v = convert(words, lang)
    
    cls = get_closest(v, vecs)[1]
    
    conns = get_connections(elements[cls].reshape(-1))
    
    wopts = set()
    for i in conns:
        wopts.update(rows[i])
    wopts = list(wopts)
    
    opts = torch.from_numpy(np.stack([table[w] for w in wopts])).float()
    
    lbls = []
    for a in q['question']['choices']:
        lbl = a['label']
        v = convert(topk(a['text']), lang).view(-1,300)
        nb = get_closest(v, opts, k=10)[0]
        conf = 1/nb.mean()
        lbls.append((lbl,conf))
    
    sol = sorted(lbls, key=lambda x: x[1])[-1][0]
    return sol

In [226]:
true = [q['answerKey'] for q in questions]

In [227]:
sols = []
correct = 0
for i, q in enumerate(questions):
    sol = solve(q)
    if sol == true[i]:
        correct += 1
    sols.append(sol)
    if i % 10 == 0:
        print('{}/{} {:.4f}'.format(i+1,len(questions), correct/(i+1)))

1/293 1.0000
11/293 0.4545
21/293 0.3810
31/293 0.3871
41/293 0.3415
51/293 0.2941
61/293 0.3115
71/293 0.3239
81/293 0.3333
91/293 0.3407
101/293 0.3465
111/293 0.3423
121/293 0.3636
131/293 0.3511
141/293 0.3404
151/293 0.3245
161/293 0.3168
171/293 0.3275
181/293 0.3370
191/293 0.3298
201/293 0.3234
211/293 0.3318
221/293 0.3258
231/293 0.3160
241/293 0.3154
251/293 0.3147
261/293 0.3180
271/293 0.3137
281/293 0.3025
291/293 0.3058


In [228]:
print('Done {:.4f}'.format(correct/(i+1)))

Done 0.3072


In [229]:
print(''.join(true))

BDDBCCDCBDACDBDCDADBDBABDDBBDCBCBCCBCBCACDBDBABCBAAABBABCADDCBABCBCBDDCCCDCAABCDADABDADCCACCCDDCBBACDBCACCAAABABDCABCCDCDDACBBCBADDBCCACACCDADADDAACDBBCADCCBBDADBBDAABACABADCABDDDCDCBBDDDDACADAACBCACACBBBCBDCBCCACCCBCCACCADDCBBBABCABBACACDCBCCBCDCADADBABDACADABBDDCBBDCBDADBBCCCCCBBDCDCDDBCBCD


In [230]:
print(''.join(sols))

BDDADABDBDBADCDCADCDCABBCABCBCBACBDDBCCDCDDCADCAABDABBDCBBCDABDBDBCDCAACDABCABCBDDADDCCAAABCAABBBBBCCCABCCCCCBCCDDABCADDDAABAADDDDBCACDACADDBCDBBDCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDADDBBBAADDACACADBDAAD


In [None]:
# elem: 0.2894
# true: CCACACCAABCBCCCDDABAAADADCBACBDCDCDDAACACDABAAACCDACBDBAACBCDAABCBABDADAACDBABBCDBADCADDACDBBCADBBDDDABACDDCACDCDAACCABDCBADAACDDBBADBDDACBCBBBDAACBBBBCDDABDBCDBDCDDCDDACCCACCCDBCBAAADCADCBBDADADCBBACCBBBCDBABADBACAACCBCDCBBCCBBADDAACCBDCCADCDCACACADDDDCACDDBBADCBBACBBCBADCBADBDBDAACBDCCBBACAADDCBDDDDDCBBBACADAADCCCBBACCACABCCABCDDCBDDCDDCCBBBABBABCBBBBDBBCCBBCCACBBCAACAAADBCDDCACBCABBDBCBABCDBDBCABCDBDBDAACACCDBDBADBBBBBDBBBBAC
# pred: CBDABDDCBACBCACACCDBBADADDAAAADBADAACBBDACBDCCBCDCABADDACDADDDACDCDCBBBDCBBBCDBCDCDCCCDDACCDBACBBCBBBADBBDBCBCBDAACDCABDBCADCBADCADACACBBBADBCDDABAABACCDBBDAAADBBBBABCDBBAAADCBBBDCBDADBDAAADBCAADCBDBBABCDBABCCADDDDADDDAAADBDBBBDBBBBCCCACCDDBAACBADAAABBACACCAAADDCADCDBBDDDCBCCBCCBDBCBDDCDCAADBDBABDDDCDCBCCDCDADAABAADDBDDAABDDABDBBBABDAABACDAABBCDBACCBDBDDBBDCCAABAACACABDCABCACDBBADCCAACDCADBCADBAABCDCCADCDCADCADDADABDDACBDACBADDB

# 8th: 0.3072
# true: BDDBCCDCBDACDBDCDADBDBABDDBBDCBCBCCBCBCACDBDBABCBAAABBABCADDCBABCBCBDDCCCDCAABCDADABDADCCACCCDDCBBACDBCACCAAABABDCABCCDCDDACBBCBADDBCCACACCDADADDAACDBBCADCCBBDADBBDAABACABADCABDDDCDCBBDDDDACADAACBCACACBBBCBDCBCCACCCBCCACCADDCBBBABCABBACACDCBCCBCDCADADBABDACADABBDDCBBDCBDADBBCCCCCBBDCDCDDBCBCD
# pred: BDDADABDBDBADCDCADCDCABBCABCBCBACBDDBCCDCDDCADCAABDABBDCBBCDABDBDBCDCAACDABCABCBDDADDCCAAABCAABBBBBCCCABCCCCCBCCDDABCADDDAABAADDDDBCACDACADDBCDBBDCABBCABDCDCDADCBAAADBADDBBDCDBBDCDDDADBBBCCCAADBBCCDCDBDBBDBDDCDCBACBACADDBBABBDDBDAAACBBACBAABBCDACDADCCBCCDDCAADAABBBBCDDCBDADDBBBAADDACACADBDAAD
