In [19]:
from __future__ import print_function, division, unicode_literals
import six
import os, sys
from os.path import join
import json
from codecs import open
from collections import defaultdict
from operator import itemgetter
import nltk
import numpy as np
from nltk.corpus import stopwords
import re
import codecs
import random
from time import time

from nltk.stem import SnowballStemmer
stemmer = SnowballStemmer("english")

from nltk.corpus import stopwords


In [2]:
import lucene
from org.apache.lucene.analysis.standard import StandardAnalyzer
from org.apache.lucene.document import Document, Field
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, IndexReader
from org.apache.lucene.search import IndexSearcher
from org.apache.lucene.search import Sort, SortField
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.util import Version
from java.io import File
lucene.initVM()

<jcc.JCCEnv at 0x7f17249ebbb8>

In [3]:
DATA_DIR = join(os.environ['HOME'], 'data/allen-ai-challenge')
WIKI_DIR = join(DATA_DIR, 'wiki_dump')
CK12_DIR = join(DATA_DIR, 'ck12_dump')
TRAINING_SET = join(DATA_DIR, 'training_set.tsv')
VALIDATION_SET = join(DATA_DIR, 'validation_set.tsv')
TRAINING_SET_MERGED = join(DATA_DIR, 'training_set_merged.tsv')
# INDEX_DIR = join(DATA_DIR, 'index-wiki-ck12')
# INDEX_DIR = join(DATA_DIR, 'index-ck12-stem')
# INDEX_DIR = join(DATA_DIR, 'index-all-l_stem_summ')
INDEX_DIR = join(DATA_DIR, 'index-ck12-stem')
SUBMISSION = join(DATA_DIR, 'submissions/lucene_wiki_ck12_17jan.tsv')
VOCABULARY = join(DATA_DIR, 'vocabulary', 'w2v_a2_5.tsv')
SENT_DELIM = ' | '

In [4]:
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english') + '. , ! ? !? ?! ... ; : - â€”'.split())
def cleanup_text(text):
    sents = []
    for s in nltk.sent_tokenize(text):        
#         words = [stemmer.stem(w) for w in nltk.word_tokenize(s.lower()) if w not in stopwords]
        words = [w for w in nltk.word_tokenize(s.lower()) if w not in stopwords]
        if words:
            sents.append(words)
    return SENT_DELIM.join([' '.join(s) for s in sents])
#     return [stemmer.stem(w.text) for w in nlp(text) if not w.is_stop and not w.is_punct and not w.is_space]

In [5]:
%%time
vocab = {}
with open(VOCABULARY, encoding='utf8') as f:
    for word, vec in (line.strip().split('\t', 1) for line in f):
        vocab[word] = np.fromstring(vec, sep='\t')
vocab_dim = 300

CPU times: user 1.15 s, sys: 16 ms, total: 1.17 s
Wall time: 1.17 s


In [None]:
vocab['earth'].shape

In [None]:
cleanup_text('Here is some text. What the heck more place?')

Index Creation
----------

In [6]:
analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
writerConfig = IndexWriterConfig(Version.LUCENE_4_10_1, StandardAnalyzer())
writer = IndexWriter(SimpleFSDirectory(File(INDEX_DIR)), writerConfig)

In [7]:
def add_document(doc_text):
    doc = Document()
    doc.add(Field("text", cleanup_text(doc_text), Field.Store.YES, Field.Index.ANALYZED))
    writer.addDocument(doc)

In [8]:
%%time
for i, fn_short in enumerate(os.listdir(CK12_DIR)):
    fn = join(CK12_DIR, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        ck12_article = json.load(f)
#         content = []
#         for subtitle, paragraph in ck12_article['contents'].items():
#             content.append(subtitle + '. ' + paragraph)
#         add_document(' '.join(content))    

        for subtitle, paragraph in ck12_article['contents'].items():
            add_document(subtitle + '. ' + paragraph)


CPU times: user 18.8 s, sys: 68 ms, total: 18.9 s
Wall time: 17.9 s


In [9]:
doc_count = writer.numDocs()
writer.close()
doc_count

7148

Build NN
----------

In [26]:
import theano
import theano.tensor as T
import lasagne
import lasagne.layers as LL
from lasagne.nonlinearities import elu, rectify

M = 2.0

input_context = LL.InputLayer((None, 300))
input_hyp = LL.InputLayer((None, 300))

# l_context = LL.ReshapeLayer(input_context, ([0], [1]))
# l_hyp = LL.ReshapeLayer(input_hyp, ([0], [1]))

l_diff = LL.ElemwiseMergeLayer([input_context, input_hyp], merge_function=T.sub)
l_mult = LL.ElemwiseMergeLayer([input_context, input_hyp], merge_function=T.mul)
nn = LL.concat([l_diff, l_mult])
nn = LL.DenseLayer(nn, 100, nonlinearity=elu)
nn = LL.DenseLayer(nn, 1, nonlinearity=rectify)
t_output = LL.get_output(nn)[:, 0]

t_must_be_less = t_output[0::2]
t_must_be_more = t_output[1::2]

t_cost = (T.sqr(t_must_be_less) + T.sqr(T.maximum(0, M - t_must_be_more))).mean()

params = LL.get_all_params(nn)

updates = lasagne.updates.adam(t_cost, params)

train_fn = theano.function([input_hyp.input_var, input_context.input_var], t_cost, updates=updates)
cost_fn = theano.function([input_hyp.input_var, input_context.input_var], t_cost)
energy_fn = theano.function([input_hyp.input_var, input_context.input_var], t_output)

Read index
-----------

In [11]:
analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
reader = IndexReader.open(SimpleFSDirectory(File(INDEX_DIR)))
searcher = IndexSearcher(reader)

In [12]:
indexi = np.arange(reader.maxDoc())

In [13]:
def drop_random(seq):
    '''Changes seq!'''
    idx = random.randint(0, len(seq)-1)
    d = seq[idx]
    del seq[idx]
    return d, seq    

def corrupt_context(index, window_nearest, window_farest):    
    artice = reader.document(index)['text']
    sentsA = artice.split(SENT_DELIM)
    if len(sentsA) < 2:
        return None, None
    hypA, restA = drop_random(sentsA)
    others = [i for i in range(index - window_farest, index + window_farest + 1) 
              if (np.abs(index-i) > window_nearest) and 
                 (i >= 0) and (i < doc_count)]
    idxB = random.choice(others)
    sentsB = reader.document(idxB)['text'].split(SENT_DELIM)
    return [(hypA, ' '.join(restA)), (hypA, ' '.join(sentsB))]

def mean_w2v(text):
    vec = np.zeros((vocab_dim,), dtype='float64')
    c = 1
    for w in nltk.word_tokenize(text):
        if w in vocab:
            vec += vocab[w]
            c += 1
    return (vec/c).astype('float32')

In [14]:
%%time
questions = []
with open(TRAINING_SET, encoding='utf8') as f:
    f.readline()
    for line in f:
        qid, q, correct, aa, ab, ac, ad = line.strip().split('\t')

        query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", cleanup_text(q)))
        hits = searcher.search(query, 20).scoreDocs
        doc = reader.document(hits[0].doc)['text']
        q_vec = mean_w2v(doc)
        vecs_context = np.zeros((4, vocab_dim), dtype='float32')
        vecs_context += q_vec
        vecs_hyp = np.zeros((4, vocab_dim), dtype='float32')
        for i, a in enumerate([aa, ab, ac, ad]):
            try:
                query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", cleanup_text(a)))
                hits = searcher.search(query, 20).scoreDocs
                doc = reader.document(hits[0].doc)['text']
            except:
                doc = ''
            vecs_hyp[i] = mean_w2v(doc)
            questions.append((qid, 'ABCD'.index(correct), vecs_hyp, vecs_context))


def check():
    scores = []
    for qid, idx_correct, vecs_hyp, vecs_context in questions:
        energies = energy_fn(vecs_hyp, vecs_context)
        scores.append(np.argmin(energies) == idx_correct)
    return np.mean(scores)

CPU times: user 16.1 s, sys: 256 ms, total: 16.4 s
Wall time: 13.9 s


In [27]:
%%time

BATCH = 20
EPOCHS = 100
    
min_win = [10, 5, 4, 3, 2, 1] + [0] * EPOCHS    

for e in range(EPOCHS):
    time_started = time()
    indices = np.arange(doc_count, dtype=int)
    np.random.shuffle(indices)
    costs = []
    for i in xrange(0, indices.shape[0], BATCH):
        batch_idx = indices[i:i+BATCH]
        batch_hyp = np.zeros((BATCH*2, vocab_dim), dtype='float32')
        batch_context = np.zeros((BATCH*2, vocab_dim), dtype='float32')
        for b, idx_b in enumerate(batch_idx):
            right, corrupted = corrupt_context(int(idx_b), min_win[e], min_win[e] + 5)
            batch_hyp[b*2] = mean_w2v(right[0])
            batch_context[b*2] = mean_w2v(right[1])

            batch_hyp[b*2+1] = mean_w2v(corrupted[0])
            batch_context[b*2+1] = mean_w2v(corrupted[1])
        batch_cost = train_fn(batch_hyp[:b*2], batch_context[:b*2])
        costs.append(batch_cost)
#         print(batch_energy)
    print('%d: %.3f (%.2f%%) in %.0fs' % (e, np.mean(costs), check() * 100, time() - time_started))
    sys.stdout.flush()
    

0: 1.230 (33.64%) in 20s
1: 0.768 (33.44%) in 20s
2: 0.776 (33.04%) in 20s
3: 0.797 (34.44%) in 20s
4: 0.914 (33.28%) in 20s
5: 1.115 (34.08%) in 20s
6: 1.367 (34.84%) in 20s
7: 1.307 (33.68%) in 20s


KeyboardInterrupt: 

In [None]:
reader.document(553)['text']

In [None]:
def check():
    with open(TRAINING_SET, encoding='utf8') as f:
        f.readline()
        for line in f:
            qid, q, correct, aa, ab, ac, ad = line.strip().split('\t')
            
            print(cleanup_text(q))
            break
check()