In [1]:
import torch
import os
from os.path import exists
import torch.nn as nn
from torch.nn.functional import log_softmax, pad, one_hot
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torch.utils.data import DataLoader
import random

### utils.py ###

class Dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
# end

### utils.py ###




### core.py ###

"Produce N identical layers."
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# end



class MultiHeadedAttention(nn.Module):

    "Take in model size and number of heads."
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    # end


    "Compute 'Scaled Dot Product Attention'"
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # print('jinyuj: scores: {}, mask: {}'.format(scores.shape, mask.shape))
            scores = scores.masked_fill(mask == 0, -1e9)
        # end
        p_attn = scores.softmax(dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # end
        return torch.matmul(p_attn, value), p_attn
    # end


    "Implements Figure 2"
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = self.attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)
    # end
# end class


"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
class ResidualLayer(nn.Module):

    def __init__(self, size, dropout=0.1, eps=1e-6):
        super(ResidualLayer, self).__init__()
        self.norm = torch.nn.LayerNorm(size, eps)
        self.dropout = nn.Dropout(dropout)
    # end

    "Apply residual connection to any sublayer with the same size."
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    # end
# end class


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))
    # end
# end


class SimpleIDEmbeddings(nn.Module):
    def __init__(self, size_vocab, dim_hidden, id_pad):
        super(SimpleIDEmbeddings, self).__init__()
        self.lut = nn.Embedding(size_vocab, dim_hidden, padding_idx=id_pad)
        self.dim_hidden = dim_hidden

    def forward(self, x):
        result = self.lut(x)
        return result * math.sqrt(self.dim_hidden)
    # end

    def get_shape(self):
        return (self.lut.num_embeddings, self.lut.embedding_dim)
    # end
# end


"Implement the PE function."
class PositionalEncoding(nn.Module):

    def __init__(self, dim_positional, max_len=512):
        super(PositionalEncoding, self).__init__()

        # Compute the positional encodings once in log space.
        self.dim_positional = dim_positional
        pe = torch.zeros(max_len, dim_positional)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim_positional, 2) * -(math.log(10000.0) / dim_positional)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to('cuda')
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x
    # end
# end


class SimpleEmbedder(nn.Module):    # no segment embedder as we do not need that
    def __init__(self, size_vocab=None, dim_hidden=128, dropout=0.1, id_pad=0):
        super(SimpleEmbedder, self).__init__()
        self.size_vocab = size_vocab
        self.dim_hidden = dim_hidden
        self.id_pad = id_pad

        self.embedder = nn.Sequential(
            SimpleIDEmbeddings(size_vocab, dim_hidden, id_pad),
            PositionalEncoding(dim_hidden),
            nn.Dropout(dropout)
        )
    # end

    def forward(self, ids_input):   # (batch, seqs_with_padding)
        return self.embedder(ids_input)
    # end

    def get_vocab_size(self):
        return self.size_vocab
    # end
# end

### core.py ###



class SimpleEncoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleEncoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 2)
    # end

    def forward(self, embeddings, masks, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention(embeddings, embeddings, embeddings, masks))
        return self.layers_residual[1](embeddings, self.layer_feedforward)
    # end
# end



class SimpleDecoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleDecoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention_decoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_attention_encoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 3)

    def forward(self, embeddings, masks_encoder, output_encoder, masks_decoder, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention_decoder(embeddings, embeddings, embeddings, masks_decoder))
        # print('jinyuj: embeddings.shape {}, masks_encoder.shape {}'.format(embeddings.shape, masks_encoder.shape))
        # if embeddings.shape[1] != masks_encoder.shape[1]:
        #     masks_encoder = masks_encoder[:,:embeddings.shape[1],:]
        # # end
        embeddings = self.layers_residual[1](embeddings, lambda embeddings: self.layer_attention_encoder(embeddings, output_encoder, output_encoder, masks_encoder))
        return self.layers_residual[2](embeddings, self.layer_feedforward)
    # end
# end


class SimpleTransformerStack(nn.Module):

    def __init__(self, obj_layer, n_layers):
        super(SimpleTransformerStack, self).__init__()
        self.layers = clones(obj_layer, n_layers)

        self.norm = torch.nn.LayerNorm(obj_layer.dim_hidden)
        self.keys_cache = ['output']
        self.cache = Dotdict({
            'outputs': None
        })
    # end

    def forward(self, embedding_encoder=None, masks_encoder=None, output_encoder=None, embedding_decoder=None, masks_decoder=None ,noncache=False, **kwargs):  # input -> (batch, len_seq, vocab)

        if output_encoder is not None and embedding_decoder is not None and masks_decoder is not None:
            embeddings = embedding_decoder
        else:
            embeddings = embedding_encoder
        # end

        for layer in self.layers:
            embeddings = layer(embeddings, masks_encoder, output_encoder, masks_decoder)
        # end

        outputs = self.norm(embeddings)

        if not noncache:
            self.cache.outputs = outputs
        # end

        return outputs
    # end

    # def get_vocab_size(self):
    #     return self.embedder.embedder_token.shape[-1]
    # # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end


class SimpleEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder_encoder, embedder_decoder):
        super(SimpleEncoderDecoder, self).__init__()

        self.embedder_encoder = embedder_encoder
        self.encoder = encoder

        self.embedder_decoder = embedder_decoder
        self.decoder = decoder
    # end

    def forward(self, ids_encoder=None, masks_encoder=None, ids_decoder=None, masks_decoder=None, nocache=False, **kwargs):
        
        output_encoder = self.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder, nocache=nocache)
        output_encoder_refilled = output_encoder.masked_fill(masks_encoder.transpose(-1,-2)==False, 0)
        output_encoder_pooled = torch.mean(output_encoder_refilled, dim=-2)
        output_encoder_pooled_expanded = output_encoder_pooled.unsqueeze(-2).expand(output_encoder.shape)
        output = output_encoder_pooled_expanded
        
        if self.embedder_decoder and self.decoder:
            output_decoder = self.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=output_encoder, masks_decoder=masks_decoder, nocache=nocache)
            output = output_decoder
        # end if
        
        return output
    # end
    
    def embed_and_encode(self, ids_encoder=None, masks_encoder=None, nocache=False, **kwargs):
        self.encoder.clear_cache()
        
        embedding_encoder = self.embedder_encoder(ids_encoder)
        output_encoder = self.encoder(
            embedding_encoder=embedding_encoder,
            masks_encoder=masks_encoder,
            nocache=nocache
        )
        
        return output_encoder
    # end

    
    def embed_and_decode(self, ids_decoder=None, masks_encoder=None, output_encoder=None, masks_decoder=None, nocache=False, **kwargs):
        self.decoder.clear_cache()
        
        embedding_decoder = self.embedder_decoder(ids_decoder)
        output_decoder = self.decoder(
            masks_encoder=masks_encoder,
            output_encoder=output_encoder,    #(len_seq, dim_hidden) -> (1, dim_hidden)
            embedding_decoder=embedding_decoder,
            masks_decoder=masks_decoder,
            nocache=nocache
        )

        return output_decoder
    # end
    

    def clear_cache(self):
        self.encoder.clear_cache()
        if self.decoder:
            self.decoder.clear_cache()
        # end
    # end

    def get_vocab_size(self, name_embedder):
        embedder = getattr(self, f'embedder_{name_embedder}')
        return embedder.get_vocab_size()
    # end

# end

class LinearAndNorm(nn.Module):
    def __init__(self, dim_in = None, dim_out = None, eps_norm=1e-12):
        super(LinearAndNorm, self).__init__()

        self.linear = torch.nn.Linear(dim_in, dim_out)
        self.norm = torch.nn.LayerNorm(dim_out, eps_norm)
    # end

    def forward(self, seqs_in):
        return self.norm(self.linear(seqs_in).relu())
    # end
# end




class TokenizerWrapper:
    def __init__(self, vocab, splitter):
        self.splitter = splitter
        self.vocab = vocab

        self.id_pad = len(vocab)
        self.id_cls = len(vocab) + 1
        self.id_sep = len(vocab) + 2
        self.id_mask = len(vocab) + 3
        
        self.size_vocab = len(vocab) + 4

        self.token_pad = '[PAD]'
        self.token_cls = '[CLS]'
        self.token_sep = '[SEP]'
        self.token_mask = '[MASK]'
           
        self.index_id_token_special = {
            self.id_pad: self.token_pad,
            self.id_cls: self.token_cls,
            self.id_sep: self.token_sep,
            self.id_mask: self.token_mask
        }
        
    # end

    def encode(self, line):
        return self.vocab([doc.text for doc in self.splitter(line)])
    # end

    def decode(self, tokens):
        words = []
        for token in tokens:
            token = int(token)
            
            if token in self.index_id_token_special:
                word_target = self.index_id_token_special[token]
            else:
                try:
                    word_target = vocab.lookup_token(token)
                except:
                    word_target = '[ERROR_LOOKUP_{}]'.format(token)
                # end
            # end
            
            words.append(word_target)
        # end
        
        return ' '.join(words)
    # end
# end



class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end



class Collator_S2S:

    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked
    # end
    

    def __call__(self, list_corpus_source):

        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        labels_similarity = []

        for corpus_source in list_corpus_source: # (line0, line1, sim), output of zip remove single case
            if len(corpus_source) == 3:
                corpus_line = [courpus_source[0], corpus_source[1]]
                labels_similarity.append(corpus_line[2])
            else:
                corpus_line = [corpus_source[1]]
            # end
            
            for line in corpus_line:
                tokens = self.tokenizer.encode(line)

                # TODO: check edge
                if len(tokens) > self.size_seq_max - 2:
                    tokens = tokens[:self.size_seq_max-2]
                # end

                tokens_input_encoder.append([self.tokenizer.id_cls] + tokens + [self.tokenizer.id_sep])
                tokens_input_decoder.append([self.tokenizer.id_cls] + tokens)
                tokens_label_decoder.append(tokens + [self.tokenizer.id_sep])
            # end
            

        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        # labels_similarity = torch.Tensor(labels_similarity).unsqueeze(0).transpose(0,1)
        labels_similarity = torch.Tensor(labels_similarity)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label,
            labels_similarity=labels_similarity
        )
    # end

    
    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.tokenizer.id_pad
        id_mask = self.tokenizer.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.tokenizer.id_pad).unsqueeze(-2)    #(nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.tokenizer.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end

In [2]:
import spacy


def Multi30k(language_pair=None):
    corpus_lines_train = []

    for lan in language_pair:
        with open('text/train.{}'.format(lan), 'r') as file:
            corpus_lines_train.append(file.read().splitlines())
        # end
    # end

    corpus_train = list(zip(*corpus_lines_train))

    corpus_lines_eval = []

    for lan in language_pair:
        with open('text/val.{}'.format(lan), 'r') as file:
            corpus_lines_eval.append(file.read().splitlines())
        # end
    # end

    corpus_eval = list(zip(*corpus_lines_eval))

    return corpus_train, corpus_eval, None
# end


def load_vocab(spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_tgt = build_vocabulary(spacy_en)
        torch.save(vocab_tgt, "vocab.pt")
    else:
        vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes: {}".format(len(vocab_tgt)))
    return vocab_tgt
# end

def load_spacy():

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en
# end

In [3]:
# class SimpleEncoderHead_MLM(nn.Module):

#     def __init__(self, model, size_vocab, dim_hidden=128):
#         super(SimpleEncoderHead_MLM, self).__init__()
#         self.model = model
        
#         self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
#         self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
#         self.extractor.weight = nn.Parameter(model.embedder_encoder.embedder[0].lut.weight)
        
#         self.func_loss = torch.nn.CrossEntropyLoss()
#     # end


#     def forward(self, **kwargs):   # labels_input -> (batch, seq, labels)
#         labels_mlm = kwargs['labels_encoder']
        
#         outputs_encoder = self.model(**kwargs)
#         outputs_ffn = self.ffn(outputs_encoder)
#         outputs_mlm = self.extractor(outputs_ffn) # outputs_mlm = prediction_logits
        
#         segments_encoder = kwargs['segments_encoder']        
#         segments_encoder_2d = segments_encoder.transpose(-1,-2)[:,:,0]

#         # loss_segments = self.func_loss(outputs_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, outputs_mlm.shape[-1]), labels_mlm.masked_select(segments_encoder_2d)) / segments_encoder_2d.reshape(-1).shape[0]
#         loss_segments = self.func_loss(outputs_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, outputs_mlm.shape[-1]), labels_mlm.masked_select(segments_encoder_2d))
        
#         masks_encoder = kwargs['masks_encoder']
#         masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
#         masks_masked_perbatch = masks_masked[:,0,:]
#         # loss_masked = self.func_loss(outputs_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, outputs_mlm.shape[-1]), labels_mlm.masked_select(masks_masked_perbatch)) / masks_masked_perbatch.reshape(-1).shape[0]
#         loss_masked = self.func_loss(outputs_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, outputs_mlm.shape[-1]), labels_mlm.masked_select(masks_masked_perbatch))       
        
#         # loss_mlm = loss_segments + loss_masked * 3
#         loss_mlm = loss_segments
        
#         return outputs_mlm, loss_mlm
#     # end

# # end

In [4]:
class SimpleDecoderHead_S2S(nn.Module):

    def __init__(self, model, size_vocab, dim_hidden=128):
        super(SimpleDecoderHead_S2S, self).__init__()
        self.model = model
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
        
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
        self.extractor.weight = nn.Parameter(model.embedder_decoder.embedder[0].lut.weight)

        self.func_loss = torch.nn.CrossEntropyLoss()

    # end

    
    def forward(self, **kwargs):   # labels_input -> (batch, seq, labels)
        labels_s2s = kwargs['labels_decoder']
        
        outputs_decoder = self.model(**kwargs)
        outputs_ffn = self.ffn(outputs_decoder)
        outputs_s2s = self.extractor(outputs_ffn)   # outputs_mlm = prediction_logits
        
        segments_decoder = kwargs['segments_label']
        segments_decoder_2d = segments_decoder.transpose(-1,-2)[:,:,0]

        loss_segments = self.func_loss(outputs_s2s.masked_select(segments_decoder_2d.unsqueeze(-1)).reshape(-1, outputs_s2s.shape[-1]), labels_s2s.masked_select(segments_decoder_2d))

        return outputs_s2s, loss_segments
    # end


    def beam_generate(self):
        pass
    # end

# end

In [5]:
class Builder:
    
    @classmethod
    def build_model_with_mlm(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, None, embedder_encoder, None)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)

        return head_mlm
    # end
    
    @classmethod
    def build_model_with_s2s(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)

        return head_s2s
    # end

# end

In [6]:
import json
import transformers
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset


gpu = 0
torch.cuda.set_device(gpu)

epochs = 60

# source
seq_max = 16
batch_size = 16


# model & head
dim_hidden = 512
dim_feedforward = 512
n_head = 8
n_layer = 8

# optimizer
lr_base_optimizer = 5e-4
betas_optimizer = (0.9, 0.999)
eps_optimizer = 1e-9

# scheduler
warmup = 200

spacy_en = load_spacy()
vocab = load_vocab(spacy_en)
tokenizer = TokenizerWrapper(vocab, spacy_en)

train_iter, valid_iter, _ = Multi30k(language_pair=("de", "en"))
train_source = to_map_style_dataset(train_iter)

collator = Collator_S2S(tokenizer, seq_max)
dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)

head = Builder.build_model_with_s2s(tokenizer.size_vocab, dim_hidden, dim_feedforward, n_head, n_layer)

for p in head.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
    # end
# end

head = head.to('cuda')


optimizer = torch.optim.Adam(head.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
decayRate = 0.96
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)

print()

Finished.
Vocabulary sizes: 6191



In [7]:
def train_a_batch(batch, head, optimizer=None, scheduler=None):
    head.train()
    _, loss_s2s = head.forward(**batch())    # save to cache

    # crossentropy loss
    
    loss_all = loss_s2s * 5
    loss_all_value = loss_all.item()
    
    # print(loss_all)
    loss_all.backward()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      

    
    if optimizer:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    # end
    
    if scheduler:
        scheduler.step()
    # end
    # manager.clear_cache()
    return loss_all_value
# end

In [8]:
from datetime import datetime


for e in range(epochs):
    losss_per_e = []
    for i, batch in enumerate(dataloader_train):
        loss_current = train_a_batch(batch, head, optimizer, None)
        losss_per_e.append(loss_current)
        if i % 100 == 0:
            print('Epoch: {} Batch: {}, loss: {}, rate: {}'.format(e, i, loss_current, optimizer.param_groups[0]['lr']))
            # break
        # end
    # end
    
    loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
    print('[{}] Epoch: {} ends. Average loss: {}'.format(datetime.utcnow(), e, loss_average_per_e))
    
    lr_scheduler.step() # schedule per epoch
# end

Epoch: 0 Batch: 0, loss: 43.997989654541016, rate: 0.0001
Epoch: 0 Batch: 100, loss: 25.74346923828125, rate: 0.0001
Epoch: 0 Batch: 200, loss: 24.688373565673828, rate: 0.0001
Epoch: 0 Batch: 300, loss: 19.072385787963867, rate: 0.0001
Epoch: 0 Batch: 400, loss: 19.663721084594727, rate: 0.0001
Epoch: 0 Batch: 500, loss: 21.98863983154297, rate: 0.0001
Epoch: 0 Batch: 600, loss: 17.68052101135254, rate: 0.0001
Epoch: 0 Batch: 700, loss: 19.75946044921875, rate: 0.0001
Epoch: 0 Batch: 800, loss: 18.86446762084961, rate: 0.0001
Epoch: 0 Batch: 900, loss: 17.95880889892578, rate: 0.0001
Epoch: 0 Batch: 1000, loss: 18.245805740356445, rate: 0.0001
Epoch: 0 Batch: 1100, loss: 17.002105712890625, rate: 0.0001
Epoch: 0 Batch: 1200, loss: 16.108827590942383, rate: 0.0001
Epoch: 0 Batch: 1300, loss: 13.102991104125977, rate: 0.0001
Epoch: 0 Batch: 1400, loss: 12.315221786499023, rate: 0.0001
Epoch: 0 Batch: 1500, loss: 13.446344375610352, rate: 0.0001
Epoch: 0 Batch: 1600, loss: 12.85431289672

In [9]:
head.eval()
print()




In [None]:
def greedy_generate(head, tokenizer, collator, **kwargs):
    id_start = tokenizer.id_cls
    id_end = tokenizer.id_sep
    size_seq_max = collator.size_seq_max

    ids_encoder = kwargs['ids_encoder']
    masks_encoder = kwargs['masks_encoder']

    outputs_encoder = head.model.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder)
    ids_decoder = torch.zeros(1, 1).fill_(id_start).type_as(ids_encoder.data)

    for i in range(size_seq_max - 1):
        masks_decoder = collator.subsequent_mask(ids_decoder.size(1)).type_as(ids_encoder.data)
        outputs_decoder = head.model.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=outputs_encoder, masks_decoder=masks_decoder)
        outputs_ffn = head.ffn(outputs_decoder)
        outputs_s2s = head.extractor(outputs_ffn)   # outputs_mlm = prediction_logits

        logits_nextword = torch.softmax(outputs_s2s[:, -1], dim=-1)  # mynote: select dim2=-1, remain=all; last is the next
        
        id_nextword = torch.argmax(logits_nextword, dim=-1)
        id_nextword = id_nextword.data[0]

        if id_nextword == id_end:
            break
        # end

        ids_decoder = torch.cat([ids_decoder, torch.zeros(1, 1).type_as(ids_encoder.data).fill_(id_nextword)], dim=1)
    # end

    return ids_decoder
# end


In [19]:
eval_source = to_map_style_dataset(valid_iter)
dataloader_eval = DataLoader(eval_source, 1, shuffle=False, collate_fn=collator)

for i, batch in enumerate(dataloader_eval):
    info_batch = batch()
    result = greedy_generate(head, tokenizer, collator, **info_batch)
    sentence_predicted = tokenizer.decode(result.cpu().tolist()[0])
    sentence_origin = tokenizer.decode(info_batch['labels_decoder'].cpu().tolist()[0])
    print('source: {}\ntarget: {}\n\n'.format(sentence_origin, sentence_predicted))
    if i >= 20:
        break
# end 

source: A group of men are loading cotton onto a truck [SEP] [PAD] [PAD] [PAD] [PAD] [PAD]
target: [CLS] A group of people are loading cotton candy truck truck


source: A man sleeping in a green room on a couch . [SEP] [PAD] [PAD] [PAD] [PAD]
target: [CLS] A man standing in a green shirt on a bench .


source: A boy wearing headphones sits on a woman 's shoulders . [SEP] [PAD] [PAD] [PAD] [PAD]
target: [CLS] A boy wearing headphones sits sits a woman 's shoulders .


source: Two men setting up a blue ice fishing hut on an iced over lake [SEP] [PAD]
target: [CLS] Two women setting up a blue ice fishing hut on an object over over


source: A balding man wearing a red life jacket is sitting in a small boat [SEP] [PAD]
target: [CLS] A young man wearing a black jacket jacket is sitting in a small small


source: A lady in a red coat , holding a <unk> hand bag likely of [SEP] [PAD]
target: [CLS] A woman in a red coat , holding a microphone in a space of


source: A brown dog is running afte

In [17]:
# eval_source = to_map_style_dataset(valid_iter)
# dataloader_eval = DataLoader(eval_source, 1, shuffle=False, collate_fn=collator)

# for i, batch in enumerate(dataloader_eval):
#     info_batch = batch()
#     output_s2s, loss_s2s = head.forward(**info_batch)
#     preds_s2s = torch.argmax(output_s2s,dim=-1)
    
#     sentence_predicted = tokenizer.decode(preds_s2s.cpu().tolist()[0])
#     sentence_origin = tokenizer.decode(info_batch['labels_decoder'].cpu().tolist()[0])
#     print('source: {}\ntarget: {}\n\n'.format(sentence_origin, sentence_predicted))
#     if i >= 0:
#         break
#     # end
# # end

In [18]:
# ids_encoder = info_batch['ids_encoder']
# masks_encoder = info_batch['masks_encoder']

# outputs_encoder = head.model.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder)
# ids_decoder = torch.zeros(1, 1).fill_(tokenizer.id_cls).type_as(ids_encoder.data)
# masks_decoder = collator.subsequent_mask(ids_decoder.size(1)).type_as(ids_encoder.data)

# outputs_decoder = head.model.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=outputs_encoder, masks_decoder=masks_decoder)
# outputs_ffn = head.ffn(outputs_decoder)
# outputs_s2s = head.extractor(head.ffn(outputs_ffn))   # outputs_mlm = prediction_logits
# preds_s2s = torch.argmax(output_s2s,dim=-1)

In [None]:
# sentence_predicted = tokenizer.decode(preds_s2s.cpu().tolist()[0])

hello
