In [1]:
import torch
import torch.nn as nn
import torchtext
import numpy as np
import os
import json
from torchtext import data, datasets
from dataset import DataHandler, BertField
import pandas as pd
from dataset import QDataset
from pytorch_pretrained_bert import BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [22]:
train_json_path = './data/qangaroo_v1.1/wikihop/train.json'
dev_json_path = './data/qangaroo_v1.1/wikihop/dev.json'

In [23]:
def get_label(row):
    index = -1
    for i in range(len(row['candidates'])):
        c = row['candidates'][i]
        if c == row['answer']:
            index = i
            break
    return index

In [24]:
df = pd.read_json(train_json_path)
df['query'] = df['query'].apply(lambda x: x.replace("_"," "))
df['label'] = df.apply(lambda x: get_label(x), axis=1)
df.to_json('./data/qangaroo_v1.1/wikihop/train_convert.json', orient='records')


df = pd.read_json(dev_json_path)
df['query'] = df['query'].apply(lambda x: x.replace("_"," "))
df['label'] = df.apply(lambda x: get_label(x), axis=1)
df.to_json('./data/qangaroo_v1.1/wikihop/val_convert.json', orient='records')


In [26]:
df = pd.read_json('./data/qangaroo_v1.1/wikihop/val_convert.json')

In [27]:
df.head()

Unnamed: 0,answer,candidates,id,label,query,supports
0,1996 summer olympics,"[1996 summer olympics, olympic games, sport]",WH_train_0,0,participant of juan rossell,"[The 2004 Summer Olympic Games, officially kno..."
1,english,"[english, greek, koine greek, nahuatl, spanish]",WH_train_1,0,languages spoken or written john osteen,[A Christian (or ) is a person who follows or ...
2,lepidosauria,"[alligatoridae, amphibia, amphisbaenia, animal...",WH_train_2,14,parent taxon proaigialosaurus,[Reptiles are tetrapod (four-limbed vertebrate...
3,crocodilia,"[animal, area, crocodile, crocodilia, homo, me...",WH_train_3,3,parent taxon australosuchus,[Mekosuchinae was a subfamily of crocodiles fr...
4,physicist,"[academic, builder, chancellor, classics, conf...",WH_train_4,20,occupation cao chong,"[Wu (222280), commonly known as Eastern Wu or ..."


In [2]:
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased-vocab.txt')


bert_field = BertField(tokenizer)
multi_bert_field = data.NestedField(bert_field)

word_field = data.Field(batch_first=True, sequential=True, tokenize=tokenizer.tokenize, lower=True) # query
multi_word_field = data.NestedField(word_field) 

raw = data.RawField()
raw.is_target = False

label_field = data.Field(sequential=False, is_target=True, use_vocab=False)

dict_field = {
    'id': ('id', raw),
    'supports': [('s_glove', multi_word_field), ('s_bert', multi_bert_field)],
    'query': [('q_glove', word_field), ('q_bert', bert_field)],
    'answer': [('a_glove', word_field), ('a_bert', bert_field)],
    'candidates': [('c_glove', multi_word_field), ('c_bert', multi_bert_field)],
    'label': ('label', label_field)
}

In [3]:
train_path = './data/qangaroo_v1.1/wikihop/train_convert.json'
val_path = './data/qangaroo_v1.1/wikihop/val_convert.json'

In [4]:
trainset = QDataset(train_path, dict_field)

100%|██████████| 43738/43738 [1:06:02<00:00, 12.09it/s]


In [5]:
valset = QDataset(val_path, dict_field)

100%|██████████| 5129/5129 [08:19<00:00, 10.27it/s]


In [6]:
train_examples_path = './train_examples.pt'
val_examples_path = './val_examples.pt'

torch.save(trainset.examples, train_examples_path)
torch.save(valset.examples, val_examples_path)