In [271]:
from __future__ import print_function, division
import numpy as np
from textblob import TextBlob
import codecs
import random
import nltk
from time import time

In [8]:
# DATA = "/Users/Pavel/Code/allen-ai-challenge/data/ck12_tokens.txt"
DATA = "/home/marat/ck12_tokens.txt"

In [12]:
GLOVE_D = 50

In [13]:
GLOVE_FILE = "/home/marat/data/glove.6B/glove.6B.%dd.txt" % GLOVE_D

In [14]:
def read_dict(fname):
    d = {}
    with codecs.open(fname, encoding="utf-8") as f:
        for row in (line.strip().split() for line in f):
            w = row[0]
            v = np.array(row[1:], dtype='float32')
            assert v.shape == (GLOVE_D,)
            d[w] = v
    return d

In [15]:
GLOVE = read_dict(GLOVE_FILE)

In [16]:
len(GLOVE)

400000

In [216]:
def text2vec(text, glove_dict, seq_length, debug=False):
    vecs = []
    for w in text.split():
        try:
            vecs.append(glove_dict[w][np.newaxis, :])
        except KeyError:
            if debug:
                print("%s is not found" % w)
            continue
    if not vecs:
        print("no vector for '%s'" % text)
    rec = np.concatenate(vecs, axis=0).astype('float32')
    if rec.shape[0] > seq_length:
        # trim long sentences
        rec = rec[rec.shape[0] - seq_length:, :]
    elif rec.shape[0] < seq_length:
        # extend short sentences with zeros
        rec = np.vstack([np.zeros((seq_length - rec.shape[0], rec.shape[1])), rec])
    assert rec.shape[0] == seq_length
    return rec

In [19]:
#text2vec("hello , die", GLOVE, 3)

In [151]:
%%time
i=0
sentences = []
with codecs.open(DATA, encoding="utf-8") as f:
    for line in f:
        sentences.append(line.strip())
        i += 1
        
blob = TextBlob(' '.join(sentences))
all_pos_tags = blob.tags        

i = 0
pos_tags = []
for l, line in enumerate(sentences):    
    tags = []
    for word in line.split():
#         if l == 63798:
#             print(word, i, all_pos_tags[i])
        if word.replace(".", "") == all_pos_tags[i][0].replace(".", ""):
            tags.append(all_pos_tags[i][1])
            i += 1
        else:
            tags.append('.')
            
    pos_tags.append(tags)
#     if l == 63798:
#         print(line, tags)
#         break
    
del all_pos_tags

assert len(sentences) == len(pos_tags)

CPU times: user 2min 18s, sys: 1.65 s, total: 2min 20s
Wall time: 3min 38s


CC Coordinating conjunction
CD Cardinal number
DT Determiner
EX Existential there
FW Foreign word
IN Preposition or subordinating conjunction
JJ Adjective
JJR Adjective, comparative
JJS Adjective, superlative
LS List item marker
MD Modal
NN Noun, singular or mass
NNS Noun, plural
NNP Proper noun, singular
NNPS Proper noun, plural
PDT Predeterminer
POS Possessive ending
PRP Personal pronoun
PRP$ Possessive pronoun
RB Adverb
RBR Adverb, comparative
RBS Adverb, superlative
RP Particle
SYM Symbol
TO to
UH Interjection
VB Verb, base form
VBD Verb, past tense
VBG Verb, gerund or present participle
VBN Verb, past participle
VBP Verb, non­3rd person singular present
VBZ Verb, 3rd person singular present
WDT Wh­determiner
WP Wh­pronoun
WP$ Possessive wh­pronoun
WRB Wh­adverb

In [198]:
CORRUPT_WINDOW = 10
def corrupt(sentences, pos_tags, index):
    s = sentences[index].split()
    noun_indices = [i for i, tag in enumerate(pos_tags[index]) if tag.startswith('NN')]
    if not noun_indices:
        return None
    noun_to_replace_index = random.choice(noun_indices)
    all_donor_indices = range(max(0, index-CORRUPT_WINDOW), index) + \
                        range(index+1, min(len(sentences), index+CORRUPT_WINDOW+1))
    donor_index = random.choice(all_donor_indices)
    all_nouns_to_insert = [sentences[donor_index].split()[i] for i, tag in enumerate(pos_tags[donor_index])
                                    if tag.startswith('NN')]
    if not all_nouns_to_insert:
        return None
    noun_to_insert = random.choice(all_nouns_to_insert)
    s[noun_to_replace_index] = noun_to_insert
    return s

In [166]:
c = 75000
sentences[c], corrupt(sentences, pos_tags, c)

(u'hydrogen gas is bubbled through the liquid oil and reacts with the carbon-carbon double bonds present in the long-chain fatty acids .',
 None)

In [180]:
pos_tags[63798]

[u'CD',
 '.',
 '.',
 u'VB',
 u'JJ',
 u'NN',
 '.',
 u'NN',
 '.',
 '.',
 u'VB',
 u'JJ',
 u'NN',
 '.',
 '.']

In [86]:
from nltk.corpus import wordnet as wn

In [75]:
wn.synset("small.a.01").lemmas()[1].antonyms()

[Lemma('large.a.01.big')]

In [267]:
%%time
import theano
import theano.tensor as T
import lasagne

SEQ_LENGTH = 30
MARGIN = 1

l_in = lasagne.layers.InputLayer(shape=(None, GLOVE_D, SEQ_LENGTH))
n = lasagne.layers.Conv1DLayer(l_in, 20, filter_size=3)  # None x 10 x 28
n = lasagne.layers.MaxPool1DLayer(n, 2)  # None x 30 x 14
n = lasagne.layers.Conv1DLayer(l_in, 40, filter_size=3)  # None x 20 x 12
n = lasagne.layers.MaxPool1DLayer(n, 2)  # None x 60 x 6
n = lasagne.layers.Conv1DLayer(l_in, 70, filter_size=3)  # None x 30 x 4
n = lasagne.layers.MaxPool1DLayer(n, 2)  # None x 90 x 2
n = lasagne.layers.reshape(n, ([0], -1))
n = lasagne.layers.DropoutLayer(n, 0.5)
n = lasagne.layers.DenseLayer(n, 100)
n = lasagne.layers.DenseLayer(n, 1)

output = lasagne.layers.get_output(n)
params = lasagne.layers.get_all_params(n)

correct_energy = output[0::2][0]  # 50
corrupt_energy = output[1::2][0]  # 50

energy = T.maximum(0, MARGIN + correct_energy - corrupt_energy).mean()
updates = lasagne.updates.adam(energy, params)

print('Compiling functions ...')
forward_fn = theano.function([l_in.input_var], output)
train_fn = theano.function([l_in.input_var], energy, updates=updates)
test_fn = theano.function([l_in.input_var], energy)

Compiling functions ...
CPU times: user 2.35 s, sys: 152 ms, total: 2.5 s
Wall time: 6.55 s


In [268]:
def energy_fn(txt):
    sents = nltk.sent_tokenize(txt)
    data = [text2vec(s, GLOVE, SEQ_LENGTH).T[np.newaxis] for s in sents]
    return forward_fn(np.concatenate(data, axis=0)).mean()

In [None]:
%%time
BATCH_SIZE = 50
EPOCH_COUNT = 10
indices = np.arange(60000)
for e in range(EPOCH_COUNT):
    epoch_start = time()
    np.random.shuffle(indices)
    errors = []
    for i in xrange(0, indices.shape[0], BATCH_SIZE):
        train_sent_idx = [k for k in indices[i:i+BATCH_SIZE]]
        train_data = []
        for correct_idx in train_sent_idx:
            corrupted = corrupt(sentences, pos_tags, correct_idx)
            if corrupted:
                train_data.append(text2vec(sentences[correct_idx], GLOVE, SEQ_LENGTH).T[np.newaxis])
                train_data.append(text2vec(' '.join(corrupted), GLOVE, SEQ_LENGTH).T[np.newaxis])
        train_data = np.concatenate(train_data, axis=0)
        error = train_fn(train_data)
        errors.append(error)
#         print('\t', np.mean(errors[-1000:]))
    time_passed = time() - epoch_start
    print(e, np.mean(errors), '%.0fsec' % time_passed)

0 1.0 48sec
1

In [263]:
v1 = 'the sun is the main source of energy for the water cycle .'
v2 = 'fossil fuels is the main source of energy for the water cycle .'
v3 = 'clouds is the main source of energy for the water cycle .'
v4 = 'the ocean is the main source of energy for the water cycle .'

# v1 = 'tension has the greatest effect on aiding the movement of blood through the human body .'
# v2 = 'friction has the greatest effect on aiding the movement of blood through the human body .'
# v3 = 'density has the greatest effect on aiding the movement of blood through the human body .'
# v4 = 'gravity has the greatest effect on aiding the movement of blood through the human body .'

print(energy_fn(v1))
print(energy_fn(v2))
print(energy_fn(v3))
print(energy_fn(v4))

10.1283543537
11.7416905774
10.3514525111
11.0380955383


In [265]:
tries = []

with codecs.open("/home/marat/Downloads/training_set_merged.tsv", encoding="utf-8") as f:
    for i, l in enumerate(f):
        q_id, correct, a1, a2, a3, a4 = l.strip().split("\t")
        energies = [energy_fn(v) for v in [a1, a2, a3, a4]]
        guess = "ABCD"[np.argmin(energies)]
#         print(guess, correct, q_id, zip([a1, a2, a3, a4], energies))
        tries.append(guess == correct)
#         if i > 10:
#             break

In [266]:
np.mean(tries)

0.25040000000000001