In [4]:
from __future__ import print_function, division, unicode_literals
import six
import os
from os.path import join
import json
from codecs import open
from collections import defaultdict
from operator import itemgetter
import nltk
import numpy as np
from nltk.corpus import stopwords
import re
import codecs

from nltk.stem import SnowballStemmer
stemmer = SnowballStemmer("english")

In [5]:
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english') + '. , ! ? !? ?! ... ; : - —'.split())

In [6]:
def tokenize(text):
    return [stemmer.stem(w) for w in nltk.word_tokenize(text.lower()) if w not in stopwords]
#     return [stemmer.stem(w.text) for w in nlp(text) if not w.is_stop and not w.is_punct and not w.is_space]

In [7]:
import lucene

In [8]:
from org.apache.lucene.analysis.standard import StandardAnalyzer
from org.apache.lucene.document import Document, Field
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, IndexReader
from org.apache.lucene.search import IndexSearcher
from org.apache.lucene.search import Sort, SortField
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.util import Version

In [9]:
from java.io import File

In [10]:
DATA_DIR = join(os.environ['HOME'], 'data/allen-ai-challenge')
WIKI_DIR = join(DATA_DIR, 'wiki_dump')
CK12_DIR = join(DATA_DIR, 'ck12_dump')
TRAINING_SET = join(DATA_DIR, 'training_set.tsv')
VALIDATION_SET = join(DATA_DIR, 'validation_set.tsv')
TRAINING_SET_MERGED = join(DATA_DIR, 'training_set_merged.tsv')
# INDEX_DIR = join(DATA_DIR, 'index-wiki-ck12')
# INDEX_DIR = join(DATA_DIR, 'index-ck12-stem')
# INDEX_DIR = join(DATA_DIR, 'index-all-l_stem_summ')
INDEX_DIR = join(DATA_DIR, 'index-paragraph')
SUBMISSION = join(DATA_DIR, 'submissions/lucene_wiki_ck12_17jan.tsv')

In [11]:
lucene.initVM()

<jcc.JCCEnv at 0x7f8490fb3df8>

Index Creation
-----------

In [9]:
analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)

In [10]:
writerConfig = IndexWriterConfig(Version.LUCENE_4_10_1, StandardAnalyzer())
writer = IndexWriter(SimpleFSDirectory(File(INDEX_DIR)), writerConfig)

def add_document(doc_text):
    doc = Document()
    doc.add(Field("text", " ".join(tokenize(doc_text)), Field.Store.YES, Field.Index.ANALYZED))
    writer.addDocument(doc)

In [11]:
%%time
for i, fn_short in enumerate(os.listdir(CK12_DIR)):
    fn = join(CK12_DIR, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        ck12_article = json.load(f)
        content = []
        for subtitle, paragraph in ck12_article['contents'].items():
            content.append(subtitle + '. ' + paragraph)
        add_document(' '.join(content))        

CPU times: user 27.8 s, sys: 124 ms, total: 27.9 s
Wall time: 27.1 s


In [12]:
%%time
for i, fn_short in enumerate(os.listdir(WIKI_DIR)):
    fn = join(WIKI_DIR, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        wiki_article = json.load(f)
        _, summary, content = wiki_article
        add_document(summary)
        for p in re.sub('\[ \d* \]', ' ', content).replace('\n\n\n', '.\n ').split('\n\n'):
            add_document(p)


CPU times: user 6min 48s, sys: 688 ms, total: 6min 49s
Wall time: 6min 42s


In [13]:
writer.numDocs()

237951

In [17]:
writer.close()

## Make prediction 

In [12]:
def iter_data(datafile, with_correct=True):
    with open(datafile, encoding='utf-8', errors='ignore') as f:
        next(f)
        for l in f:
            if with_correct:
                idd, q, correct, aa, ab, ac, ad = l.strip().split("\t")
            else:
                idd, q, aa, ab, ac, ad = l.strip().split("\t")
                correct = "no"
            q, aa, ab, ac, ad = [' '.join(tokenize(x)) for x in [q, aa, ab, ac, ad]]
            yield {"idd": idd, "q": q, "correct": correct, "aa": aa, "ab": ab, "ac": ac, "ad": ad}

In [13]:
from collections import defaultdict

In [14]:
%%time
res = defaultdict(list)
MAX = 30
docs_per_q = range(1, 20)

analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
reader = IndexReader.open(SimpleFSDirectory(File(INDEX_DIR)))
searcher = IndexSearcher(reader)

for row in iter_data(TRAINING_SET):
    queries = [row['aa'], row['ab'], row['ac'], row['ad']]
    queries = [row['q'] + ' ' + q  for q in queries]
    scores = defaultdict(list)
    for q in queries:
        query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", q))
        hits = searcher.search(query, MAX)
        doc_importance = [hit.score for hit in hits.scoreDocs]
        for n in docs_per_q:
            scores[n].append(sum(doc_importance[:n]))
      
    for n in docs_per_q:
        res[n].append(['A','B','C','D'][np.argmax(scores[n])] == row["correct"])

CPU times: user 44.7 s, sys: 1.86 s, total: 46.5 s
Wall time: 42.4 s


In [21]:
import telepot
class TelegramStream:
    def __init__(self, token, reciever_id):
        self.bot = telepot.Bot(token)
        self.id = reciever_id
        self.buffer = []
    def write(self, txt):
        self.buffer.append(txt)
        if txt.endswith('\n'):
            msg = ''.join(self.buffer)
            if msg:
                self.bot.sendMessage(self.id, msg)
            self.buffer = []
    
tele = TelegramStream(os.environ['TELEGRAM_BOT'], os.environ['TELEGRAM_ID'])

KeyError: u'TELEGRAM_BOT'

In [15]:
for x in sorted(res):
    print(x, np.mean(res[x]))
#     print(x, np.mean(res[x]), file=tele)

1 0.4208
2 0.4428
3 0.4488
4 0.4452
5 0.4488
6 0.452
7 0.454
8 0.4472
9 0.4496
10 0.4504
11 0.4464
12 0.4492
13 0.4456
14 0.4424
15 0.4448
16 0.4428
17 0.444
18 0.4432
19 0.4432


Submit
-----

In [24]:
%%time
docs_to_consider = 7

analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
reader = IndexReader.open(SimpleFSDirectory(File(INDEX_DIR)))
searcher = IndexSearcher(reader)

with open(SUBMISSION, "w") as s:
    s.write("id,correctAnswer\n")
    for row in iter_data(VALIDATION_SET, False):
        queries = [row['aa'], row['ab'], row['ac'], row['ad']]
        queries = [row['q'] + ' ' + q  for q in queries]
        scores = []
        for q in queries:
            query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", q))
            #query = QueryParser(Version.LUCENE_30, "text", analyzer).parse(re.sub("[/^]", "\^", q))
            hits = searcher.search(query, docs_to_consider)
            doc_importance = [hit.score for hit in hits.scoreDocs]
            scores.append(sum(doc_importance))
        guess = "ABCD"[np.argmax(scores)]
        s.write("%s,%s\n" % (row["idd"], guess))

CPU times: user 1min 31s, sys: 3.33 s, total: 1min 34s
Wall time: 1min 33s


Features
-----

In [31]:
FEATURES_LUCENE_ALL_SCORES = join(DATA_DIR, 'features/lucene_all.tsv')

In [16]:
%%time
MAX = 10
analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
reader = IndexReader.open(SimpleFSDirectory(File(INDEX_DIR)))
searcher = IndexSearcher(reader)

output_file = join(DATA_DIR, 'features', 'lucene_cumsum%d.tsv' % MAX)
with open(output_file, "w") as fs:
    for row in iter_data(TRAINING_SET):
        queries = [row['aa'], row['ab'], row['ac'], row['ad']]
        queries = [row['q'] + ' ' + q  for q in queries]
        features = []
        for q in queries:
            query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", q))
            hits = searcher.search(query, MAX)
            doc_importances = [hit.score for hit in hits.scoreDocs]
            features.append(";".join(str(d) for d in np.cumsum(doc_importances)))
        print(row["idd"], row["correct"], *features, file=fs, sep="\t")

CPU times: user 39.6 s, sys: 1.63 s, total: 41.3 s
Wall time: 40.5 s


In [31]:
features

[u'1.07244801521;0.889635920525;0.848675429821;0.844713747501;0.786205768585;0.765998661518;0.75105714798;0.749342441559;0.704311668873;0.683470845222',
 u'1.08825957775;0.930170536041;0.902752220631;0.861187875271;0.857167780399;0.762130260468;0.747552096844;0.714695692062;0.693547546864;0.681111216545',
 u'1.09217810631;0.924883246422;0.906002759933;0.864288687706;0.86025416851;0.79830878973;0.779356360435;0.764874458313;0.717269122601;0.698661744595',
 u'1.38169002533;0.879681348801;0.839179158211;0.835261821747;0.777408480644;0.74265319109;0.69643086195;0.675823152065;0.632328689098;0.623643994331']

In [41]:
%%time
MAX = 13000
analyzer = StandardAnalyzer(Version.LUCENE_4_10_1)
reader = IndexReader.open(SimpleFSDirectory(File(INDEX_DIR)))
searcher = IndexSearcher(reader)

with open(FEATURES_LUCENE_ALL_SCORES, "w") as fs:
    for row in iter_data(TRAINING_SET):
        queries = [row['aa'], row['ab'], row['ac'], row['ad']]
        queries = [row['q'] + ' ' + q  for q in queries]
        features = []
        for q in queries:
            query = QueryParser(Version.LUCENE_4_10_1, "text", analyzer).parse(re.sub("[^a-zA-Z0-9]"," ", q))      
            hits = searcher.search(query, MAX)
            doc_importances = {hit.doc: hit.score for hit in hits.scoreDocs}
#             features.append(";".join(doc_importances))
#             print(doc_importances)
            break
#         print(row["idd"], row["correct"], *features, file=fs, sep="\t")
        break

CPU times: user 44 ms, sys: 0 ns, total: 44 ms
Wall time: 37.4 ms


In [40]:
sorted(doc_importances.values())


[0.07655282318592072,
 0.07666215300559998,
 0.07670583575963974,
 0.07675909250974655,
 0.07676452398300171,
 0.07677092403173447,
 0.07710880786180496,
 0.07715141028165817,
 0.07718561589717865,
 0.07733228802680969,
 0.07745745778083801,
 0.07757867127656937,
 0.07758105546236038,
 0.07763881981372833,
 0.07765979319810867,
 0.07768512517213821,
 0.07781240344047546,
 0.07786702364683151,
 0.07792502641677856,
 0.07797142118215561,
 0.07798108458518982,
 0.07800722122192383,
 0.07816611975431442,
 0.07818196713924408,
 0.07824768126010895,
 0.07841303944587708,
 0.07841971516609192,
 0.07842428982257843,
 0.0784338042140007,
 0.07845243066549301,
 0.07848235219717026,
 0.07851400226354599,
 0.07859144359827042,
 0.07860704511404037,
 0.07864391803741455,
 0.07871557027101517,
 0.07877980917692184,
 0.07888653874397278,
 0.07888990640640259,
 0.07893432676792145,
 0.07895950227975845,
 0.07904188334941864,
 0.07910335808992386,
 0.07913459092378616,
 0.07921204715967178,
 0.07924040