In [1]:
from collections import Counter
from itertools import chain
import os
import torch
import json

In [2]:
root = '/home/zzs/data/qangaroo_v1.1/wikihop/'
train_path = 'train.json'
dev_path = 'dev.json'

dev_json_path = os.path.join(root, dev_path)
dev_ex_path = dev_json_path.replace('.json','_test.pt')
dev_data = json.load(open(dev_json_path))

In [3]:
from torchtext import data as textdata, vocab
from torchtext.data import Field
import copy
from collections import Counter
from typing import List

class CharField(Field):
    
    def __init__(
        self,
        pad_token='<pad>',
        unk_token='<unk>',
        batch_first=True,
        max_word_length=20,
        max_sentence_length=128,
        lower=True,
        **kwargs):
        super().__init__(
            sequential=True,  # Otherwise pad is set to None in textdata.Field
            batch_first=batch_first,
            use_vocab=True,
            pad_token=pad_token,
            unk_token=unk_token,
            lower=lower,
            **kwargs
        )
        self.max_word_length = max_word_length
        self.max_sentence_length = max_sentence_length
        
    def build_vocab(self, *args, **kwargs):
        sources = []
        for arg in args:
            if isinstance(arg, textdata.Dataset):
                sources += [
                    getattr(arg, name)
                    for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)

        counter = Counter()
        for data in sources:
            # data is the return value of preprocess().
            for para in data:
                if isinstance(para[0], list):
                    for sentence in para:
                        for word_chars in sentence:
                            counter.update(word_chars)
                else:
                    for word_chars in para:
                        counter.update(word_chars)                    
       
        specials = [self.unk_token, self.pad_token]

        self.vocab = vocab.Vocab(counter, specials=specials, **kwargs)
        
    def pad(self, minibatch: List[List[List[str]]]) -> List[List[List[str]]]:
        """
        Example of minibatch:
        ::
            [[['p', 'l', 'a', 'y', '<PAD>', '<PAD>'],
              ['t', 'h', 'a', 't', '<PAD>', '<PAD>'],
              ['t', 'r', 'a', 'c', 'k', '<PAD>'],
              ['o', 'n', '<PAD>', '<PAD>', '<PAD>', '<PAD>'],
              ['r', 'e', 'p', 'e', 'a', 't']
             ], ...
            ]
        """
        # If we change the same minibatch object then the underlying data
        # will get corrupted. Hence deep copy the minibatch object.
        
        if self.max_sentence_length is not None:
            max_sentence_length = self.max_sentence_length
        else:
            max_sentence_length = max(len(sent) for sent in minibatch)

        if self.max_word_length is not None:
            max_word_length = self.max_word_length            
        else:
            max_word_length = max(len(word) for sent in minibatch for word in sent)        

        padded_minibatch = []
        for sentence in minibatch:
            sentence_ch = []
            for word in sentence[:max_sentence_length]:
                sentence_ch.append(list(word))
            padded_minibatch.append(sentence_ch)

        for i, sentence in enumerate(padded_minibatch):
            for j, word in enumerate(sentence):
                char_padding = [self.pad_token] * (max_word_length - len(word))
                padded_minibatch[i][j].extend(char_padding)
                padded_minibatch[i][j] = padded_minibatch[i][j][:max_word_length]
            if len(sentence) < max_sentence_length:
                for _ in range(max_sentence_length - len(sentence)):
                    char_padding = [self.pad_token] * max_word_length
                    padded_minibatch[i].append(char_padding)

        return padded_minibatch

    def numericalize(self, batch, device=None):
        batch_char_ids = []
        for sentence in batch:
            sentence_char_ids = super().numericalize(sentence, device=device)
            batch_char_ids.append(sentence_char_ids)
        return torch.stack(batch_char_ids, dim=0)
    
    
    
import torchtext
from torchtext.data import Field, Dataset,Iterator, RawField, BucketIterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab, FastText


class DocCharField(CharField):

    def __init__(
        self, 
        memory_size=None,
        max_word_length=20,
        max_sentence_length=128,
        keep_sent_len=128,
        keep_word_len=10,
        **kwargs
        ):
        tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
        super(DocCharField, self).__init__(tokenize=tokenizer, **kwargs)
        self.memory_size = memory_size
        self.keep_sent_len = keep_sent_len
        self.keep_word_len = keep_word_len


    def preprocess(self, x):
        if isinstance(x, list):
            ss =  super(DocCharField, self).preprocess(x)
            return [super(DocCharField, self).preprocess(s) for s in ss]
        else:
            return super(DocCharField, self).preprocess(x)

    def pad(self, minibatch):
        if isinstance(minibatch[0][0], list):
            self.max_sentence_length = max(max(len(x) for x in ex) for ex in minibatch)
            if self.keep_sent_len is not None:
                self.max_sentence_length = min(self.keep_sent_len, self.max_sentence_length)
            self.max_word_length = max([len(word) for para in minibatch for sent in para for word in sent ])
            if self.keep_word_len is not None:
                self.max_word_length = min(self.keep_word_len, self.max_word_length)
                
            if self.memory_size is None:
                memory_size = max(len(ex) for ex in minibatch)
            else:
                memory_size = self.memory_size
            padded = []
            for ex in minibatch:
                # sentences are indexed in reverse order and truncated to memory_size
                nex = ex[:memory_size]
                padded.append(
                    super(DocCharField, self).pad(nex)
                )
                for _ in range(memory_size - len(nex)):
                    padded_sent = [[self.pad_token]*self.max_word_length for _ in range(self.max_sentence_length)]
                    padded[-1].append(padded_sent)
            return padded
        else:
            self.max_sentence_length = None
            self.max_word_length = max([len(word) for sent in minibatch for word in sent ])
            if self.keep_word_len is not None:
                self.max_word_length = min(self.keep_word_len, self.max_word_length)
            return super(DocCharField, self).pad(minibatch)

    def numericalize(self, arr, device=None):
        if isinstance(arr[0][0][0], list):
            tmp = [
                super(DocCharField, self).numericalize(x, device=device).data
                for x in arr
            ]
            arr = torch.stack(tmp)
            if self.sequential:
                arr = arr.contiguous()
            return arr
        else:
            return super(DocCharField, self).numericalize(arr, device=device)    

In [4]:
doc_char_field = DocCharField(keep_sent_len=128, keep_word_len=30,lower=True)

In [5]:

doc_char_field = DocCharField(keep_sent_len=256, keep_word_len=16,lower=True)
fields = {
   'candidates': ('candidates_char',doc_char_field),
   'supports': ('supports_char',doc_char_field),
    'query': ('query_char', doc_char_field),
    'id': ('id', RawField()),
}

In [6]:
from tqdm import tqdm
make_example = torchtext.data.example.Example.fromdict


In [7]:
examples = []
def preprocess(item):
    answer = item['answer']
    candidates = item['candidates']
    label = candidates.index(answer)
    item['label'] = label
    item['query'] = item['query'].replace('_',' ')
    return item

for d in tqdm(dev_data):
    d = preprocess(d)
    example = make_example(d, fields)
    examples.append(example)

100%|██████████| 5129/5129 [01:40<00:00, 51.20it/s]


In [8]:
if isinstance(fields, dict):
    fields, field_dict = [], fields
    for field in field_dict.values():
        if isinstance(field, list):
            fields.extend(field)
        else:
            fields.append(field)

In [9]:
dataset = Dataset(examples, fields)

In [10]:
doc_char_field.build_vocab(dataset)

In [11]:
#test_iter = BucketIterator(dataset, 4 ,sort_key=lambda x: len(x.supports_char), sort_within_batch=True, device=None)
test_iter = BucketIterator(dataset, 4,device=None)

In [12]:
i = 0
for batch in test_iter:
    i += 1
    if i > 10:
        break
    print(batch)


[torchtext.data.batch.Batch of size 4]
	[.candidates_char]:[torch.LongTensor of size 4x46x3x12]
	[.supports_char]:[torch.LongTensor of size 4x25x240x16]
	[.query_char]:[torch.LongTensor of size 4x8x9]
	[.id]:['WH_dev_3042', 'WH_dev_2533', 'WH_dev_2341', 'WH_dev_2735']

[torchtext.data.batch.Batch of size 4]
	[.candidates_char]:[torch.LongTensor of size 4x48x3x13]
	[.supports_char]:[torch.LongTensor of size 4x20x256x16]
	[.query_char]:[torch.LongTensor of size 4x5x10]
	[.id]:['WH_dev_593', 'WH_dev_2066', 'WH_dev_4297', 'WH_dev_4078']

[torchtext.data.batch.Batch of size 4]
	[.candidates_char]:[torch.LongTensor of size 4x47x4x15]
	[.supports_char]:[torch.LongTensor of size 4x27x236x15]
	[.query_char]:[torch.LongTensor of size 4x15x14]
	[.id]:['WH_dev_4753', 'WH_dev_2939', 'WH_dev_2956', 'WH_dev_3823']

[torchtext.data.batch.Batch of size 4]
	[.candidates_char]:[torch.LongTensor of size 4x27x3x12]
	[.supports_char]:[torch.LongTensor of size 4x19x256x16]
	[.query_char]:[torch.LongTensor o

In [31]:
len(test_iter)

1283

In [22]:
class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_data = next(self.loader)
        except StopIteration:
            self.next_input = None
            return
            
    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.next_data
        self.preload()
        return data

In [23]:
prefetcher = data_prefetcher(test_iter)

In [27]:
i = 0
batch = prefetcher.next()

In [30]:
batch.candidates_char

tensor([[[[ 3,  6,  6,  ...,  1,  1,  1],
          [ 3,  9, 20,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         [[20,  7,  8,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         [[20,  9,  3,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         ...,

         [[ 8,  3, 19,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         [[ 8,  4,  9,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         [[14,  6,  5,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]]],


        [[[ 3,  9,  4,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1],
          [ 1,  1,  1,  ...,  1,  1,  1]],

         [[ 3, 14,  4,  ..., 17, 10, 19],
          [ 1,  1,  1,  ...,  1,  1,  1],
   