In [16]:
from __future__ import print_function, division, unicode_literals
import os, sys
from os.path import join
import json
from codecs import open
import nltk
import numpy as np
from nltk.corpus import stopwords
import re
import random
from itertools import chain
from time import time
from nltk.stem import SnowballStemmer
stemmer = SnowballStemmer("english")

In [17]:
ROOT_DATA = join(os.environ["HOME"], "data/allen-ai-challenge")

Datasets
----

In [18]:
WIKI = join(ROOT_DATA, "parsed_wiki_data") #top5 search wiki hits from wiki
CK12 = join(ROOT_DATA, "ck12_dump") #parsing ck12
QUIZLET = join(ROOT_DATA, 'quizlet')

TRAINING = join(ROOT_DATA, "training_set.tsv")
TRAINING_CLEANED = join(ROOT_DATA, "training_set_cleaned.tsv")
VALIDATION = join(ROOT_DATA, "validation_set.tsv")
VALIDATION_CLEANED = join(ROOT_DATA, "validation_set_cleaned.tsv")

MERGED = join(ROOT_DATA, "merged_corpus.txt")

In [20]:
def tokenize(text):
    return [stemmer.stem(w) for w in nltk.word_tokenize(text.lower()) if w not in stopwords]

In [21]:
%%time
ck12_paragraphs = []
for i, fn_short in enumerate(os.listdir(CK12)):
    fn = join(CK12, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        ck12_article = json.load(f)
        for subtitle, paragraph in ck12_article['contents'].items():
            ck12_paragraphs.append(paragraph.strip())

CPU times: user 160 ms, sys: 12 ms, total: 172 ms
Wall time: 172 ms


In [22]:
len(ck12_paragraphs)

7148

In [23]:
%%time
wiki_paragraphs = []
for i, fn_short in enumerate(os.listdir(WIKI)):
    fn = join(WIKI, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        for line in (line.strip() for line in f):
            if not line:
                continue
            if line.startswith('=='):
                continue
            wiki_paragraphs.append(line)

CPU times: user 5.81 s, sys: 324 ms, total: 6.13 s
Wall time: 6.13 s


In [24]:
len(wiki_paragraphs)

561764

In [25]:
%%time
from StringIO import StringIO
terms = {}
for i, fn_short in enumerate(os.listdir(QUIZLET)):
    term_count = 0
    fn = join(QUIZLET, fn_short)
    with open(fn, encoding='utf-8', errors='ignore') as f:
        for line in f:
            j = json.load(StringIO(line))
            for t in j['terms']:
                terms[t['term']] = t['definition']
                term_count += 1
    print(fn_short, term_count, sep=': ')
print()
quizlet_paragraphs = [t + ' ' + d for t, d in terms.iteritems()]    

biology.txt: 238449
ck-12.txt: 9811
ck 12.txt: 9811
earth science.txt: 180614
astronomy.txt: 190219
chemistry.txt: 209191
physics.txt: 196988
ck12.txt: 4632
CPU times: user 8.68 s, sys: 212 ms, total: 8.89 s
Wall time: 8.89 s


In [26]:
len(quizlet_paragraphs)

386836

Cleaning datasets
-----

In [27]:
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english') + '. , ! ? !? ?! ... ; : - — summary youtube www'.split())
for t in "no not above".split():
    stopwords.remove(t)

In [28]:
def text_clean(text):
    s = re.sub(r'[^\w\s\d]', '', text)
    return [stemmer.stem(w) for w in nltk.word_tokenize(s.lower()) if w not in stopwords]

In [None]:
text_clean("hello Died \\ went I Summary all of the above 45 5 www youtube")

[u'hello', u'die', u'went', u'abov', u'45', u'5']

In [None]:
%%time
with open(MERGED, encoding="utf-8", mode="w") as f:
    for d in chain(ck12_paragraphs, wiki_paragraphs, quizlet_paragraphs):
        ct = text_clean(d)
        print(*ct, sep=" ", file=f)

Train/validation
-----

In [None]:
def text_clean_join(t):
    return " ".join(text_clean(t))

In [None]:
with open(TRAINING_CLEANED, encoding="utf-8", mode="w") as fo:
    with open(TRAINING, encoding="utf-8") as f:
        next(f)
        for i, l in enumerate(f):
            [qid, q, r, aa, ab, ac, ad] = l.strip().split("\t")
            print(qid, text_clean_join(q), r,
                  text_clean_join(aa),
                  text_clean_join(ab),
                  text_clean_join(ac),
                  text_clean_join(ad),
                  sep="\t", file=fo)

In [None]:
with open(VALIDATION_CLEANED, encoding="utf-8", mode="w") as fo:
    with open(VALIDATION, encoding="utf-8") as f:
        next(f)
        for i, l in enumerate(f):
            [qid, q, aa, ab, ac, ad] = l.strip().split("\t")
            print(qid, text_clean_join(q),
                  text_clean_join(aa),
                  text_clean_join(ab),
                  text_clean_join(ac),
                  text_clean_join(ad),
                  sep="\t", file=fo)