In [1]:
from torchtext import data
from torchtext.data import Dataset, Iterator, Field
import numpy as np
from collections import defaultdict, Counter

In [2]:
tok_fun = lambda s: s.split()

src_field = data.Field(init_token=None, eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           batch_first=True, lower=True,
                           unk_token="<UNK>",
                           include_lengths=True)

trg_field = data.Field(init_token="<BOS>", eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           unk_token="<UNK>",
                           batch_first=True, lower=True,
                           include_lengths=True)
edge_org_field = data.Field(init_token=None, eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           unk_token="<UNK>",
                           batch_first=True, lower=True,
                           include_lengths=True)
edge_trg_field = data.Field(init_token=None, eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           unk_token="<UNK>",
                           batch_first=True, lower=True,
                           include_lengths=True)

In [3]:
print(src_field)

<torchtext.data.field.Field object at 0x7fb5607da6a0>


In [10]:


class TranslationDataset(data.Dataset):
    """Defines a dataset for machine translation."""

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.src), len(ex.trg))

    def __init__(self, src_file, trg_file, fields, **kwargs):
        """Create a TranslationDataset given paths and fields.
        Arguments:
            path: Common prefix of paths to the data files for both languages.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        if not isinstance(fields[0], (tuple, list)):
            fields = [('src', fields[0]), ('trg', fields[1]),\
                     ('edge_org', fields[2]), ('edge_trg', fields[3])]

        examples = []

        for src_line, trg_line in zip(src_file, trg_file):
                src_line, trg_line = src_line.strip(), trg_line.strip()
                if src_line != '' and trg_line != '':
                    print(src_line)
                    print(trg_line)
                    examples.append(data.Example.fromlist(
                        [src_line, trg_line,"0 1 2 3 4","1 2 3 4 0"], fields))

        super(TranslationDataset, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, exts, fields, path=None, root='.data',
               train='train', validation='val', test='test', **kwargs):
        """Create dataset objects for splits of a TranslationDataset.
        Arguments:
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            path (str): Common prefix of the splits' file paths, or None to use
                the result of cls.download(root).
            root: Root dataset storage directory. Default is '.data'.
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            test: The prefix of the test data. Default: 'test'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        if path is None:
            path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), exts, fields, **kwargs)
        val_data = None if validation is None else cls(
            os.path.join(path, validation), exts, fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), exts, fields, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None)

In [11]:
DEFAULT_UNK_ID = lambda: 0
UNK_TOKEN = "<UNK>"
PAD_TOKEN = "<PAD>"
BOS_TOKEN = "<BOS>"
EOS_TOKEN = "<EOS>"
class Vocabulary:
    """ Vocabulary represents mapping between tokens and indices. """

    def __init__(self, tokens = None, file: str = None) -> None:
        """
        Create vocabulary from list of tokens or file.

        Special tokens are added if not already in file or list.
        File format: token with index i is in line i.

        :param tokens: list of tokens
        :param file: file to load vocabulary from
        """
        # don't rename stoi and itos since needed for torchtext
        # warning: stoi grows with unknown tokens, don't use for saving or size

        # special symbols
        self.specials = ["<UNK>", "<PAD>", "<BOS>", "<EOS>"]

        self.stoi = defaultdict(DEFAULT_UNK_ID)
        self.itos = []
        if tokens is not None:
            self._from_list(tokens)
        elif file is not None:
            self._from_file(file)

    def _from_list(self, tokens = None) -> None:
        """
        Make vocabulary from list of tokens.
        Tokens are assumed to be unique and pre-selected.
        Special symbols are added if not in list.

        :param tokens: list of tokens
        """
        self.add_tokens(tokens=self.specials+tokens)
        assert len(self.stoi) == len(self.itos)

    def _from_file(self, file: str) -> None:
        """
        Make vocabulary from contents of file.
        File format: token with index i is in line i.

        :param file: path to file where the vocabulary is loaded from
        """
        tokens = []
        with open(file, "r") as open_file:
            for line in open_file:
                tokens.append(line.strip("\n"))
        self._from_list(tokens)

    def __str__(self) -> str:
        return self.stoi.__str__()

    def to_file(self, file: str) -> None:
        """
        Save the vocabulary to a file, by writing token with index i in line i.

        :param file: path to file where the vocabulary is written
        """
        with open(file, "w") as open_file:
            for t in self.itos:
                open_file.write("{}\n".format(t))

    def add_tokens(self, tokens) -> None:
        """
        Add list of tokens to vocabulary

        :param tokens: list of tokens to add to the vocabulary
        """
        for t in tokens:
            new_index = len(self.itos)
            # add to vocab if not already there
            if t not in self.itos:
                self.itos.append(t)
                self.stoi[t] = new_index

    def is_unk(self, token: str) -> bool:
        """
        Check whether a token is covered by the vocabulary

        :param token:
        :return: True if covered, False otherwise
        """
        return self.stoi[token] == DEFAULT_UNK_ID()

    def __len__(self) -> int:
        return len(self.itos)

    def array_to_sentence(self, array: np.array, cut_at_eos=True,
                          skip_pad=True):
        """
        Converts an array of IDs to a sentence, optionally cutting the result
        off at the end-of-sequence token.

        :param array: 1D array containing indices
        :param cut_at_eos: cut the decoded sentences at the first <eos>
        :param skip_pad: skip generated <pad> tokens
        :return: list of strings (tokens)
        """
        sentence = []
        for i in array:
            s = self.itos[i]
            if cut_at_eos and s == EOS_TOKEN:
                break
            if skip_pad and s == PAD_TOKEN:
                continue
            sentence.append(s)
        return sentence

    def arrays_to_sentences(self, arrays: np.array, cut_at_eos=True,
                            skip_pad=True):
        """
        Convert multiple arrays containing sequences of token IDs to their
        sentences, optionally cutting them off at the end-of-sequence token.

        :param arrays: 2D array containing indices
        :param cut_at_eos: cut the decoded sentences at the first <eos>
        :param skip_pad: skip generated <pad> tokens
        :return: list of list of strings (tokens)
        """
        sentences = []
        for array in arrays:
            sentences.append(
                self.array_to_sentence(array=array, cut_at_eos=cut_at_eos,
                                       skip_pad=skip_pad))
        return sentences


def build_vocab(field: str, max_size: int, min_freq: int, dataset: Dataset,
                vocab_file: str = None) -> Vocabulary:
    """
    Builds vocabulary for a torchtext `field` from given`dataset` or
    `vocab_file`.

    :param field: attribute e.g. "src"
    :param max_size: maximum size of vocabulary
    :param min_freq: minimum frequency for an item to be included
    :param dataset: dataset to load data for field from
    :param vocab_file: file to store the vocabulary,
        if not None, load vocabulary from here
    :return: Vocabulary created from either `dataset` or `vocab_file`
    """

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
    else:
        # create newly
        def filter_min(counter: Counter, min_freq: int):
            """ Filter counter by min frequency """
            filtered_counter = Counter({t: c for t, c in counter.items()
                                        if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter: Counter, limit: int):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        tokens = []
        for i in dataset.examples:
            if field == "src":
                tokens.extend(i.src)
            elif field == "trg":
                tokens.extend(i.trg)

        counter = Counter(tokens)
        if min_freq > -1:
            counter = filter_min(counter, min_freq)
        vocab_tokens = sort_and_cut(counter, max_size)
        assert len(vocab_tokens) <= max_size

        vocab = Vocabulary(tokens=vocab_tokens)
        assert len(vocab) <= max_size + len(vocab.specials)
        assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN

    # check for all except for UNK token whether they are OOVs
    for s in vocab.specials[1:]:
        assert not vocab.is_unk(s)

    return vocab

In [12]:
english_sentences=["I am happy to be here","This is the test sentence number 2","the last test sentence is  this one"]
spanish_sentences=["estoy feliz de estar aquí","ésta es la oración de prueba número 2","La última ouración de prueba es ésta"]
edge_indexes=[[[0,1,2,3,4,5],[0,1,2,3,4,5]],[[0,1,2,3,4],[0,1,2,3,4]],[[0,1,2,3],[0,1,2,3]]]
dataset=TranslationDataset(english_sentences,spanish_sentences,fields=(src_field, trg_field,edge_org_field,edge_trg_field))

I am happy to be here
estoy feliz de estar aquí
This is the test sentence number 2
ésta es la oración de prueba número 2
the last test sentence is  this one
La última ouración de prueba es ésta


In [13]:
src_vocab = build_vocab(field="src", min_freq=1,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
trg_vocab = build_vocab(field="trg", min_freq=1,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
edge_org_vocab = build_vocab(field="edge_org", min_freq=1,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
edge_trg_vocab = build_vocab(field="edge_trg", min_freq=1,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
src_field.vocab = src_vocab
trg_field.vocab = trg_vocab
edge_org_field.vocab = edge_org_vocab
edge_trg_field.vocab = edge_trg_vocab

In [25]:
edge_trg_vocab.stoi

defaultdict(<function __main__.<lambda>()>,
            {'<UNK>': 0,
             '<PAD>': 1,
             '<BOS>': 2,
             '<EOS>': 3,
             '1': 0,
             '2': 0,
             '3': 0,
             '4': 0,
             '0': 0})

In [14]:
data_iter = data.BucketIterator(
            repeat=False, sort=False, dataset=dataset,
            batch_size=3, batch_size_fn=None,
            train=True, sort_within_batch=True,
            sort_key=lambda x: len(x.src), shuffle=True)

In [16]:
batch=next(iter(data_iter))

In [20]:
batch.edge_org

(tensor([[0, 0, 0, 0, 0, 3],
         [0, 0, 0, 0, 0, 3],
         [0, 0, 0, 0, 0, 3]]),
 tensor([6, 6, 6]))

In [None]:

fields = [('edge_org', data.Field(sequential=False)), ('edge_trg', data.Field(sequential=False))]
data.Example.fromlist([[0,1,2,3,4,5],\
                       [1,2,3,4,5,0]],\
                      fields)


In [None]:
ex=data.Field(sequential=True)

In [None]:
ex.vocab