In [1]:
import json
import torch
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert import BertModel
from tqdm import tqdm

In [10]:
class WikiExample(object):
    
    def __init__(self,
                 wiki_id,
                 passages,
                 choices,
                 question,
                 label=None):
        self.wiki_id = wiki_id
        self.passages = passages
        self.choices = choices
        self.label = label
        self.question = question
        
    def __str__(self):
        return self.__repr__()
    
    def __repr__(self):
        l = [
            f'id: {self.wiki_id}',
            f'question: {self.question}',
            f'passages:[{",".join(self.passages)}]',
            f'choices:[{",".join(self.choices)}]',
        ]
        
        if self.label is not None:
            l.append(f'label: {self.label}')
        
        return '\n'.join(l)
        
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        
class WikiFeatures(object):
    
    def __init__(self,
                 wiki_id,
                 passages_features,
                 choices_features,
                 question_features,
                 label=None):
        
        self.passages_features = passages_features
        self.choices_features = choices_features
        self.question_features = question_features
        self.label = label
        self.example_id = wiki_id 
        
def read_wiki_example(path):
    data = json.load(open(path))
    examples = []
    for item in data:
        passages = item['supports']
        choices = item['candidates']
        question = item['query']
        wiki_id = item['id']
        answer = item['answer']
        label = -1
        for idx, choice in enumerate(choices):
            if choice == answer:
                label = idx
        if label == -1:
            print(wiki_id)
            continue
        example = WikiExample(wiki_id, passages, choices, question, label)
        examples.append(example)
    return examples        

def _truncate_seq(seq, max_seq_length):
    if len(seq) <= max_seq_length:
        return seq
    else:
        return seq[:max_seq_length]
    
def convert_context_to_features(context, max_seq_length=None):
    context_token = tokenizer.tokenize(context)
    if max_seq_length is None:
        max_seq_length = len(context_token) + 2
    context_token = _truncate_seq(context_token, max_seq_length - 2)
    context_token = ["[CLS]"] + context_token + ["[SEP]"] 

    segment_ids = [0]*len(context_token)
    input_ids = tokenizer.convert_tokens_to_ids(context_token)
    input_mask = [1] * len(input_ids)

    padding = [0] * (max_seq_length - len(context_token))
    input_ids += padding
    input_mask += padding
    segment_ids += padding

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    return InputFeatures(input_ids, input_mask, segment_ids)    

def convert_examples_to_features(examples, tokenizer, max_seq_length):
    all_features = []
    for example in tqdm(examples):
        max_choice_len = max([ len(tokenizer.tokenize(choice)) for choice in example.choices])

        choices_features = [convert_context_to_features(choice, max_choice_len+2) for choice in example.choices]
        passages_features = [convert_context_to_features(passage, max_seq_length) for passage in example.passages]
        question_features = [convert_context_to_features(example.question)]
        
        all_features.append(WikiFeatures(example.wiki_id,
                                         passages_features,
                                         choices_features,
                                         question_features,
                                         example.label))
    return all_features
    

def convert_to_tensor(input_features):
    input_ids = torch.tensor([features.input_ids for features in input_features], dtype=torch.long)
    input_mask = torch.tensor([features.input_mask for features in input_features], dtype=torch.long)
    segment_ids = torch.tensor([features.segment_ids for features in input_features], dtype=torch.long)
    return input_ids, input_mask, segment_ids   

def make_input(item):
    d = {}
    d['choices'] = convert_to_tensor(item.choices_features)
    d['question'] = convert_to_tensor(item.question_features)
    d['passages'] = convert_to_tensor(item.passages_features)
    d['label'] = torch.tensor([item.label], dtype=torch.long)
    d['wiki_id'] = item.example_id
    return d    



In [11]:
tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased-vocab.txt', do_lower_case=True)

03/04/2019 22:17:34 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file ./bert-base-uncased-vocab.txt


In [13]:
train_json_path = './data/qangaroo_v1.1/wikihop/train.json'
dev_json_path = './data/qangaroo_v1.1/wikihop/dev.json'

max_seq_length = 128

train_bert_path = f'./data/train_data_bert_{max_seq_length}.pt'
dev_bert_path = f'./data/dev_data_bert_{max_seq_length}.pt'

In [14]:

train_examples  = read_wiki_example(train_json_path)
dev_examples = read_wiki_example(dev_json_path)

train_features = convert_examples_to_features(train_examples, tokenizer=tokenizer, max_seq_length=max_seq_length)
dev_features = convert_examples_to_features(dev_examples, tokenizer=tokenizer, max_seq_length=max_seq_length)

train_data = [make_input(f) for f in train_features]
dev_data = [make_input(f) for f in dev_features]

torch.save(train_data, train_bert_path)
torch.save(dev_data, dev_bert_path)

100%|██████████| 43738/43738 [40:39<00:00, 16.77it/s]  
100%|██████████| 5129/5129 [05:06<00:00, 16.73it/s]


In [9]:
train_examples[9].choices

['accounting',
 'artist',
 'band',
 'barber',
 'canada',
 'commercial',
 'indie pop',
 'manufacturer',
 'marketing',
 'scouting',
 'singer',
 'united kingdom']

In [None]:
train_bert_path = f'./data/train_data_bert.pt'
dev_bert_path = f'./data/dev_data_bert.pt'
train_data = torch.load(train_bert_path)
dev_data = torch.load(dev_bert_path)

In [None]:
p_nums = []
for item in train_data:
    passages = item['passages']
    p_nums.append(passages[0].size(0))

In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame(p_nums)

In [None]:
df.describe()

In [None]:
df[df > 32] = -1

In [None]:
df[0].value_counts()