In [1]:
from collections import Counter, defaultdict
from itertools import chain, count

import torch
import torchtext.data
import torchtext.vocab

PAD_WORD = '<blank>'
UNK = 0
BOS_WORD = '<s>'
EOS_WORD = '</s>'

In [2]:
def get_fields(nFeatures=0):
    fields = {}
    fields["src"] = torchtext.data.Field(
        pad_token=PAD_WORD,
        include_lengths=True)

    # fields = [("src_img", torchtext.data.Field(
    #     include_lengths=True))]

    for j in range(nFeatures):
        fields["src_feat_"+str(j)] = \
            torchtext.data.Field(pad_token=PAD_WORD)

    fields["tgt"] = torchtext.data.Field(
        init_token=BOS_WORD, eos_token=EOS_WORD,
        pad_token=PAD_WORD)
    
    def make_src(data, _):
        src_size = max([t.size(0) for t in data])
        src_vocab_size = max([t.max() for t in data]) + 1
        alignment = torch.zeros(src_size, len(data), src_vocab_size)
        for i, sent in enumerate(data):
            for j, t in enumerate(sent):
                alignment[j, i, t] = 1
        return alignment

    fields["src_map"] = torchtext.data.Field(
        use_vocab=False, tensor_type=torch.FloatTensor,
        postprocessing=make_src, sequential=False)

    def make_tgt(data, _):
        tgt_size = max([t.size(0) for t in data])
        alignment = torch.zeros(tgt_size, len(data)).long()
        for i, sent in enumerate(data):
            alignment[:sent.size(0), i] = sent
        return alignment

    fields["alignment"] = torchtext.data.Field(
        use_vocab=False, tensor_type=torch.LongTensor,
        postprocessing=make_tgt, sequential=False)

    fields["indices"] = torchtext.data.Field(
        use_vocab=False, tensor_type=torch.LongTensor,
        sequential=False)

    return fields

def collect_features(fields, side="src"):
    assert side in ["src", "tgt"]
    feats = []
    for j in count():
        key = side + "_feat_" + str(j)
        if key not in fields:
            break
        feats.append(key)
    return feats

In [5]:
vocab = torch.load("../data/fixed.vocab.pt")
# vocab = dict(vocab)
# fields = get_fields(
#     len(collect_features(vocab)))
# for k, v in vocab.items():
#     # Hack. Can't pickle defaultdict :(
#     v.stoi = defaultdict(lambda: 0, v.stoi)
#     fields[k].vocab = v

FileNotFoundError: [Errno 2] No such file or directory: '../data/fixed.vocab.pt'

In [11]:
print(vocab)

[('src', <torchtext.vocab.Vocab object at 0x2b95b1735b38>), ('tgt', <torchtext.vocab.Vocab object at 0x2b95b1735b38>)]


In [16]:
vocab[0][1].freqs

Counter({'__start_name__': 1,
         'The': 2,
         'Vaults': 2,
         '__end_name__': 1,
         '__start_eatType__': 1,
         'pub': 2,
         '__end_eatType__': 1,
         '__start_priceRange__': 1,
         'more': 2,
         'than': 2,
         '£': 2,
         '30': 2,
         '__end_priceRange__': 1,
         '__start_customerrating__': 1,
         '5': 2,
         'out': 2,
         'of': 2,
         '__end_customerrating__': 1,
         '__start_near__': 1,
         'Café': 2,
         'Adriatic': 2,
         '__end_near__': 1,
         '__start_additional_words__': 1,
         'star': 2,
         'Prices': 2,
         'start': 2,
         '__end_additional_words__': 1,
         'Cambridge': 2,
         'Blue': 2,
         '__start_food__': 1,
         'English': 2,
         '__end_food__': 1,
         'cheap': 2,
         'Brazil': 2,
         'Close': 2,
         'serves': 2,
         'delicious': 2,
         'Tuscan': 2,
         'Beef': 2,
         'Delic