In [1]:
import spacy
import os

def Multi30k(language_pair=None):
    corpus_lines_train = []

    for lan in language_pair:
        with open('text/train.{}'.format(lan), 'r') as file:
            corpus_lines_train.append(file.read().splitlines())
        # end
    # end

    corpus_train = list(zip(*corpus_lines_train))

    corpus_lines_eval = []

    for lan in language_pair:
        with open('text/val.{}'.format(lan), 'r') as file:
            corpus_lines_eval.append(file.read().splitlines())
        # end
    # end

    corpus_eval = list(zip(*corpus_lines_eval))

    return corpus_train, corpus_eval, None
# end


def load_vocab(spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_tgt = build_vocabulary(spacy_en)
        torch.save(vocab_tgt, "vocab.pt")
    else:
        vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes: {}".format(len(vocab_tgt)))
    return vocab_tgt
# end

def load_spacy():

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en
# end

In [2]:
import torch
import random

class TokenizerWrapper:
    def __init__(self, vocab, splitter):
        self.splitter = splitter
        self.vocab = vocab

        self.id_pad = len(vocab)
        self.id_cls = len(vocab) + 1
        self.id_sep = len(vocab) + 2
        self.id_mask = len(vocab) + 3
        
        self.size_vocab = len(vocab) + 4

        self.token_pad = '[PAD]'
        self.token_cls = '[CLS]'
        self.token_sep = '[SEP]'
        self.token_mask = '[MASK]'
        
    # end

    def encode(self, line):
        return self.vocab([doc.text for doc in self.splitter(line)])
    # end

    def decode(self):
        pass
    # end
# end


class Batch:
    DEVICE = 'cpu'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end


class Collator_S2S:

    def __init__(self, tokenizer, size_seq_max):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
    # end

    def __call__(self, list_corpus_source):

        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        labels_similarity = []

        for corpus_source in list_corpus_source: # (line0, line1, sim), output of zip remove single case
            if corpus_source == 3:
                corpus_line = [courpus_source[0], corpus_source[1]]
                labels_similarity.append(corpus_line[2])
            else:
                corpus_line = [corpus_source[1]]
            # end
            
            for line in corpus_line:
                tokens = self.tokenizer.encode(line)

                # TODO: check edge
                if len(tokens) > self.size_seq_max - 2:
                    tokens = tokens[:self.size_seq_max-2]
                # end

                tokens_input_encoder.append([self.tokenizer.id_cls] + tokens + [self.tokenizer.id_sep])
                tokens_input_decoder.append([self.tokenizer.id_cls] + tokens)
                tokens_label_decoder.append(tokens + [self.tokenizer.id_sep])
            # end
            

        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=0.3)
        inputs_decoder, masks_decoder, segments_decoder = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        # labels_similarity = torch.Tensor(labels_similarity).unsqueeze(0).transpose(0,1)
        labels_similarity = torch.Tensor(labels_similarity)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label,
            labels_similarity=labels_similarity
        )
    # end

    
    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.tokenizer.id_pad
        id_mask = self.tokenizer.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])
                print(index_masked)

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.tokenizer.id_pad).unsqueeze(-2).expand(inputs.shape[0], inputs.shape[-1], inputs.shape[-1]) #(nbatch, seq, seq)
        masks_attention = self.make_std_mask(inputs, self.tokenizer.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2).expand(inputs.shape[0], inputs.shape[-1], inputs.shape[-1])
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end

In [3]:
import json
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset


seq_max = 16
batch_size = 2



spacy_en = load_spacy()
vocab = load_vocab(spacy_en)
tokenizer = TokenizerWrapper(vocab, spacy_en)

train_iter, valid_iter, _ = Multi30k(language_pair=("de", "en"))

train_source = to_map_style_dataset(train_iter)


collator = Collator_S2S(tokenizer, seq_max)
dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)

Finished.
Vocabulary sizes: 6191


In [4]:
for i, batch in enumerate(dataloader_train):
    if i >= 1:
        break
    # end
    
    info_batch = batch()
    print('ids_encoder')
    print(info_batch['ids_encoder'])
    print('\n\nlabels_encoder')
    print(info_batch['labels_encoder'])
    print('\n\nmasks_encoder')
    print(info_batch['masks_encoder'])
    print('\n\nids_decoder')
    print(info_batch['ids_decoder'])
    print('\n\nlabels_decoder')
    print(info_batch['labels_decoder'])
    print('\n\nmasks_decoder')
    print(info_batch['masks_decoder'])    
    print('\n\nsegments_label')
    print(info_batch['segments_label'])
    
# end

tensor([ 7, 10,  3])
tensor([12,  8,  7])
ids_encoder
tensor([[6192,   19,   25, 6194, 1169,  808,   17, 6194,   84,  336, 6194,    5,
         6193, 6191, 6191, 6191],
        [6192,  164,   36,    7,  333,  286,   17, 6194, 6194,  744, 3732, 2678,
         6194, 6193, 6191, 6191]])


labels_encoder
tensor([[6192,   19,   25,   15, 1169,  808,   17,   57,   84,  336, 1339,    5,
         6193, 6191, 6191, 6191],
        [6192,  164,   36,    7,  333,  286,   17, 1191,    4,  744, 3732, 2678,
            5, 6193, 6191, 6191]])


masks_encoder
tensor([[[ True,  True,  True, False,  True,  True,  True, False,  True,  True,
          False,  True,  True, False, False, False],
         [ True,  True,  True, False,  True,  True,  True, False,  True,  True,
          False,  True,  True, False, False, False],
         [ True,  True,  True, False,  True,  True,  True, False,  True,  True,
          False,  True,  True, False, False, False],
         [ True,  True,  True, False,  True,  True, 