In [1]:
import torch
import os
from os.path import exists
import torch.nn as nn
from torch.nn.functional import log_softmax, pad, one_hot
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torch.utils.data import DataLoader
import random

### utils.py ###

class Dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
# end

### utils.py ###




### core.py ###

"Produce N identical layers."
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# end



class MultiHeadedAttention(nn.Module):

    "Take in model size and number of heads."
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    # end


    "Compute 'Scaled Dot Product Attention'"
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # end
        p_attn = scores.softmax(dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # end
        return torch.matmul(p_attn, value), p_attn
    # end


    "Implements Figure 2"
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = self.attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)
    # end
# end class


"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
class ResidualLayer(nn.Module):

    def __init__(self, size, dropout=0.1, eps=1e-6):
        super(ResidualLayer, self).__init__()
        self.norm = torch.nn.LayerNorm(size, eps)
        self.dropout = nn.Dropout(dropout)
    # end

    "Apply residual connection to any sublayer with the same size."
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    # end
# end class


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))
    # end
# end


class SimpleIDEmbeddings(nn.Module):
    def __init__(self, size_vocab, dim_hidden, id_pad):
        super(SimpleIDEmbeddings, self).__init__()
        self.lut = nn.Embedding(size_vocab, dim_hidden, padding_idx=id_pad)
        self.dim_hidden = dim_hidden

    def forward(self, x):
        result = self.lut(x)
        return result * math.sqrt(self.dim_hidden)
    # end

    def get_shape(self):
        return (self.lut.num_embeddings, self.lut.embedding_dim)
    # end
# end


"Implement the PE function."
class PositionalEncoding(nn.Module):

    def __init__(self, dim_positional, max_len=512):
        super(PositionalEncoding, self).__init__()

        # Compute the positional encodings once in log space.
        self.dim_positional = dim_positional
        pe = torch.zeros(max_len, dim_positional)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim_positional, 2) * -(math.log(10000.0) / dim_positional)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to('cuda')
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x
    # end
# end


class SimpleEmbedder(nn.Module):    # no segment embedder as we do not need that
    def __init__(self, size_vocab=None, dim_hidden=128, dropout=0.1, id_pad=0):
        super(SimpleEmbedder, self).__init__()
        self.size_vocab = size_vocab
        self.dim_hidden = dim_hidden
        self.id_pad = id_pad

        self.embedder = nn.Sequential(
            SimpleIDEmbeddings(size_vocab, dim_hidden, id_pad),
            PositionalEncoding(dim_hidden),
            nn.Dropout(dropout)
        )
    # end

    def forward(self, ids_input):   # (batch, seqs_with_padding)
        return self.embedder(ids_input)
    # end

    def get_vocab_size(self):
        return self.size_vocab
    # end
# end

### core.py ###



class SimpleEncoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleEncoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 2)
    # end

    def forward(self, embeddings, masks, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention(embeddings, embeddings, embeddings, masks))
        return self.layers_residual[1](embeddings, self.layer_feedforward)
    # end
# end



class SimpleDecoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleDecoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention_decoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_attention_encoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 3)

    def forward(self, embeddings, masks_decoder, output_encoder, masks_encoder, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention_decoder(embeddings, embeddings, embeddings, masks_decoder))
        embeddings = self.layers_residual[1](embeddings, lambda embeddings: self.layer_attention_encoder(embeddings, output_encoder, output_encoder, masks_encoder))
        return self.layers_residual[2](embeddings, self.layer_feedforward)
    # end
# end


class SimpleTransformerStack(nn.Module):

    def __init__(self, obj_layer, n_layers):
        super(SimpleTransformerStack, self).__init__()
        self.layers = clones(obj_layer, n_layers)

        self.norm = torch.nn.LayerNorm(obj_layer.dim_hidden)
        self.keys_cache = ['output']
        self.cache = Dotdict({
            'outputs': None
        })
    # end

    def forward(self, embedding_encoder=None, masks_encoder=None, output_encoder=None, embedding_decoder=None, masks_decoder=None ,noncache=False, **kwargs):  # input -> (batch, len_seq, vocab)

        if output_encoder is not None and embedding_decoder is not None and masks_decoder is not None:
            embeddings = embedding_decoder
        else:
            embeddings = embedding_encoder
        # end

        for layer in self.layers:
            embeddings = layer(embeddings, masks_encoder, output_encoder, masks_decoder)
        # end

        outputs = self.norm(embeddings)

        if not noncache:
            self.cache.outputs = outputs
        # end

        return outputs
    # end

    # def get_vocab_size(self):
    #     return self.embedder.embedder_token.shape[-1]
    # # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end


class SimpleEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder_encoder, embedder_decoder):
        super(SimpleEncoderDecoder, self).__init__()

        self.embedder_encoder = embedder_encoder
        self.encoder = encoder

        self.embedder_decoder = embedder_decoder
        self.decoder = decoder
    # end

    def forward(self, ids_encoder=None, masks_encoder=None, ids_decoder=None, masks_decoder=None, nocache=False, **kwargs):
        self.clear_cache()

        embedding_encoder = self.embedder_encoder(ids_encoder)
        output_encoder = self.encoder(
            embedding_encoder=embedding_encoder,
            masks_encoder=masks_encoder,
            nocache=nocache
        )
        embedding_decoder = self.embedder_decoder(ids_decoder)
        output_decoder = self.decoder(
            masks_encoder=masks_encoder,
            output_encoder=output_encoder,
            embedding_decoder=embedding_decoder,
            masks_decoder=masks_decoder,
            nocache=nocache
        )
        return output_decoder
    # end

    def clear_cache(self):
        self.encoder.clear_cache()
        self.decoder.clear_cache()
    # end

    def get_vocab_size(self, name_embedder):
        embedder = getattr(self, f'embedder_{name_embedder}')
        return embedder.get_vocab_size()
    # end

# end

class LinearAndNorm(nn.Module):
    def __init__(self, dim_in = None, dim_out = None, eps_norm=1e-12):
        super(LinearAndNorm, self).__init__()

        self.linear = torch.nn.Linear(dim_in, dim_out)
        self.norm = torch.nn.LayerNorm(dim_out, eps_norm)
    # end

    def forward(self, seqs_in):
        return self.norm(self.linear(seqs_in))
    # end
# end

class SimpleEncoderHead_MLM(nn.Module):

    def __init__(self, size_vocab, dim_hidden=128):
        super(SimpleEncoderHead_MLM, self).__init__()

        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab)

        self.func_loss = torch.nn.CrossEntropyLoss()

        self.keys_cache = ['loss', 'outputs', 'labels_mlm']
        self.cache = Dotdict({
            'loss': None,
            'outputs': None,
            'labels_mlm': None
        })
    # end

    def forward(self, model, labels_encoder=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        self.clear_cache()

        outputs_encoder = model.encoder.cache.outputs    # -> (batch, seqs_input, dim_hidden)
        outputs_mlm = self.extractor(self.ffn(outputs_encoder)) # outputs_mlm = prediction_logits

        if not nocache:
            self.cache.outputs = outputs_mlm
            self.cache.labels_mlm = labels_encoder
        # end

        return outputs_mlm
    # end

    def get_loss(self):
        outputs_mlm = self.cache.outputs
        labels_mlm = self.cache.labels_mlm

        loss_mlm = self.func_loss(outputs_mlm.view(-1, outputs_mlm.size(-1)), labels_mlm.view(-1))  # labels is 1-hot labels
        self.cache.loss = loss_mlm
        return loss_mlm
    # end

    def evaluate(self):
        pass
    # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end

class SimpleEncoderHead_Similarity(nn.Module):

    def __init__(self, model):
        super(SimpleEncoderHead_Similarity, self).__init__()
        self.model = model
        self.func_loss = torch.nn.MSELoss()
        self.cos_score_transformation = torch.nn.Identity()
        self.cache = Dotdict({
            'loss': None,
            'labels_sim': None
        })
    # end

    def forward(self, labels_sim=None, nocache=False, **kwargs):  # labels_sim (batch/2, 1)   for every two sentences, we have a label
        self.clear_cache()

        outputs_encoder = self.model.encoder.outputs
        size_batch, len_seq, dim_hidden = outputs_encoder.shape

        if size_batch % 2 != 0:
            raise Exception('sim calculation is not prepared as size_batch % 2 != 0')
        # end

        # pooling (batch, pair, dim_hidden)
        outputs_pooling = outputs_encoder[:, 0, :].squeeze(1).view(-1, 2, dim_hidden)   # might cls + sep, but abandon now (as it's not easy to get sep for every batch, different location)
        sims = self.cos_score_transformation(torch.cosine_similarity(outputs_pooling))  # -> (batch, scores)

        if not nocache:
            self.cache.sims = sims
            self.cache.labels_sim = labels_sim
        # end

        return sims
    # end

    def get_loss(self):
        sims = self.cache.sims
        labels_sim = self.cache.labels_sim

        loss_sim = self.func_loss(sims, labels_sim)
        self.cache.loss = loss_sim
        return loss_sim
    # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end

    def evaluate(self):
        pass
    # end
# end

class SimpleDecoderHead_S2S(nn.Module):

    def __init__(self, size_vocab, dim_hidden=128):
        super(SimpleDecoderHead_S2S, self).__init__()
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab)

        self.func_loss = torch.nn.CrossEntropyLoss()

        self.keys_cache = ['loss', 'outputs', 'labels_s2s']
        self.cache = Dotdict({
            'loss': None,
            'outputs': None,
            'labels_s2s': None
        })
    # end

    def forward(self, model, labels_decoder=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        self.clear_cache()

        outputs_decoder = model.decoder.cache.outputs    # -> (batch, seqs_input, dim_hidden)
        outputs_s2s = self.extractor(self.ffn(outputs_decoder))   # outputs_mlm = prediction_logits


        if not nocache:
            self.cache.outputs = outputs_s2s
            self.cache.labels_s2s = labels_decoder
        # end

        return outputs_s2s
    # end


    def get_loss(self):
        outputs_s2s = self.cache.outputs
        labels_s2s = self.cache.labels_s2s

        loss_s2s = self.func_loss(outputs_s2s.view(-1, outputs_s2s.size(-1)), labels_s2s.view(-1))  # labels is 1-hot labels
        self.cache.loss = loss_s2s
        return loss_s2s
    # end


    def evaluate(self):
        pass
    # end


    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end


class HeadManager:
    def __init__(self):
        self.index_name_head = {}
    # end

    def register(self, head):
        name_head = head.__class__.__name__
        self.index_name_head[name_head] = head
    # end

    def forward(self, model, **kwargs):
        for name_head, head in self.index_name_head.items():
            head.forward(model, **kwargs)
        # end
    # end

    def get_head(self, klass):
        return self.index_name_head.get(klass.__name__, None)
    # end
    
    def to(self, device):
        for name_head in self.index_name_head:
            self.index_name_head[name_head] = self.index_name_head[name_head].to(device)
        # end
        
        return self
    # end
    
    def clear_cache(self):
        for _, head in self.index_name_head.items():
            head.clear_cache()
        # end
    # end
# end


class Builder:


    @classmethod
    def build_model_with_s2s(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)

        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)

        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder)

        manager = HeadManager()
        manager.register(SimpleDecoderHead_S2S(size_vocab, dim_hidden))
        return model, manager
    # end
    
    @classmethod
    def build_model_with_mlm(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)

        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)

        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder)

        manager = HeadManager()
        manager.register(SimpleEncoderHead_MLM(size_vocab, dim_hidden))
        return model, manager
    # end

    @classmethod
    def build_model(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)

        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)

        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder)

        manager = HeadManager()
        manager.register(SimpleEncoderHead_MLM(size_vocab, dim_hidden))
        manager.register(SimpleDecoderHead_S2S(size_vocab, dim_hidden))
        return model, manager
    # end
# end


class TokenizerWrapper:
    def __init__(self, vocab, splitter):
        self.splitter = splitter
        self.vocab = vocab

        self.id_pad = len(vocab)
        self.id_cls = len(vocab) + 1
        self.id_sep = len(vocab) + 2
        self.id_mask = len(vocab) + 3
        
        self.size_vocab = len(vocab) + 4

        self.token_pad = '[PAD]'
        self.token_cls = '[CLS]'
        self.token_sep = '[SEP]'
        self.token_mask = '[MASK]'
        
    # end

    def encode(self, line):
        return self.vocab([doc.text for doc in self.splitter(line)])
    # end

    def decode(self):
        pass
    # end
# end


class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end


class Collator_S2S:

    def __init__(self, tokenizer, size_seq_max):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
    # end

    def __call__(self, list_corpus_line):

        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []

        for corpus_line in list_corpus_line:
            tokens = self.tokenizer.encode(corpus_line[1])
            
            # TODO: check edge
            if len(tokens) > self.size_seq_max - 2:
                tokens = tokens[:self.size_seq_max-2]
            # end

            tokens_input_encoder.append([self.tokenizer.id_cls] + tokens + [self.tokenizer.id_sep])
            tokens_input_decoder.append([self.tokenizer.id_cls] + tokens)
            tokens_label_decoder.append(tokens + [self.tokenizer.id_sep])
        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=0.1)
        inputs_decoder, masks_decoder, segments_decoder = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label = self.pad_sequences(tokens_label_decoder, self.size_seq_max)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label
        )
    # end

    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.tokenizer.id_pad
        id_mask = self.tokenizer.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(len_seq))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * len_seq)])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)

        masks_segment = (inputs != self.tokenizer.id_pad).unsqueeze(-2).expand(inputs.shape[0], inputs.shape[-1], inputs.shape[-1]) #(nbatch, seq, seq)
        masks_attention = self.make_std_mask(inputs, self.tokenizer.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded == id_mask).unsqueeze(-2).expand(inputs.shape[0], inputs.shape[-1], inputs.shape[-1])
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end


In [2]:
import spacy


def Multi30k(language_pair=None):
    corpus_lines_train = []

    for lan in language_pair:
        with open('text/train.{}'.format(lan), 'r') as file:
            corpus_lines_train.append(file.read().splitlines())
        # end
    # end

    corpus_train = list(zip(*corpus_lines_train))

    corpus_lines_eval = []

    for lan in language_pair:
        with open('text/val.{}'.format(lan), 'r') as file:
            corpus_lines_eval.append(file.read().splitlines())
        # end
    # end

    corpus_lines_eval = list(zip(*corpus_lines_train))

    return corpus_lines_train, corpus_lines_eval, None
# end


def load_vocab(spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_tgt = build_vocabulary(spacy_en)
        torch.save(vocab_tgt, "vocab.pt")
    else:
        vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes: {}".format(len(vocab_tgt)))
    return vocab_tgt
# end

def load_spacy():

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en
# end

In [3]:
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset


gpu = 0
torch.cuda.set_device(gpu)

seq_max = 64
batch_size = 4
dim_hidden = 128
dim_feedforward = 128
n_head = 4
n_layer = 2

spacy_en = load_spacy()
vocab = load_vocab(spacy_en)
tokenizer = TokenizerWrapper(vocab, spacy_en)

train_iter, valid_iter, _ = Multi30k(language_pair=("de", "en"))
train_source = to_map_style_dataset(valid_iter)


collator = Collator_S2S(tokenizer, seq_max)
dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)

model, manager = Builder.build_model(tokenizer.size_vocab, dim_hidden, dim_feedforward, n_head, n_layer)
model = model.to('cuda')
manager = manager.to('cuda')

Finished.
Vocabulary sizes: 6191


In [4]:
def train_a_batch(batch, model, manager, optimizer=None):
    model.forward(**batch())    # save to cache
    manager.forward(model, **batch())    # save to cache

    loss_mlm = manager.get_head(SimpleEncoderHead_MLM).get_loss()
    loss_s2s = manager.get_head(SimpleDecoderHead_S2S).get_loss()
    # loss_sim = manager_head.get_head('sim_encoder').get_loss()

    # cross entropy loss
    loss_crossentropy = loss_mlm + loss_s2s
    # loss_crossentropy = loss_s2s
    # loss_crossentropy = loss_mlm
    print(loss_crossentropy)
    loss_crossentropy.backward()

    # mean square loss
    # loss_sim.backward()
    if optimizer:
        optimizer.step()
    # end
    
    manager.clear_cache()
# end

In [5]:
for batch in dataloader_train:
    train_a_batch(batch, model, manager)
# end

tensor(15.9083, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9723, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9543, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9319, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9284, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.0365, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9560, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9623, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.1229, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9714, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.8843, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.0467, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.9698, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.0610, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15.8572, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.0188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.1154, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16.0984, device='cuda:0'

KeyboardInterrupt: 

In [None]:
model