In [None]:
import os
from torchtext import data, datasets
from torchtext.vocab import GloVe
import nltk
import pandas as pd
import json
import numpy as np
import tqdm
from tqdm import tqdm
import torchtext

In [None]:
os.listdir('./data/qangaroo_v1.1/wikihop/')

In [None]:
train_json_path = './data/qangaroo_v1.1/wikihop/train.json'
dev_json_path = './data/qangaroo_v1.1/wikihop/dev.json'

In [None]:
train_data = json.load(open(train_json_path))
val_data = json.load(open(dev_json_path))

In [None]:
train_data[0].keys()

In [None]:
# Convert the supports documents to superdocument.
def convert(item):
    n = len(item['supports'])
    indexs = np.arange(n)
    np.random.shuffle(indexs)
    contexts = np.array([[ token.lower() for token in nltk.word_tokenize(tokens)] for tokens in item['supports']])
    super_doc = []
    for i in range(n):
        super_doc += contexts[indexs[i]] 
        if i < n-1:
            super_doc += ['<sep>']

    query = nltk.word_tokenize(item['query'].replace('_',' '))        
    answer = nltk.word_tokenize(item['answer'].lower())
    
    m = len(answer)
    s_idx = -1
    e_idx = -1
    for i in range(len(super_doc)):
        if super_doc[i] == answer[0]:
            if ''.join(super_doc[i:i+m]) == ''.join(answer):
                s_idx = i
                e_idx = i+m-1
                break
    return super_doc, query, s_idx, e_idx


In [None]:
def convert_to_single_hop(dataset):
    n_dataset = []
    for item in tqdm(dataset):
        super_doc, query, s_idx, e_idx = convert(item)
        if s_idx == -1: # Ingore the examples which can't be found in support documents
            continue 
        while e_idx >= 8192:
            super_doc, query, s_idx, e_idx = convert(item)
        n_item = {}
        n_item['context'] = super_doc
        n_item['id'] = item['id']
        n_item['query'] = query
        n_item['s_idx'] = s_idx
        n_item['e_idx'] = e_idx
        n_item['answer'] = item['answer']
        n_dataset.append(n_item)
    return n_dataset

In [None]:
n_train = convert_to_single_hop(train_data)
n_val = convert_to_single_hop(val_data)

In [None]:
print(len(train_data), len(n_train))
print(len(val_data), len(n_val))

In [None]:
count = 0
for item in n_train:
    context = item['context']
    if len(context) > 8192:
        count += 1
print(count)

In [None]:
for i in range(len(n_val)):
    item = n_val[i]
    if len(item['context']) > 8192:
        item['context'] = item['context'][:8192]
        n_val[i] = item

In [None]:
for i in range(len(n_train)):
    item = n_train[i]
    if len(item['context']) > 8192:
        item['context'] = item['context'][:8192]
        n_train[i] = item

In [None]:
n_train_json_path = './data/qangaroo_v1.1/wikihop/train_single.json'
n_val_json_path = './data/qangaroo_v1.1/wikihop/dev_single.json'

def save_json(json_data, path):
    dumps = []
    for line in json_data:
        dumps.append(dict(line))
    with open(path, 'w', encoding='utf-8') as f:
        for line in dumps:
            json.dump(line, f)
            print('', file=f)
save_json(n_train, n_train_json_path)
save_json(n_val, n_val_json_path)

In [None]:
char_field_nesting =  data.Field(batch_first=True, tokenize=list)
char_field = data.NestedField(char_field_nesting)
raw = data.RawField()
raw.is_target = False
word_field = data.Field(batch_first=True)
label_field = data.Field(sequential=False, use_vocab=False)

In [None]:
item = n_train[0]

In [None]:
n_train[0].keys()

In [None]:
dict_field = {
    'id': ('id', raw),
    's_idx': ('s_idx', label_field),
    'e_idx': ('e_idx', label_field),
    'context': [('c_word', word_field), ('c_char', char_field)],
    'query': [('q_word', word_field), ('q_char', char_field)]    
}

In [None]:
train_dataset, val_dataset = data.TabularDataset.splits(path='', train=n_train_json_path, validation=n_val_json_path, 
                                                        format='json', fields=dict_field)

In [None]:
char_field.build_vocab(train_dataset, val_dataset)

In [None]:

word_field.build_vocab(train_dataset, val_dataset, vectors=torchtext.vocab.GloVe(dim=100,name='6B'))

In [None]:
iterator = iter(torchtext.data.BucketIterator(val_dataset, batch_size=2))

In [None]:
item = next(iterator)

In [None]:
len(word_field.vocab.stoi.keys())

In [None]:
import torch

In [None]:
torch.save(train_dataset.examples,'./train_examples.pt')
torch.save(val_dataset.examples, './val_examples.pt')