In [177]:
import json, nltk, io, pickle
import numpy as np
from itertools import chain

### Read data

In [197]:
with io.open('/pio/data/data/squad/train-v1.1.json', 'r', encoding='utf-8') as f:
    train = json.load(f)

In [191]:
with io.open('/pio/data/data/squad/dev-v1.1.json', 'r', encoding='utf-8') as f:
    dev = json.load(f)

### Data structure

In [192]:
dev['data'][0]['paragraphs'][0]['qas'][0]['answers']

[{u'answer_start': 177, u'text': u'Denver Broncos'},
 {u'answer_start': 177, u'text': u'Denver Broncos'},
 {u'answer_start': 177, u'text': u'Denver Broncos'}]

In [190]:
' '.join(nltk.word_tokenize(train['data'][10]['paragraphs'][60]['context'])).split(' . ')

[u"The State Council declared a three-day period of national mourning for the quake victims starting from May 19 , 2008 ; the PRC 's National Flag and Regional Flags of Hong Kong and Macau Special Administrative Regions flown at half mast",
 u'It was the first time that a national mourning period had been declared for something other than the death of a state leader , and many have called it the biggest display of mourning since the death of Mao Zedong',
 u'At 14:28 CST on May 19 , 2008 , a week after the earthquake , the Chinese public held a moment of silence',
 u'People stood silent for three minutes while air defense , police and fire sirens , and the horns of vehicles , vessels and trains sounded',
 u"Cars and trucks on Beijing 's roads also came to a halt",
 u"People spontaneously burst into cheering `` Zhongguo jiayou ! '' ( Let 's go , China ! ) and `` Sichuan jiayou '' ( Let 's go , Sichuan ! ) afterwards ."]

### Grab all the question-answer pairs and create a wordlist

In [65]:
words = set()
data = []
lower = lambda x: x.lower()

for par in train['data']:
    title = par['title']
    
    for con in par['paragraphs']:
        context = con['context']
        context_tok = map(lower, nltk.word_tokenize(context))
        words |= set(context_tok)
        
        for q in con['qas']:
            question = q['question']
            question_tok = map(lower, nltk.word_tokenize(question))
            words |= set(question_tok)
            
            Id = q['id']
            
            answers = []
            
            for ans in q['answers']:
                text = ans['text']
                text_tok = map(lower, nltk.word_tokenize(text))
                ans_start = ans['answer_start']
                
                answers.append((ans_start, text_tok))
                
            data.append([answers, question_tok, context_tok])
            
words.add('<unk>')

In [83]:
print len(data), len(words)

87599 102802


In [109]:
for d in data:
    if len(d[0]) > 1:
        print d
        break

### Turn words into numbers

In [85]:
def split_on_dot(s):
    res = [[]]
    for w in s:
        res[-1].append(w)
        if w == u'.':
            res.append([])
    return res if res[-1] else res[:-1]

def words_to_num(s):
    return map(lambda x: w_to_i.get(x, w_to_i['<unk>']), s)

In [84]:
i_to_w = dict(enumerate(words))
w_to_i = {v:k for (k,v) in i_to_w.items()}

In [70]:
for i in xrange(len(data)):
    data[i][2] = split_on_dot(data[i][2])

In [145]:
data_num = []

for a, q, c in data:
    answers = []
    for ans in a:
        answers.append((ans[0], words_to_num(ans[1])))        
    data_num.append([answers, words_to_num(q), map(words_to_num, c)])

In [146]:
# Some answers don't work, because of the tokenizer

bugged_answers = 0

for ans,_,_ in data_num:
    for _,a in ans:
        if w_to_i['<unk>'] in a:
            bugged_answers += 1
bugged_answers

80

In [147]:
data_num = map(lambda l: [l[0], [l[1]] + l[2]], data_num)    

In [148]:
data_num = map(lambda l: [map(lambda t: t[1], l[0]), l[1]], data_num)

In [149]:
# There are more broken answers, because I tag words instead of characters

k = 0
for a, q in data_num:
    for w in a[0]:
        if w not in list(chain(*q[1:])):
            k += 1
k

1028

### Find answer indices on words, not characters

In [151]:
inds = []

for a, q in data_num:
    ans = []
    tot_q = list(chain(*q[1:]))
    for x in a:
        for i in xrange(len(tot_q)):
            if x == tot_q[i:i+len(x)]:
                ans.append(list(xrange(i, i + len(x))))
                break
    inds.append(ans)

In [152]:
for i in xrange(len(data_num)):
    data_num[i][0] = inds[i]

### Save processed data

In [165]:
sorted_words = map(lambda x: x[0], sorted(w_to_i.items(), key=lambda x: x[1]))

In [166]:
with io.open('/pio/data/data/squad/wordlist.txt', 'w', encoding='utf-8') as f:
    for w in sorted_words:
        f.write(unicode(w + '\n'))

In [176]:
data_num[0]

[[[102, 103, 104]],
 [[78406,
   50177,
   45612,
   67711,
   87146,
   71884,
   58619,
   29587,
   94338,
   55795,
   94338,
   72312,
   97133,
   83077],
  [100780, 44968, 67711, 60695, 608, 43315, 89195, 32610, 492],
  [40701,
   67711,
   45830,
   54332,
   55792,
   78791,
   78506,
   19554,
   43315,
   51341,
   10820,
   23764,
   67711,
   87146,
   71884,
   492],
  [83032,
   94338,
   95628,
   23764,
   67711,
   45830,
   54332,
   49698,
   33470,
   19557,
   44968,
   19554,
   43315,
   85100,
   10820,
   23764,
   99569,
   1485,
   96317,
   37478,
   1485,
   67711,
   78483,
   64851,
   101002,
   14122,
   32833,
   66547,
   77561,
   492],
  [19445,
   78406,
   67711,
   45830,
   54332,
   19554,
   67711,
   32756,
   23764,
   67711,
   7991,
   80913,
   492],
  [83032,
   71615,
   67711,
   32756,
   19554,
   67711,
   16921,
   44968,
   43315,
   50307,
   87507,
   23764,
   55111,
   49698,
   18814,
   492],
  [19557,
   19554,
   43315,
 

In [170]:
# This file has a lot of redundant parts, context is repeated for each question.
# It only slows down the initial loading.

with open('/pio/data/data/squad/train.pkl', 'w') as f:
    pickle.dump(data_num, f)

In [193]:
a = np.load('/pio/data/data/squad/train.pkl')

In [194]:
a[0]

[[[102, 103, 104]],
 [[78406,
   50177,
   45612,
   67711,
   87146,
   71884,
   58619,
   29587,
   94338,
   55795,
   94338,
   72312,
   97133,
   83077],
  [100780, 44968, 67711, 60695, 608, 43315, 89195, 32610, 492],
  [40701,
   67711,
   45830,
   54332,
   55792,
   78791,
   78506,
   19554,
   43315,
   51341,
   10820,
   23764,
   67711,
   87146,
   71884,
   492],
  [83032,
   94338,
   95628,
   23764,
   67711,
   45830,
   54332,
   49698,
   33470,
   19557,
   44968,
   19554,
   43315,
   85100,
   10820,
   23764,
   99569,
   1485,
   96317,
   37478,
   1485,
   67711,
   78483,
   64851,
   101002,
   14122,
   32833,
   66547,
   77561,
   492],
  [19445,
   78406,
   67711,
   45830,
   54332,
   19554,
   67711,
   32756,
   23764,
   67711,
   7991,
   80913,
   492],
  [83032,
   71615,
   67711,
   32756,
   19554,
   67711,
   16921,
   44968,
   43315,
   50307,
   87507,
   23764,
   55111,
   49698,
   18814,
   492],
  [19557,
   19554,
   43315,
 

In [174]:
len(data_num)

87599

In [195]:
b = {}
idx = 0

with io.open('/pio/data/data/squad/wordlist.txt', 'r', encoding='utf-8') as f:
    for line in f:
        b[line[:-1]] = idx
        idx += 1
        
rev_b = {v:k for (k,v) in b.items()}

In [196]:
rev_b[83077]

u'?'