In [1]:
import spacy
import os
import torch

from spacy.lang.en.examples import sentences
from torchtext.vocab import build_vocab_from_iterator

# import torchtext.datasets as datasets
# from multi30k import Multi30k


def Multi30k(language_pair=None):
    corpus_lines_train = []

    for lan in language_pair:
        with open('text/train.{}'.format(lan), 'r') as file:
            corpus_lines_train.append(file.read().splitlines())
        # end
    # end

    corpus_train = list(zip(*corpus_lines_train))

    corpus_lines_eval = []

    for lan in language_pair:
        with open('text/val.{}'.format(lan), 'r') as file:
            corpus_lines_eval.append(file.read().splitlines())
        # end
    # end

    corpus_lines_eval = list(zip(*corpus_lines_train))

    return corpus_lines_train, corpus_lines_eval, None
# end
def load_tokenizers():

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en
# end

def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]
# end


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])
    # end
# end

def build_vocabulary(spacy_en):

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    print("Building English Vocabulary ...")
    train, val, test = Multi30k(language_pair=("de", "en"))

    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val, tokenize_en, index=1),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    vocab_tgt.set_default_index(vocab_tgt["<unk>"])

    return vocab_tgt
# end


def load_vocab(spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_tgt = build_vocabulary(spacy_en)
        torch.save(vocab_tgt, "vocab.pt")
    else:
        vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab_tgt))
    return vocab_tgt
# end

vocab = load_vocab(load_tokenizers())

Finished.
Vocabulary sizes:
6191


In [1]:
def collate_batch(
    batch,
    src_pipeline,
    tgt_pipeline,
    src_vocab,
    tgt_vocab,
    device,
    max_padding=128,
    pad_id=2,
):
    bs_id = torch.tensor([0], device=device)  # <s> token id
    eos_id = torch.tensor([1], device=device)  # </s> token id
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(
                    src_vocab(src_pipeline(_src)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        src_list.append(
            # warning - overwrites values for negative values of padding - len
            pad(
                processed_src,
                (
                    0,
                    max_padding - len(processed_src),
                ),
                value=pad_id,
            )
        )
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt)),
                value=pad_id,
            )
        )

    src = torch.stack(src_list)
    tgt = torch.stack(tgt_list)
    return (src, tgt)

In [2]:
def create_dataloaders(
    device,
    vocab_src,
    vocab_tgt,
    spacy_de,
    spacy_en,
    batch_size=12000,
    max_padding=128,
    is_distributed=True,
):
    # def create_dataloaders(batch_size=12000):
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_src,
            vocab_tgt,
            device,
            max_padding=max_padding,
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

    train_iter, valid_iter, test_iter = Multi30k(
        language_pair=("de", "en")
    )

    train_iter_map = to_map_style_dataset(
        train_iter
    )  # DistributedSampler needs a dataset len()
    train_sampler = (
        DistributedSampler(train_iter_map) if is_distributed else None
    )
    valid_iter_map = to_map_style_dataset(valid_iter)
    valid_sampler = (
        DistributedSampler(valid_iter_map) if is_distributed else None
    )

    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler,
        collate_fn=collate_fn,
    )
    return train_dataloader, valid_dataloader
# end