In [22]:
import gensim
import re
import multiprocessing
import numpy as np
from collections import Counter
import codecs
import nltk

In [23]:
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english') + '. , ! ? !? ?! ... ; : - —'.split())

In [24]:
print(gensim.__version__)

0.12.4


In [25]:
gensim.models.word2vec.FAST_VERSION

1

In [26]:
import os
from os.path import join

In [27]:
DATA_DIR = join(os.environ['HOME'], "data/allen-ai-challenge")

In [28]:
WIKI_DIR = join(DATA_DIR, "parsed_wiki_data")

In [29]:
TRAIN_SET = join(DATA_DIR, "training_set.tsv")

In [30]:
def parse_text(t):
    s = re.sub(r'[^\w\s]', '', t)
    words = [w for w in nltk.word_tokenize(s.lower()) if w not in stopwords]
    return words

In [32]:
parse_text("test a book")

['test', 'book']

In [33]:
def iter_text(directory):
    for fname in os.listdir(directory):
        with codecs.open(os.path.join(directory, fname), encoding="utf-8") as f:
            for l in f:
                r = parse_text(l)
                if r != []:
                    yield r

In [34]:
for i, l in enumerate(iter_text(WIKI_DIR)):
    print(l)
    if i > 2:
        break

[u'churchill', u'inuit', u'kuugjuaq', u'town', u'northern', u'manitoba', u'canada', u'west', u'shore', u'hudson', u'bay', u'roughly', u'110', u'kilometres', u'manitobanunavut', u'border', u'famous', u'many', u'polar', u'bears', u'move', u'toward', u'shore', u'inland', u'autumn', u'leading', u'nickname', u'polar', u'bear', u'capital', u'world', u'helped', u'growing', u'tourism', u'industry']
[u'geographyedit']
[u'churchill', u'located', u'along', u'hudson', u'bay', u'58th', u'parallel', u'north', u'far', u'canadian', u'populated', u'areas', u'located', u'churchill', u'located', u'far', u'towns', u'cities', u'thompson', u'good', u'bit', u'south', u'closest', u'larger', u'settlement', u'province', u'capital', u'winnipeg', u'distant', u'even', u'airplane', u'nine', u'degrees', u'south', u'bit', u'west']
[u'historyedit']


In [35]:
class Sentences(object):
    
    def __init__(self, directory):
        self.directory = directory
        
    def __iter__(self):
        for text in iter_text(self.directory):
            yield text

In [36]:
sentences = Sentences(WIKI_DIR)

In [37]:
multiprocessing.cpu_count()

48

In [None]:
%%time
word_model = gensim.models.Word2Vec(sentences, workers=12,
                                    size=8, iter=20)

In [None]:
len(word_model.vocab)

In [None]:
word_model.most_similar("dwarf")

In [46]:
def quantify_text(t, model):
    words = parse_text(t)
    emb = [word_model[w] for w in words if w in word_model.vocab and len(w) > 0]
    if emb != []:
        return np.mean(emb, axis=0)
    else:
        return np.zeros(model.vector_size)

In [47]:
from gensim import matutils
def similarity(v1, v2):
    return np.dot(matutils.unitvec(np.array(v1)), matutils.unitvec(np.array(v2)))

In [48]:
def range_answers(q, answers, models):
    scores = []
    for model in models:
        question = quantify_text(q, model)
        if (question == 0).all():
            return None
        ps = Counter()
        scores_model = []
        for a in answers:
            a_q = quantify_text(a, model)
            if (a_q == 0).all():
                scores_model.append(0) 
            else:
                si = similarity(question, a_q)
                scores_model.append(si)
        scores.append(scores_model)
    return np.mean(scores, axis=0)

In [75]:
tries = []
with open(TRAIN_SET) as f:
    next(f)
    for i, l in enumerate(f):
        [qid, q, r, aa, ab, ac, ad] = l.strip().split("\t")
        scores = range_answers(q, [aa, ab, ac, ad],
                               [word_model])
        no_scores = (scores == 0).all()
        if not no_scores:
            guess = "ABCD"[np.argmax(scores)]
        else:
            print(q, aa, ab, ac, ad)
        tries.append(1 if guess == r else 0)

A sperm that contains alleles HqT fuses with an egg that contains alleles hqt. Which of the following genotypes will form in the offspring? HHqqTt HhQqTt Hhqqtt HhqqTt


In [76]:
np.mean(tries)

0.3236