In [1]:
import torch
import os
from os.path import exists
import torch.nn as nn
from torch.nn.functional import log_softmax, pad, one_hot
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torch.utils.data import DataLoader
import random
import json
import csv
from copy import deepcopy
from pathlib import Path
import shutil
import re

### utils.py ###

class Dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def __iadd__(self, other):
        for k, v in self.items():
            if k in other and other[k]:
                self[k] += other[k]
            # end
        # end

        return self
    # end
# end



# Takes the file paths as arguments
def parse_csv_file_to_json(path_file_csv):
    # create a dictionary
    elements = []

    # Open a csv reader called DictReader
    with open(path_file_csv, encoding='utf-8') as file_csv:
    #with open(path_file_csv) as file_csv:
        reader_csv = csv.DictReader(file_csv, delimiter="\t")

        # Convert each row into a dictionary
        # and add it to data
        for dict_head_value in reader_csv:
            element = {}

            for head, value in dict_head_value.items():
                #print(value)
                if value and (value[0] in ["[", "{"]):
                    #element[head] = eval(value)
                    element[head] = value
                else:
                    element[head] = value

            elements.append(element)
        # end
    # end

    return elements
# end

### utils.py ###




### core.py ###

"Produce N identical layers."
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# end



class MultiHeadedAttention(nn.Module):

    "Take in model size and number of heads."
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    # end


    "Compute 'Scaled Dot Product Attention'"
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # print('jinyuj: scores: {}, mask: {}'.format(scores.shape, mask.shape))
            scores = scores.masked_fill(mask == 0, -1e9)
        # end
        p_attn = scores.softmax(dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # end
        return torch.matmul(p_attn, value), p_attn
    # end


    "Implements Figure 2"
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = self.attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)
    # end
# end class


"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
class ResidualLayer(nn.Module):

    def __init__(self, size, dropout=0.1, eps=1e-6):
        super(ResidualLayer, self).__init__()
        self.norm = torch.nn.LayerNorm(size, eps)
        self.dropout = nn.Dropout(dropout)
    # end

    "Apply residual connection to any sublayer with the same size."
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    # end
# end class


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))
    # end
# end


class SimpleIDEmbeddings(nn.Module):
    def __init__(self, size_vocab, dim_hidden, id_pad):
        super(SimpleIDEmbeddings, self).__init__()
        self.lut = nn.Embedding(size_vocab, dim_hidden, padding_idx=id_pad)
        self.dim_hidden = dim_hidden

    def forward(self, x):
        result = self.lut(x)
        return result * math.sqrt(self.dim_hidden)
    # end

    def get_shape(self):
        return (self.lut.num_embeddings, self.lut.embedding_dim)
    # end
# end


"Implement the PE function."
class PositionalEncoding(nn.Module):

    def __init__(self, dim_positional, max_len=512):
        super(PositionalEncoding, self).__init__()

        # Compute the positional encodings once in log space.
        self.dim_positional = dim_positional
        pe = torch.zeros(max_len, dim_positional)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim_positional, 2) * -(math.log(10000.0) / dim_positional)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to('cuda')
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x
    # end
# end


class SimpleEmbedder(nn.Module):    # no segment embedder as we do not need that
    def __init__(self, size_vocab=None, dim_hidden=128, dropout=0.1, id_pad=0):
        super(SimpleEmbedder, self).__init__()
        self.size_vocab = size_vocab
        self.dim_hidden = dim_hidden
        self.id_pad = id_pad

        self.embedder = nn.Sequential(
            SimpleIDEmbeddings(size_vocab, dim_hidden, id_pad),
            PositionalEncoding(dim_hidden),
            nn.Dropout(dropout)
        )
    # end

    def forward(self, ids_input):   # (batch, seqs_with_padding)
        return self.embedder(ids_input)
    # end

    def get_vocab_size(self):
        return self.size_vocab
    # end
# end

### core.py ###



class SimpleEncoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleEncoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 2)
    # end

    def forward(self, embeddings, masks, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention(embeddings, embeddings, embeddings, masks))
        return self.layers_residual[1](embeddings, self.layer_feedforward)
    # end
# end



class SimpleDecoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleDecoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention_decoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_attention_encoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 3)

    def forward(self, embeddings, masks_encoder, output_encoder, masks_decoder, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention_decoder(embeddings, embeddings, embeddings, masks_decoder))
        embeddings = self.layers_residual[1](embeddings, lambda embeddings: self.layer_attention_encoder(embeddings, output_encoder, output_encoder, masks_encoder))
        return self.layers_residual[2](embeddings, self.layer_feedforward)
    # end
# end


class SimpleTransformerStack(nn.Module):

    def __init__(self, obj_layer, n_layers):
        super(SimpleTransformerStack, self).__init__()
        self.layers = clones(obj_layer, n_layers)

        self.norm = torch.nn.LayerNorm(obj_layer.dim_hidden)
        self.keys_cache = ['output']
        self.cache = Dotdict({
            'output': None
        })
    # end

    def forward(self, embedding_encoder=None, masks_encoder=None, output_encoder=None, embedding_decoder=None, masks_decoder=None ,noncache=False, **kwargs):  # input -> (batch, len_seq, vocab)

        if output_encoder is not None and embedding_decoder is not None and masks_decoder is not None:
            embeddings = embedding_decoder
        else:
            embeddings = embedding_encoder
        # end

        for layer in self.layers:
            embeddings = layer(embeddings, masks_encoder, output_encoder, masks_decoder)
        # end

        output = self.norm(embeddings)

        if not noncache:
            self.cache.output = output
        # end

        return output
    # end

    # def get_vocab_size(self):
    #     return self.embedder.embedder_token.shape[-1]
    # # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end


class SimpleEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder_encoder, embedder_decoder, pooling=False):
        super(SimpleEncoderDecoder, self).__init__()

        self.pooling = pooling
        
        self.embedder_encoder = embedder_encoder
        self.encoder = encoder

        self.embedder_decoder = embedder_decoder
        self.decoder = decoder
        
        self.keys_cache = ['output_encoder_pooled']
        self.cache = Dotdict({
            'output_encoder_pooled': None
        })
    # end

    def forward(self, ids_encoder=None, masks_encoder=None, ids_decoder=None, masks_decoder=None, nocache=False, **kwargs):
        
        output_encoder = self.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder, nocache=nocache)
        output = output_encoder
        
        if self.pooling:
            output_encoder_refilled = output_encoder.masked_fill(masks_encoder.transpose(-1,-2)==False, 0)
            output_encoder_pooled = torch.mean(output_encoder_refilled, dim=-2)
            self.cache.output_encoder_pooled = output_encoder_pooled
            
            output_encoder_pooled_expanded = output_encoder_pooled.unsqueeze(-2).expand(output_encoder.shape)
            output = output_encoder_pooled_expanded
        # end
        
        if self.embedder_decoder and self.decoder:
            output_decoder = self.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=output, masks_decoder=masks_decoder, nocache=nocache)
            output = output_decoder
        # end if
        
        return output
    # end
    
    def embed_and_encode(self, ids_encoder=None, masks_encoder=None, nocache=False, **kwargs):
        self.encoder.clear_cache()
        
        embedding_encoder = self.embedder_encoder(ids_encoder)
        output_encoder = self.encoder(
            embedding_encoder=embedding_encoder,
            masks_encoder=masks_encoder,
            nocache=nocache
        )
        
        return output_encoder
    # end

    
    def embed_and_decode(self, ids_decoder=None, masks_encoder=None, output_encoder=None, masks_decoder=None, nocache=False, **kwargs):
        self.decoder.clear_cache()
        
        embedding_decoder = self.embedder_decoder(ids_decoder)
        output_decoder = self.decoder(
            masks_encoder=masks_encoder,
            output_encoder=output_encoder,    #(len_seq, dim_hidden) -> (1, dim_hidden)
            embedding_decoder=embedding_decoder,
            masks_decoder=masks_decoder,
            nocache=nocache
        )

        return output_decoder
    # end
    

    def clear_cache(self):
        self.encoder.clear_cache()
        
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
        
        if self.decoder:
            self.decoder.clear_cache()
        # end
    # end


    def get_vocab_size(self, name_embedder):
        embedder = getattr(self, f'embedder_{name_embedder}')
        return embedder.get_vocab_size()
    # end

# end

class LinearAndNorm(nn.Module):
    def __init__(self, dim_in = None, dim_out = None, eps_norm=1e-12):
        super(LinearAndNorm, self).__init__()

        self.linear = torch.nn.Linear(dim_in, dim_out)
        self.norm = torch.nn.LayerNorm(dim_out, eps_norm)
    # end

    def forward(self, seqs_in):
        return self.norm(self.linear(seqs_in).relu())
    # end
# end




class TokenizerWrapper:
    def __init__(self, vocab, splitter):
        self.splitter = splitter
        self.vocab = vocab

        self.id_pad = len(vocab)
        self.id_cls = len(vocab) + 1
        self.id_sep = len(vocab) + 2
        self.id_mask = len(vocab) + 3
        
        self.size_vocab = len(vocab) + 4
        self.vocab_size = self.size_vocab

        self.token_pad = '[P@D]'
        self.token_cls = '[CL$]'
        self.token_sep = '[$EP]'
        self.token_mask = '[M@$K]'
           
        self.index_id_token_special = {
            self.id_pad: self.token_pad,
            self.id_cls: self.token_cls,
            self.id_sep: self.token_sep,
            self.id_mask: self.token_mask
        }
        
    # end

    def encode(self, line):
        return self.vocab([doc.text.lower() for doc in self.splitter(line)])
    # end

    def decode(self, tokens):
        words = []
        for token in tokens:
            token = int(token)
            
            if token in self.index_id_token_special:
                word_target = self.index_id_token_special[token]
            else:
                try:
                    word_target = vocab.lookup_token(token)
                except:
                    word_target = '[ERROR_LOOKUP_{}]'.format(token)
                # end
            # end
            
            words.append(word_target)
        # end
        
        return ' '.join(words)
    # end
# end



class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end



class Collator_S2S:

    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked
    # end
    

    def __call__(self, list_corpus_source):

        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        labels_similarity = []

        for corpus_source in list_corpus_source: # (line0, line1, sim), output of zip remove single case
            if len(corpus_source) == 3:
                corpus_line = [corpus_source[0], corpus_source[1]]
                labels_similarity.append(corpus_source[2])
            else:
                corpus_line = [corpus_source[1]]
            # end
            
            for line in corpus_line:
                tokens = self.tokenizer.encode(line)

                # TODO: check edge
                if len(tokens) > self.size_seq_max - 2:
                    tokens = tokens[:self.size_seq_max-2]
                # end

                tokens_input_encoder.append([self.tokenizer.id_cls] + tokens + [self.tokenizer.id_sep])
                tokens_input_decoder.append([self.tokenizer.id_cls] + tokens)
                tokens_label_decoder.append(tokens + [self.tokenizer.id_sep])
            # end
            

        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        # labels_similarity = torch.Tensor(labels_similarity).unsqueeze(0).transpose(0,1)
        labels_similarity = torch.Tensor(labels_similarity)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label,
            labels_similarity=labels_similarity
        )
    # end

    
    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.tokenizer.id_pad
        id_mask = self.tokenizer.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.tokenizer.id_pad).unsqueeze(-2)    #(nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.tokenizer.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end



class Collator_BERT:
    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked
        
        index_special_token_2_id = {k:v for k,v in zip(tokenizer.all_special_tokens,tokenizer.all_special_ids)}
        
        self.id_pad = index_special_token_2_id['[PAD]']
        self.id_mask = index_special_token_2_id['[MASK]']
        self.id_cls = index_special_token_2_id['[CLS]']
        self.id_sep = index_special_token_2_id['[SEP]']
        self.id_unk = index_special_token_2_id['[UNK]']
        
        self.regex_special_token = re.compile(r'\[(PAD|MASK|CLS|SEP|EOL|UNK)\]')
    # end
    
    def _preprocess(self, line):
        line = re.sub(self.regex_special_token, r'<\1>', line)
        line = re.sub(r'''('|"|`){2}''', '', line)
        line = re.sub(r'\.{2,3}', '', line)
        line = re.sub(r' {2,}', ' ', line)
        line = line.lstrip().rstrip()
        return line
    # end
    

    def __call__(self, list_sequence_batch):
        list_sequence_batch = [self._preprocess(sequence) for sequence in list_sequence_batch]   # remove special tokens
        
        list_sequence_tokenized = self.tokenizer.batch_encode_plus(list_sequence_batch, add_special_tokens=False)['input_ids']
        
        # Process I. 
        list_list_tokenized = []
        
        # batch initialized condition
        list_tokenized_cache = []
        len_tokenized_accumulated = 2 # add cls and sep
        
        while list_sequence_tokenized:
            tokenized_poped = list_sequence_tokenized.pop(0)
            len_tokenized_current = len(tokenized_poped)
            
            if len_tokenized_accumulated + len_tokenized_current > self.size_seq_max:
                if list_tokenized_cache:
                    list_list_tokenized.append(list_tokenized_cache)
                
                    # clear
                    list_tokenized_cache = []
                    len_tokenized_accumulated = 2
                # end
            # end

            list_tokenized_cache.append(tokenized_poped)
            len_tokenized_accumulated += len_tokenized_current
        # end
        
        list_list_tokenized.append(list_tokenized_cache)
        
        
        # Process II. Merge list_tokenized
        list_tokenized_merged = []
        
        for list_tokenized in list_list_tokenized:
            # tokenized_merged = [token for tokenized_padded in [tokenized + [self.id_eol] for tokenized in list_tokenized] for token in tokenized_padded]
            tokenized_merged = [token for tokenized in list_tokenized for token in tokenized][:self.size_seq_max-2]
            list_tokenized_merged.append(tokenized_merged)
        # end
        
        
        # Process III. Add begin and stop special token, same as jinyuj_transformers_quora.ipynb
        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        
        for tokenized_merged in list_tokenized_merged:
            tokens_input_encoder.append([self.id_cls] + tokenized_merged + [self.id_sep])
            tokens_input_decoder.append([self.id_cls] + tokenized_merged)
            tokens_label_decoder.append(tokenized_merged + [self.id_sep])
        # end
        
        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        
        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label
        )
    # end


    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.id_pad
        id_mask = self.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            # print(sequence_padded.shape)
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.id_pad).unsqueeze(-2)    #(nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0
    # end

    
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end

In [2]:
import spacy


def Multi30k(language_pair=None):
    corpus_lines_train = []

    for lan in language_pair:
        with open('text/train.{}'.format(lan), 'r') as file:
            corpus_lines_train.append(file.read().splitlines())
        # end
    # end

    corpus_train = list(zip(*corpus_lines_train))

    corpus_lines_eval = []

    for lan in language_pair:
        with open('text/val.{}'.format(lan), 'r') as file:
            corpus_lines_eval.append(file.read().splitlines())
        # end
    # end

    corpus_eval = list(zip(*corpus_lines_eval))

    return corpus_train, corpus_eval, None
# end


def Quora(split=0.05):
    filename_quora = 'quora_duplicate_questions.tsv'
    
    contents_quora = parse_csv_file_to_json(filename_quora)
    list_corpus_quora = []
    for c in contents_quora:
        label = int(c['is_duplicate'])
        score = 1.0 if label else 0.5
        corpus_quora = (c['question1'], c['question2'], score)
        list_corpus_quora.append(corpus_quora)
    # end
    
    indexs_all = list(range(len(list_corpus_quora)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus_quora))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus_quora[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus_quora[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end

def BookCorpus2000(split=0.1):
    filename = 'bookcorpus_2000.json'
    
    with open(filename, 'r') as file:
        list_corpus = json.load(file)
    # end
    
    indexs_all = list(range(len(list_corpus)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end


def BookCorpus(split=0.0001, used=-1):
    import datasets
    
    list_corpus = datasets.load_dataset('bookcorpus')['train']['text'][:used]   # 70,000,000, 70 Million
    
    indexs_all = list(range(len(list_corpus)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end




def load_vocab(filename_vocab):
    vocab_tgt = torch.load(filename_vocab)
    return vocab_tgt
# end

def load_spacy():
    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en
# end

In [3]:
class SimpleEncoderHead_MLM(nn.Module):

    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'corrects_segmented': 0,
            'corrects_masked': 0,
            'num_segmented': 0,
            'num_masked': 0 
        })
    # end
    
    def __init__(self, model, size_vocab, dim_hidden=128):
        super(SimpleEncoderHead_MLM, self).__init__()
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
        self.extractor.weight = nn.Parameter(model.embedder_encoder.embedder[0].lut.weight)
        
        self.keys_cache = ['labels_mlm', 'masks_encoder', 'segments_encoder', 'output']
        self.cache = Dotdict({
            'labels_mlm': None,
            'masks_encoder': None,
            'segments_encoder': None,
            'output': None
        })
        
        self.func_loss = torch.nn.CrossEntropyLoss()
    # end


    def forward(self, model, labels_encoder=None, segments_encoder=None, masks_encoder=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        output_encoder = model.encoder.cache.output
        output_ffn = self.ffn(output_encoder)
        output_mlm = self.extractor(output_ffn) # output_mlm = prediction_logits

        if not nocache:
            self.cache.labels_mlm = labels_encoder
            self.cache.masks_encoder = masks_encoder
            self.cache.segments_encoder = segments_encoder
            self.cache.output = output_mlm
        # end

        return output_mlm
    # end
    
    def get_loss(self):
        
        labels_mlm = self.cache.labels_mlm
        masks_encoder = self.cache.masks_encoder
        segments_encoder = self.cache.segments_encoder
        output_mlm = self.cache.output
        
        info_acc = SimpleEncoderHead_MLM.get_info_accuracy_template()
        
        segments_encoder_2d = segments_encoder.transpose(-1,-2)[:,:,0]
        hidden_mlm_segmented = output_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1]) # should be (segmented_all_batchs, size_vocab)
        
        loss_segments = self.func_loss(hidden_mlm_segmented, labels_mlm.masked_select(segments_encoder_2d))
        info_acc.corrects_segmented = torch.sum(hidden_mlm_segmented.argmax(-1) == labels_mlm.masked_select(segments_encoder_2d)).cpu().item()
        info_acc.num_segmented = hidden_mlm_segmented.shape[0]
        
        masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
        masks_masked_perbatch = masks_masked[:,0,:]
        hidden_mlm_masked = output_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1])
        
        loss_masked = self.func_loss(hidden_mlm_masked, labels_mlm.masked_select(masks_masked_perbatch))       
        info_acc.corrects_masked = torch.sum(hidden_mlm_masked.argmax(-1) == labels_mlm.masked_select(masks_masked_perbatch)).cpu().item()
        info_acc.num_masked = hidden_mlm_masked.shape[0]
        
        loss_mlm = loss_segments + loss_masked * 3
        
        return loss_mlm, info_acc
    # end
    
    
    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end

# end

In [4]:
class SimpleDecoderHead_S2S(nn.Module):

    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'corrects_segmented': 0,
            'num_segmented': 0 
        })
    # end
    
    
    def __init__(self, model, size_vocab, dim_hidden=128):
        super(SimpleDecoderHead_S2S, self).__init__()
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
        self.extractor.weight = nn.Parameter(model.embedder_decoder.embedder[0].lut.weight)

        self.func_loss = torch.nn.CrossEntropyLoss()
        
        self.keys_cache = ['output', 'labels_s2s', 'segments_decoder']
        self.cache = Dotdict({
            'output': None,
            'labels_s2s': None,
            'segments_decoder': None
        })

    # end



    def forward(self, model, labels_decoder=None, segments_label=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        output_decoder = model.decoder.cache.output
        output_ffn = self.ffn(output_decoder)
        output_s2s = self.extractor(output_ffn)   # output_mlm = prediction_logits
        
        if not nocache:
            self.cache.segments_label = segments_label
            self.cache.labels_s2s =  labels_decoder
            self.cache.output = output_s2s
        # end

        return output_s2s
    # end


    def get_loss(self):
        labels_s2s = self.cache.labels_s2s
        output_s2s = self.cache.output
        info_acc = SimpleDecoderHead_S2S.get_info_accuracy_template()
        
        segments_label = self.cache.segments_label
        segments_label_2d = segments_label.transpose(-1,-2)[:,:,0]
        hidden_s2s_segmented = output_s2s.masked_select(segments_label_2d.unsqueeze(-1)).reshape(-1, output_s2s.shape[-1])

        loss_segments = self.func_loss(hidden_s2s_segmented, labels_s2s.masked_select(segments_label_2d))
        info_acc.corrects_segmented = torch.sum(hidden_s2s_segmented.argmax(-1) == labels_s2s.masked_select(segments_label_2d)).cpu().item()
        info_acc.num_segmented = hidden_s2s_segmented.shape[0]
        
        return loss_segments * 4, info_acc
    # end


    def evaluate(self):
        pass
    # end


    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end



In [5]:
class SimpleEncoderHead_Similarity(nn.Module):

    
    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'meansquares': []
        })
    # end    
    

    def __init__(self):
        super(SimpleEncoderHead_Similarity, self).__init__()

        self.func_loss = torch.nn.MSELoss()
        self.cos_score_transformation = torch.nn.Identity()
        self.keys_cache = ['labels_sim', 'output']
        self.cache = Dotdict({
            'labels_sim': None,
            'output': None
        })
    # end

    def forward(self, model, labels_similarity=None, nocache=False, **kwargs):  # labels_sim (batch/2, 1)   for every two sentences, we have a label

        output_encoder_pooled = model.cache.output_encoder_pooled
        size_batch, dim_hidden = output_encoder_pooled.shape

        if size_batch % 2 != 0:
            raise Exception('sim calculation is not prepared as size_batch % 2 != 0')
        # end

        # pooling (batch, pair, dim_hidden)
        output_pooling = output_encoder_pooled.squeeze(1).view(-1, 2, dim_hidden)   # might cls + sep, but abandon now (as it's not easy to get sep for every batch, different location)
        output_pooling_x1 = output_pooling[:, 0, :]
        output_pooling_x2 = output_pooling[:, 1, :]
        sims = self.cos_score_transformation(torch.cosine_similarity(output_pooling_x1, output_pooling_x2))  # -> (batch, scores)

        if not nocache:
            self.cache.output = sims
            self.cache.labels_sim = labels_similarity
        # end

        return sims
    # end

    def get_loss(self):
        sims = self.cache.output
        labels_sim = self.cache.labels_sim
        info_acc = SimpleEncoderHead_Similarity.get_info_accuracy_template()

        loss_sim = self.func_loss(sims, labels_sim)
        info_acc.meansquares.append((torch.mean((sims - labels_sim) ** 2)).cpu().item())
        return loss_sim * 64, info_acc
    # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end

    def evaluate(self):
        pass
    # end
# end


In [6]:
class HeadManager(nn.Module):
    def __init__(self):
        super(HeadManager, self).__init__()
        self.index_name_head = set()
    # end

    def register(self, head):
        name_head = head.__class__.__name__
        setattr(self, name_head, head)
        self.index_name_head.add(name_head)
        return self
    # end

    def forward(self, model, **kwargs):
        for name in self.index_name_head:
            head = getattr(self, name)
            head.forward(model, **kwargs)
        # end
    # end

    def get_head(self, klass):
        return getattr(self, klass.__name__)
    # end

    def clear_cache(self):
        for name_head in self.index_name_head:
            getattr(self, name_head).clear_cache()
        # end
    # end
# end


class Trainer(nn.Module):
    def __init__(self, model=None, manager=None):
        super(Trainer, self).__init__()
        self.model = model
        self.manager = manager
    # end

    def forward(self, **kwargs):
        self.clear_cache()
        
        self.model.forward(**kwargs)
        self.manager.forward(self.model, **kwargs)
    # end
    
    def clear_cache(self):
        self.model.clear_cache() if self.model else None
        self.manager.clear_cache() if self.manager else None
    # end
# end


class SaverAndLoader:
    def __init__(self, path_checkpoints='./checkpoints'):
        self.dict_name_item = {}
        self.path_checkpoints = path_checkpoints
    # end
    
    def add_item(self, item, name=None):
        if not name:
            name = item.__class__.__name__
        # end
        
        self.dict_name_item[name] = item
        return self
    # end
    
    
    def update_checkpoint(self, name_checkpoint, name_checkpoint_previous=None):  # epoch_n
        if name_checkpoint_previous:
            result = self._delete_checkpoint_folder(name_checkpoint_previous)
            if result:
                print(f'[INFO] {name_checkpoint_previous} is cleared.')
            else:
                print(f'[ALERT] {name_checkpoint_previous} fail to be cleared.')
            # end
        # end
        
        folder_checkpoint = self._create_checkpoint_folder(name_checkpoint)
        for name_item, item in self.dict_name_item.items():
            path_checkpoint_item = os.path.join(folder_checkpoint, f'{name_item}.pt')
            torch.save(item.state_dict(), path_checkpoint_item)
            
            size_file_saved_MB = os.path.getsize(path_checkpoint_item) / 1024 / 1024
            print(f'[INFO] {name_item} is saved, {size_file_saved_MB} MB')
        # end
        
        print(f'[INFO] {name_checkpoint} is saved')
    # end

    
    def load_item_state(self, name_checkpoint, instance_item, name_item=None):
        if not name_item:
            name_item = instance_item.__class__.__name__
        # end
        
        path_checkpoint_item = os.path.join(self.path_checkpoints, name_checkpoint, f'{name_item}.pt')
        if not os.path.exists(path_checkpoint_item):
            print(f'[ERROR] {path_checkpoint_item} not exists')
            return None
        # end
        if issubclass(instance_item.__class__, torch.nn.Module):
            instance_item.load_state_dict(torch.load(path_checkpoint_item), strict=False)
        else:
            instance_item.load_state_dict(torch.load(path_checkpoint_item))
        # end
        
        print(f'[INFO] {name_item} loaded for {name_checkpoint}.')
        return instance_item
    # end
    
    
    def list_items(self):
        return list(self.dict_name_item.keys())
    # end
    
    def _create_checkpoint_folder(self, name_checkpoint):
        path_folder_target = os.path.join(self.path_checkpoints, name_checkpoint)
        Path(path_folder_target).mkdir(parents=True, exist_ok=True)
        return path_folder_target
    # end
    
    def _delete_checkpoint_folder(self, name_checkpoint_previous):
        path_folder_target = os.path.join(self.path_checkpoints, name_checkpoint_previous)
        if os.path.exists(path_folder_target):
            shutil.rmtree(path_folder_target, ignore_errors=True)
        # end
        return (not os.path.exists(path_folder_target))
    # end
# end

In [7]:
class Builder:
    
    @classmethod
    def build_model_with_mlm_v2(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, None, embedder_encoder, None)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)

        manager = HeadManager().register(head_mlm)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    
    @classmethod
    def build_model_with_s2s_v2(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)
        
        manager = HeadManager().register(head_s2s)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    
    @classmethod
    def build_model_with_2heads(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)
        
        manager = HeadManager().register(head_s2s).register(head_mlm)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    
    @classmethod
    def load_model_with_2heads(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer, saver, name_checkpoint):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)
        
        loader.load_item_state(name_checkpoint, model)
        loader.load_item_state(name_checkpoint, head_s2s)
        loader.load_item_state(name_checkpoint, head_mlm)
        
        manager = HeadManager().register(head_s2s).register(head_mlm)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    
    @classmethod
    def build_model_with_sim_v2(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, None, embedder_encoder, None, pooling=True)
        head_sim = SimpleEncoderHead_Similarity()

        manager = HeadManager().register(head_sim)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    

    @classmethod
    def build_model_with_3heads(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)
        head_sim = SimpleEncoderHead_Similarity()
        
        manager = HeadManager().register(head_s2s).register(head_mlm).register(head_sim)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
# end

In [8]:
import re
import json
import transformers
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset
from transformers import AutoTokenizer

gpu = 0
torch.cuda.set_device(gpu)

epochs = 3

# source
seq_max = 128
batch_size = 64


# model & head
dim_hidden = 512
dim_feedforward = 512
n_head = 8
n_layer = 8

# optimizer
lr_base_optimizer = 1e-4
betas_optimizer = (0.9, 0.999)
eps_optimizer = 1e-9

# scheduler
warmup = 200

### for bookcorpus 2 heads ###
train_source, valid_source, _ = BookCorpus(split=0.001, used=100000)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
collator = Collator_BERT(tokenizer, seq_max)
###########


dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)
dataloader_eval = DataLoader(valid_source, 1, shuffle=False, collate_fn=collator)

# trainer= Builder.load_model_with_2heads(tokenizer.vocab_size, dim_hidden, dim_feedforward, n_head, n_layer, loader, 'epoch1')
trainer = Builder.build_model_with_2heads(tokenizer.vocab_size, dim_hidden, dim_feedforward, n_head, n_layer)

for p in trainer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
    # end
# end

trainer = trainer.to('cuda')

optimizer = torch.optim.Adam(trainer.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
decayRate = 0.96
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)


# optimizer = loader.load_item_state('epoch1', optimizer)
# lr_scheduler = loader.load_item_state('epoch1', lr_scheduler)

loader = SaverAndLoader('checkpoints_0')
loader.add_item(trainer.model)
loader.add_item(trainer.manager.get_head(SimpleEncoderHead_MLM))
loader.add_item(trainer.manager.get_head(SimpleDecoderHead_S2S))
loader.add_item(optimizer)
loader.add_item(lr_scheduler)

print()




In [9]:
def train_a_batch(batch, trainer, optimizer=None, scheduler=None):
    trainer.train()
    trainer.forward(**batch())
    
    
    loss_s2s, info_acc_s2s = trainer.manager.get_head(SimpleDecoderHead_S2S).get_loss()
    loss_mlm, info_acc_mlm = trainer.manager.get_head(SimpleEncoderHead_MLM).get_loss()
    # loss_sim, info_acc_sim = trainer.manager.get_head(SimpleEncoderHead_Similarity).get_loss()

    # crossentropy loss
    
    # loss_all = loss_s2s * 5
    # loss_all = loss_mlm
    # loss_all = loss_sim
    loss_all = (loss_s2s + loss_mlm) / 2
    # loss_all = (loss_s2s + loss_mlm + loss_sim) / 3
    # loss_all = (loss_s2s + loss_mlm + loss_sim)
    loss_all_value = loss_all.item()
    
    # print(loss_all)
    loss_all.backward()

    
    if optimizer:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    # end
    
    if scheduler:
        scheduler.step()
    # end
    
    trainer.clear_cache()
    return loss_all_value, Dotdict({'mlm': info_acc_mlm, 's2s': info_acc_s2s})
    # return loss_all_value, Dotdict({'mlm': info_acc_mlm, 's2s': info_acc_s2s, 'sim': info_acc_sim})
# end

In [10]:
def evaluate_a_batch(batch, trainer, *args, **kwargs):
    trainer.eval()
    with torch.no_grad():
        trainer.forward(**batch())
    # end
    
    loss_s2s, info_acc_s2s = trainer.manager.get_head(SimpleDecoderHead_S2S).get_loss()
    loss_mlm, info_acc_mlm = trainer.manager.get_head(SimpleEncoderHead_MLM).get_loss()

    # crossentropy loss
    
    # loss_all = loss_s2s * 5
    # loss_all = loss_mlm
    # loss_all = loss_sim
    loss_all = (loss_s2s + loss_mlm) / 2
    # loss_all = (loss_s2s + loss_mlm + loss_sim) / 3
    # loss_all = (loss_s2s + loss_mlm + loss_sim)
    loss_all_value = loss_all.item()
    
    trainer.clear_cache()
    return loss_all_value, Dotdict({'mlm': info_acc_mlm, 's2s': info_acc_s2s})
# end

In [11]:
from datetime import datetime
from tqdm import tqdm

name_checkpoint_current = None
name_checkpoint_last = None

for e in range(epochs):
    
    info_acc_heads_train = Dotdict({
        'mlm': SimpleEncoderHead_MLM.get_info_accuracy_template(),
        's2s': SimpleDecoderHead_S2S.get_info_accuracy_template(),
        # 'sim': SimpleEncoderHead_Similarity.get_info_accuracy_template()
    })


    info_acc_heads_eval = Dotdict({
        'mlm': SimpleEncoderHead_MLM.get_info_accuracy_template(),
        's2s': SimpleDecoderHead_S2S.get_info_accuracy_template(),
        # 'sim': SimpleEncoderHead_Similarity.get_info_accuracy_template()
    })
    
    # train phase
    losss_per_e = []
    for i, batch in enumerate(tqdm(dataloader_train)):
        loss_current, info_acc_heads_batch = train_a_batch(batch, trainer, optimizer, None)
        info_acc_heads_train += info_acc_heads_batch
        
        losss_per_e.append(loss_current)
        if i % 100 == 0:
            print('Epoch: {} Batch: {}, loss: {}, rate: {}, acc_mlm: {}, acc_s2s: {}'.format(
                e, i, loss_current, optimizer.param_groups[0]['lr'],
                info_acc_heads_batch.mlm.corrects_masked / info_acc_heads_batch.mlm.num_masked,
                info_acc_heads_batch.s2s.corrects_segmented / info_acc_heads_batch.s2s.num_segmented,
                # sum(info_acc_heads_batch.sim.meansquares) / len(info_acc_heads_batch.sim.meansquares)
            ), end='\r')
        # end
    # end
    
    loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
    print('[{}] Epoch: {} training ends. Status: Average loss: {}, Average MLM accuracy: {}, Average S2S accuracy: {}'.format(
        datetime.utcnow(), e, loss_average_per_e,
        info_acc_heads_train.mlm.corrects_masked / info_acc_heads_train.mlm.num_masked,
        info_acc_heads_train.s2s.corrects_segmented / info_acc_heads_train.s2s.num_segmented,
        # sum(info_acc_heads_train.sim.meansquares) / len(info_acc_heads_train.sim.meansquares)
    ))
    
    if e % 2 == 0:
        lr_scheduler.step() # schedule per 2 epoch
    # end

    
    # eval phase
    losss_per_e = []
    for i, batch in enumerate(tqdm(dataloader_train)):
        loss_current, info_acc_heads_batch = evaluate_a_batch(batch, trainer)
        info_acc_heads_eval += info_acc_heads_batch
        
        losss_per_e.append(loss_current)
    # end
    
    loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
    print('[{}] Epoch: {} Evalutation ends. Status: Average loss: {}, Average MLM accuracy: {}, Average S2S accuracy: {}'.format(
        datetime.utcnow(), e, loss_average_per_e,
        info_acc_heads_eval.mlm.corrects_masked / info_acc_heads_eval.mlm.num_masked,
        info_acc_heads_eval.s2s.corrects_segmented / info_acc_heads_eval.s2s.num_segmented,
        # sum(info_acc_heads_eval.sim.meansquares) / len(info_acc_heads_eval.sim.meansquares)
    ))
    
    name_checkpoint_current = f'epoch_{e}'
    loader.update_checkpoint(name_checkpoint_current, name_checkpoint_last)
    name_checkpoint_last = name_checkpoint_current
# end

  0%|          | 2/1561 [01:06<11:54:34, 27.50s/it]

Epoch: 0 Batch: 0, loss: 41.342159271240234, rate: 0.0001, acc_mlm: 0.0, acc_s2s: 0.0

  7%|▋         | 102/1561 [01:20<03:20,  7.29it/s] 

Epoch: 0 Batch: 100, loss: 27.255756378173828, rate: 0.0001, acc_mlm: 0.05514705882352941, acc_s2s: 0.059203444564047365

 13%|█▎        | 202/1561 [01:33<02:50,  7.99it/s]

Epoch: 0 Batch: 200, loss: 25.891868591308594, rate: 0.0001, acc_mlm: 0.0627062706270627, acc_s2s: 0.04961089494163424

 19%|█▉        | 302/1561 [01:47<02:50,  7.37it/s]

Epoch: 0 Batch: 300, loss: 24.867835998535156, rate: 0.0001, acc_mlm: 0.05653710247349823, acc_s2s: 0.11431513903192585

 26%|██▌       | 402/1561 [02:01<02:27,  7.87it/s]

Epoch: 0 Batch: 400, loss: 24.38874053955078, rate: 0.0001, acc_mlm: 0.06578947368421052, acc_s2s: 0.11923076923076924

 32%|███▏      | 502/1561 [02:14<02:21,  7.48it/s]

Epoch: 0 Batch: 500, loss: 23.752307891845703, rate: 0.0001, acc_mlm: 0.042483660130718956, acc_s2s: 0.15479582146248813

 39%|███▊      | 602/1561 [02:28<02:07,  7.49it/s]

Epoch: 0 Batch: 600, loss: 23.078828811645508, rate: 0.0001, acc_mlm: 0.06885245901639345, acc_s2s: 0.15784408084696824

 45%|████▍     | 702/1561 [02:41<01:51,  7.74it/s]

Epoch: 0 Batch: 700, loss: 22.693668365478516, rate: 0.0001, acc_mlm: 0.08664259927797834, acc_s2s: 0.16382978723404254

 51%|█████▏    | 802/1561 [02:54<01:42,  7.40it/s]

Epoch: 0 Batch: 800, loss: 23.112232208251953, rate: 0.0001, acc_mlm: 0.052805280528052806, acc_s2s: 0.15272373540856032

 58%|█████▊    | 902/1561 [03:07<01:31,  7.22it/s]

Epoch: 0 Batch: 900, loss: 22.144081115722656, rate: 0.0001, acc_mlm: 0.06382978723404255, acc_s2s: 0.16666666666666666

 64%|██████▍   | 1002/1561 [03:21<01:17,  7.23it/s]

Epoch: 0 Batch: 1000, loss: 21.730398178100586, rate: 0.0001, acc_mlm: 0.04391891891891892, acc_s2s: 0.16122650840751732

 71%|███████   | 1102/1561 [03:34<01:02,  7.39it/s]

Epoch: 0 Batch: 1100, loss: 21.985755920410156, rate: 0.0001, acc_mlm: 0.05263157894736842, acc_s2s: 0.1723076923076923

 77%|███████▋  | 1202/1561 [03:48<00:49,  7.31it/s]

Epoch: 0 Batch: 1200, loss: 21.946741104125977, rate: 0.0001, acc_mlm: 0.05639097744360902, acc_s2s: 0.17672886937431395

 83%|████████▎ | 1302/1561 [04:02<00:35,  7.38it/s]

Epoch: 0 Batch: 1300, loss: 22.588062286376953, rate: 0.0001, acc_mlm: 0.0672782874617737, acc_s2s: 0.1693548387096774

 90%|████████▉ | 1402/1561 [04:15<00:21,  7.29it/s]

Epoch: 0 Batch: 1400, loss: 22.26475715637207, rate: 0.0001, acc_mlm: 0.08865248226950355, acc_s2s: 0.178125

 96%|█████████▌| 1502/1561 [04:29<00:07,  7.38it/s]

Epoch: 0 Batch: 1500, loss: 22.573814392089844, rate: 0.0001, acc_mlm: 0.06741573033707865, acc_s2s: 0.16575192096597147

100%|██████████| 1561/1561 [04:37<00:00,  5.63it/s]


[2023-11-29 11:06:52.081279] Epoch: 0 training ends. Status: Average loss: 23.931120262170435, Average MLM accuracy: 0.05945389255244097, Average S2S accuracy: 0.1419955644668053


100%|██████████| 1561/1561 [01:26<00:00, 18.07it/s]


[2023-11-29 11:08:18.462823] Epoch: 0 Evalutation ends. Status: Average loss: 22.03921740212401, Average MLM accuracy: 0.06133217188524349, Average S2S accuracy: 0.1746098964510416
[INFO] SimpleEncoderDecoder is saved, 249.78036785125732 MB
[INFO] SimpleEncoderHead_MLM is saved, 60.62121295928955 MB
[INFO] SimpleDecoderHead_S2S is saved, 60.62121295928955 MB
[INFO] Adam is saved, 737.990008354187 MB
[INFO] ExponentialLR is saved, 0.0005445480346679688 MB
[INFO] epoch_0 is saved


  0%|          | 2/1561 [00:00<03:56,  6.59it/s]

Epoch: 1 Batch: 0, loss: 22.64373207092285, rate: 9.6e-05, acc_mlm: 0.05128205128205128, acc_s2s: 0.17418546365914786

  7%|▋         | 102/1561 [00:14<03:15,  7.46it/s]

Epoch: 1 Batch: 100, loss: 22.654911041259766, rate: 9.6e-05, acc_mlm: 0.05514705882352941, acc_s2s: 0.17115177610333693

 13%|█▎        | 202/1561 [00:27<02:55,  7.74it/s]

Epoch: 1 Batch: 200, loss: 21.91657829284668, rate: 9.6e-05, acc_mlm: 0.052805280528052806, acc_s2s: 0.16245136186770429

 19%|█▉        | 302/1561 [00:41<02:43,  7.72it/s]

Epoch: 1 Batch: 300, loss: 21.64832878112793, rate: 9.6e-05, acc_mlm: 0.045936395759717315, acc_s2s: 0.164778578784758

 26%|██▌       | 402/1561 [00:54<02:37,  7.36it/s]

Epoch: 1 Batch: 400, loss: 22.04288101196289, rate: 9.6e-05, acc_mlm: 0.05263157894736842, acc_s2s: 0.15096153846153845

 32%|███▏      | 502/1561 [01:07<02:19,  7.57it/s]

Epoch: 1 Batch: 500, loss: 21.985258102416992, rate: 9.6e-05, acc_mlm: 0.049019607843137254, acc_s2s: 0.1785375118708452

 39%|███▊      | 602/1561 [01:21<02:04,  7.68it/s]

Epoch: 1 Batch: 600, loss: 22.03728485107422, rate: 9.6e-05, acc_mlm: 0.05573770491803279, acc_s2s: 0.17709335899903753

 45%|████▍     | 702/1561 [01:34<01:51,  7.72it/s]

Epoch: 1 Batch: 700, loss: 21.441957473754883, rate: 9.6e-05, acc_mlm: 0.05054151624548736, acc_s2s: 0.1595744680851064

 51%|█████▏    | 802/1561 [01:47<01:39,  7.64it/s]

Epoch: 1 Batch: 800, loss: 21.92365837097168, rate: 9.6e-05, acc_mlm: 0.07590759075907591, acc_s2s: 0.17704280155642024

 58%|█████▊    | 902/1561 [02:00<01:24,  7.82it/s]

Epoch: 1 Batch: 900, loss: 21.715312957763672, rate: 9.6e-05, acc_mlm: 0.07801418439716312, acc_s2s: 0.18229166666666666

 64%|██████▍   | 1002/1561 [02:13<01:17,  7.19it/s]

Epoch: 1 Batch: 1000, loss: 21.17093276977539, rate: 9.6e-05, acc_mlm: 0.033783783783783786, acc_s2s: 0.17012858555885263

 71%|███████   | 1102/1561 [02:27<01:00,  7.56it/s]

Epoch: 1 Batch: 1100, loss: 21.971891403198242, rate: 9.6e-05, acc_mlm: 0.05964912280701754, acc_s2s: 0.18051282051282053

 77%|███████▋  | 1202/1561 [02:41<00:47,  7.55it/s]

Epoch: 1 Batch: 1200, loss: 21.082435607910156, rate: 9.6e-05, acc_mlm: 0.08270676691729323, acc_s2s: 0.18660812294182216

 83%|████████▎ | 1302/1561 [02:54<00:34,  7.41it/s]

Epoch: 1 Batch: 1300, loss: 22.20851707458496, rate: 9.6e-05, acc_mlm: 0.04281345565749235, acc_s2s: 0.17293906810035842

 90%|████████▉ | 1402/1561 [03:08<00:20,  7.66it/s]

Epoch: 1 Batch: 1400, loss: 22.034414291381836, rate: 9.6e-05, acc_mlm: 0.07092198581560284, acc_s2s: 0.18229166666666666

 96%|█████████▌| 1502/1561 [03:21<00:07,  7.60it/s]

Epoch: 1 Batch: 1500, loss: 22.425548553466797, rate: 9.6e-05, acc_mlm: 0.04868913857677903, acc_s2s: 0.16794731064763996

100%|██████████| 1561/1561 [03:29<00:00,  7.43it/s]


[2023-11-29 11:12:15.133408] Epoch: 1 training ends. Status: Average loss: 22.02207019373364, Average MLM accuracy: 0.06064483053519619, Average S2S accuracy: 0.17379152993984906


100%|██████████| 1561/1561 [01:26<00:00, 18.09it/s]


[2023-11-29 11:13:41.408403] Epoch: 1 Evalutation ends. Status: Average loss: 21.8013455954826, Average MLM accuracy: 0.06187660067736016, Average S2S accuracy: 0.1799449015300994
[INFO] epoch_0 is cleared.
[INFO] SimpleEncoderDecoder is saved, 249.78036785125732 MB
[INFO] SimpleEncoderHead_MLM is saved, 60.62121295928955 MB
[INFO] SimpleDecoderHead_S2S is saved, 60.62121295928955 MB
[INFO] Adam is saved, 737.990008354187 MB
[INFO] ExponentialLR is saved, 0.0005445480346679688 MB
[INFO] epoch_1 is saved


  0%|          | 2/1561 [00:00<03:19,  7.81it/s]

Epoch: 2 Batch: 0, loss: 22.15630340576172, rate: 9.6e-05, acc_mlm: 0.08974358974358974, acc_s2s: 0.18796992481203006

  7%|▋         | 102/1561 [00:13<03:09,  7.68it/s]

Epoch: 2 Batch: 100, loss: 22.27186393737793, rate: 9.6e-05, acc_mlm: 0.0625, acc_s2s: 0.17007534983853606

 13%|█▎        | 202/1561 [00:27<02:54,  7.78it/s]

Epoch: 2 Batch: 200, loss: 21.770280838012695, rate: 9.6e-05, acc_mlm: 0.0429042904290429, acc_s2s: 0.16828793774319067

 19%|█▉        | 302/1561 [00:41<02:47,  7.53it/s]

Epoch: 2 Batch: 300, loss: 21.187597274780273, rate: 9.6e-05, acc_mlm: 0.0706713780918728, acc_s2s: 0.17507723995880536

 26%|██▌       | 402/1561 [00:54<02:40,  7.21it/s]

Epoch: 2 Batch: 400, loss: 22.005496978759766, rate: 9.6e-05, acc_mlm: 0.04276315789473684, acc_s2s: 0.15

 32%|███▏      | 502/1561 [01:07<02:22,  7.41it/s]

Epoch: 2 Batch: 500, loss: 22.020832061767578, rate: 9.6e-05, acc_mlm: 0.06209150326797386, acc_s2s: 0.1842355175688509

 39%|███▊      | 602/1561 [01:21<02:12,  7.24it/s]

Epoch: 2 Batch: 600, loss: 22.051027297973633, rate: 9.6e-05, acc_mlm: 0.05573770491803279, acc_s2s: 0.17998075072184794

 45%|████▍     | 702/1561 [01:34<01:59,  7.18it/s]

Epoch: 2 Batch: 700, loss: 22.011154174804688, rate: 9.6e-05, acc_mlm: 0.05415162454873646, acc_s2s: 0.1723404255319149

 51%|█████▏    | 802/1561 [01:48<01:37,  7.76it/s]

Epoch: 2 Batch: 800, loss: 22.026836395263672, rate: 9.6e-05, acc_mlm: 0.0429042904290429, acc_s2s: 0.17898832684824903

 58%|█████▊    | 902/1561 [02:01<01:32,  7.11it/s]

Epoch: 2 Batch: 900, loss: 21.534942626953125, rate: 9.6e-05, acc_mlm: 0.06028368794326241, acc_s2s: 0.18854166666666666

 64%|██████▍   | 1002/1561 [02:15<01:16,  7.35it/s]

Epoch: 2 Batch: 1000, loss: 20.9748477935791, rate: 9.6e-05, acc_mlm: 0.07094594594594594, acc_s2s: 0.17309594460929772

 71%|███████   | 1102/1561 [02:28<01:00,  7.62it/s]

Epoch: 2 Batch: 1100, loss: 21.47995376586914, rate: 9.6e-05, acc_mlm: 0.07017543859649122, acc_s2s: 0.18153846153846154

 77%|███████▋  | 1202/1561 [02:42<00:49,  7.18it/s]

Epoch: 2 Batch: 1200, loss: 21.380634307861328, rate: 9.6e-05, acc_mlm: 0.04887218045112782, acc_s2s: 0.19099890230515917

 83%|████████▎ | 1302/1561 [02:55<00:34,  7.54it/s]

Epoch: 2 Batch: 1300, loss: 21.920467376708984, rate: 9.6e-05, acc_mlm: 0.0764525993883792, acc_s2s: 0.17831541218637992

 90%|████████▉ | 1402/1561 [03:09<00:21,  7.32it/s]

Epoch: 2 Batch: 1400, loss: 21.82544708251953, rate: 9.6e-05, acc_mlm: 0.0673758865248227, acc_s2s: 0.190625

 96%|█████████▌| 1502/1561 [03:23<00:08,  7.33it/s]

Epoch: 2 Batch: 1500, loss: 21.802059173583984, rate: 9.6e-05, acc_mlm: 0.08614232209737828, acc_s2s: 0.16575192096597147

100%|██████████| 1561/1561 [03:31<00:00,  7.39it/s]


[2023-11-29 11:17:38.819180] Epoch: 2 training ends. Status: Average loss: 21.908296596079893, Average MLM accuracy: 0.060229703581207224, Average S2S accuracy: 0.1770776271492924


100%|██████████| 1561/1561 [01:26<00:00, 18.10it/s]


[2023-11-29 11:19:05.080765] Epoch: 2 Evalutation ends. Status: Average loss: 21.734466985278523, Average MLM accuracy: 0.06152952732238578, Average S2S accuracy: 0.18240332505438048
[INFO] epoch_1 is cleared.
[INFO] SimpleEncoderDecoder is saved, 249.78036785125732 MB
[INFO] SimpleEncoderHead_MLM is saved, 60.62121295928955 MB
[INFO] SimpleDecoderHead_S2S is saved, 60.62121295928955 MB
[INFO] Adam is saved, 737.990008354187 MB
[INFO] ExponentialLR is saved, 0.0005445480346679688 MB
[INFO] epoch_2 is saved


In [12]:
trainer.eval()
print()




In [13]:
# For s2s head
def greedy_generate(model, head, tokenizer, collator, **kwargs):
    id_start = tokenizer.id_cls if hasattr(tokenizer, 'id_cls') else collator.id_cls
    id_end = tokenizer.id_sep if hasattr(tokenizer, 'id_sep') else collator.id_sep
    id_pad = tokenizer.id_pad if hasattr(tokenizer, 'id_pad') else collator.id_pad
    size_seq_max = collator.size_seq_max

    ids_encoder_twin = kwargs['ids_encoder']
    masks_encoder_twin = kwargs['masks_encoder']
    
    ids_decoder_all = []
    
    for j in range(ids_encoder_twin.shape[0]):
        ids_encoder = ids_encoder_twin[j,].unsqueeze(0)
        masks_encoder = masks_encoder_twin[j,].unsqueeze(0)

        output_encoder = model.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder)
        ids_decoder = torch.zeros(1, 1).fill_(id_start).type_as(ids_encoder.data)

        for i in range(size_seq_max - 1):
            masks_decoder = collator.subsequent_mask(ids_decoder.size(1)).type_as(ids_encoder.data)
            output_decoder = model.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=output_encoder, masks_decoder=masks_decoder)

            output_ffn = head.ffn(output_decoder)
            output_s2s = head.extractor(output_ffn)   # output_mlm = prediction_logits

            logits_nextword = torch.softmax(output_s2s[:, -1], dim=-1)  # mynote: select dim2=-1, remain=all; last is the next

            id_nextword = torch.argmax(logits_nextword, dim=-1)
            id_nextword = id_nextword.data[0]

            if id_nextword == id_end:
                break
            # end

            ids_decoder = torch.cat([ids_decoder, torch.zeros(1, 1).type_as(ids_encoder.data).fill_(id_nextword)], dim=1)
        # end
        
        ids_pad = torch.full((1, size_seq_max - ids_decoder.shape[-1]), id_pad).type_as(ids_decoder.data)
        
        ids_decoder_all.append(torch.cat([ids_decoder, ids_pad], dim=-1).squeeze(0))
    # end for 

    return torch.stack(ids_decoder_all)
# end

# eval_source = to_map_style_dataset(valid_iter)
dataloader_eval = DataLoader(valid_source, 1, shuffle=False, collate_fn=collator)

for i, batch in enumerate(dataloader_eval):
    info_batch = batch()
    result = greedy_generate(trainer.model, trainer.manager.get_head(SimpleDecoderHead_S2S), tokenizer, collator, **info_batch)
    
    
    result_cpu_list = result.cpu().tolist()
    labels_decoder_cpu_list = info_batch['labels_decoder'].cpu().tolist()
    
    for result_cpu, labels_decoder in zip(result_cpu_list, labels_decoder_cpu_list):
    
        sentence_predicted = tokenizer.decode(result_cpu).split(' [PAD]')[0]
        sentence_origin = tokenizer.decode(labels_decoder).split(' [PAD]')[0]
        
        print('source: {}\ntarget: {}\n\n'.format(sentence_origin, sentence_predicted))
    # end
    
    if i >= 5:
        break
    # end
# end 

source: he probably wasn't even attracted to her. [SEP]
target: [CLS] she wasn't have to the other, she wasn't have to the other. she't have to the other. i't have to the other. i't have to the other. she wasn't have to the other. i't have to the other. i't have to the other. i't have to the other. she wasn't have to the other.


source: yes. [SEP]
target: [CLS] she wasn't have to the other, she wasn't have to the other. i't have to the other. i't have to the other. i't have to the other. i't have to the other. i't have to the other. i't have to the other. i't have to the other. she wasn't have to the other.


source: he got back into his truck and drove the rest of the way to his cabin seething with anger. [SEP]
target: [CLS] she wasn't have to the other, she wasn't have to the other. she't have to the other. she't have to the other. she't have to the other. she wasn't have to the other. she't have to the other. she't have to the other. she't have to the other. she wasn't have to the 

In [14]:
def decode_output(out_mlm, masks_masked_prebatch, labels_mlm, ids_encoder, tokenizer):
    # print segments
    # sentence_predicts = tokenizer.decode(out_mlm.softmax(dim=-1).argmax(dim=-1).masked_select(segments_encoder[:, 0, :]).numpy().tolist())
    # sentence_labels = tokenizer.decode(labels_mlm.masked_select(segments_encoder[:, 0, :]).numpy().tolist())
    # sentence_inputs = tokenizer.decode(ids_encoder.masked_select(segments_encoder[:, 0, :]).numpy().tolist())

    # print masks
    sentence_predicts = tokenizer.decode(out_mlm.softmax(dim=-1).argmax(dim=-1).masked_select(masks_masked_perbatch).numpy().tolist())
    sentence_labels = tokenizer.decode(labels_mlm.masked_select(masks_masked_perbatch).numpy().tolist())
    sentence_inputs = tokenizer.decode(ids_encoder.masked_select(masks_masked_perbatch).numpy().tolist())


#     sentence_predicts = tokenizer.decode(out_mlm.softmax(dim=-1).argmax(dim=-1).numpy().tolist()[0])
#     sentence_labels = tokenizer.decode(labels_mlm.numpy().tolist()[0])
#     sentence_inputs = tokenizer.decode(ids_encoder.numpy().tolist()[0])

    predicts_masked = out_mlm.softmax(dim=-1).argmax(dim=-1).masked_select(masks_masked_perbatch)
    labels_masked = labels_mlm.masked_select(masks_masked_perbatch)

    acc = torch.count_nonzero(predicts_masked == labels_masked) / labels_masked.shape[0]
    # acc = torch.count_nonzero(out_mlm.softmax(dim=-1).argmax(dim=-1).view(-1) == labels_mlm.view(-1)) / labels_mlm.view(-1).shape[0]
    return acc, sentence_labels, sentence_inputs, sentence_predicts
# end


# eval_source = to_map_style_dataset(valid_iter)
dataloader_eval = DataLoader(valid_source, 1, shuffle=False, collate_fn=collator)
# dataloader_eval = DataLoader(train_source, 1, shuffle=False, collate_fn=collator)

for i, batch in enumerate(dataloader_eval):
    info_batch = batch()
    trainer.forward(**info_batch)
    
    head = trainer.manager.get_head(SimpleEncoderHead_MLM)
    out_mlm = head.cache.output
    loss_mlm, _ = head.get_loss()
    
    out_mlm = out_mlm.cpu().detach()
    loss_mlm = loss_mlm.cpu().detach()
    labels_mlm = info_batch['labels_encoder'].cpu().detach()
    masks_encoder = info_batch['masks_encoder'].cpu().detach()
    segments_encoder = info_batch['segments_encoder'].cpu().detach()
    ids_encoder = info_batch['ids_encoder'].cpu().detach()
    
    
    masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
    masks_masked_perbatch = masks_masked[:,0,:]
    
    for j in range(masks_masked_perbatch.shape[0]):
        acc, sentence_labels, sentence_inputs, sentence_predicts = decode_output(out_mlm[j,].unsqueeze(0), masks_masked_perbatch[j,].unsqueeze(0), labels_mlm[j,].unsqueeze(0), ids_encoder[j,].unsqueeze(0), tokenizer)
        print('loss: {}, acc: {}\nsource: {}\ninput: {}\npredict: {}\n\n'.format(loss_mlm.item(), acc, sentence_labels, sentence_inputs, sentence_predicts))
    
    if i >= 5:
        break
    # end
# end

loss: 17.235618591308594, acc: 0.3333333432674408
source: probably to.
input: [MASK] [MASK] [MASK]
predict: he,.


loss: nan, acc: nan
source: 
input: 
predict: 


loss: 23.20349884033203, acc: 0.1666666716337204
source: he truck and rest of cabin
input: [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
predict: he,,,..


loss: 15.194304466247559, acc: 0.0
source: t
input: [MASK]
predict: ,


loss: 22.411466598510742, acc: 0.0
source: others gossip d have to a bit more but up
input: [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
predict: ,,........


loss: 19.523229598999023, acc: 0.0
source: i about everything d
input: [MASK] [MASK] [MASK] [MASK]
predict: he,,,




In [15]:
# lr_scheduler.state_dict()

In [16]:
# loader = SaverAndLoader('./checkpoints')
# loader.add_item(trainer.model)
# loader.add_item(trainer.manager.get_head(SimpleEncoderHead_MLM))
# loader.add_item(trainer.manager.get_head(SimpleDecoderHead_S2S))
# loader.add_item(optimizer)
# loader.add_item(lr_scheduler)

In [17]:
# loader.update_checkpoint('epoch1')