In [1]:
import torch
import os
from os.path import exists
import torch.nn as nn
# from torch.nn.functional import log_softmax, pad, one_hot
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torch.utils.data import DataLoader
import random
import json
import csv
from pathlib import Path
import shutil
import re

### utils.py ###

class Dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def __iadd__(self, other):
        for k, v in self.items():
            if k in other and other[k]:
                self[k] += other[k]
            # end
        # end

        return self
    # end
# end



# Takes the file paths as arguments
def parse_csv_file_to_json(path_file_csv):
    # create a dictionary
    elements = []

    # Open a csv reader called DictReader
    with open(path_file_csv, encoding='utf-8') as file_csv:
    #with open(path_file_csv) as file_csv:
        reader_csv = csv.DictReader(file_csv)

        # Convert each row into a dictionary
        # and add it to data
        for dict_head_value in reader_csv:
            element = {}

            for head, value in dict_head_value.items():
                #print(value)
                if value and (value[0] in ["[", "{"]):
                    #element[head] = eval(value)
                    element[head] = value
                else:
                    element[head] = value

            elements.append(element)
        # end
    # end

    return elements
# end

### utils.py ###




### core.py ###

"Produce N identical layers."
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# end



class MultiHeadedAttention(nn.Module):

    "Take in model size and number of heads."
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    # end


    "Compute 'Scaled Dot Product Attention'"
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # print('jinyuj: scores: {}, mask: {}'.format(scores.shape, mask.shape))
            scores = scores.masked_fill(mask == 0, -1e9)
        # end
        p_attn = scores.softmax(dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # end
        return torch.matmul(p_attn, value), p_attn
    # end


    "Implements Figure 2"
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = self.attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)
    # end
# end class


"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
class ResidualLayer(nn.Module):

    def __init__(self, size, dropout=0.1, eps=1e-6):
        super(ResidualLayer, self).__init__()
        self.norm = torch.nn.LayerNorm(size, eps)
        self.dropout = nn.Dropout(p=dropout)
    # end

    "Apply residual connection to any sublayer with the same size."
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    # end
# end class


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))
    # end
# end


class SimpleIDEmbeddings(nn.Module):
    def __init__(self, size_vocab, dim_hidden, id_pad):
        super(SimpleIDEmbeddings, self).__init__()
        self.lut = nn.Embedding(size_vocab, dim_hidden, padding_idx=id_pad)
        self.dim_hidden = dim_hidden

    def forward(self, x):
        result = self.lut(x)
        return result * math.sqrt(self.dim_hidden)
    # end

    def get_shape(self):
        return (self.lut.num_embeddings, self.lut.embedding_dim)
    # end
# end


"Implement the PE function."
class PositionalEncoding(nn.Module):

    def __init__(self, dim_positional, max_len=512):
        super(PositionalEncoding, self).__init__()

        # Compute the positional encodings once in log space.
        self.dim_positional = dim_positional
        pe = torch.zeros(max_len, dim_positional)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim_positional, 2) * -(math.log(10000.0) / dim_positional)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).to('cuda')
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x
    # end
# end


class SimpleEmbedder(nn.Module):    # no segment embedder as we do not need that
    def __init__(self, size_vocab=None, dim_hidden=128, dropout=0.1, id_pad=0):
        super(SimpleEmbedder, self).__init__()
        self.size_vocab = size_vocab
        self.dim_hidden = dim_hidden
        self.id_pad = id_pad

        self.embedder = nn.Sequential(
            SimpleIDEmbeddings(size_vocab, dim_hidden, id_pad),
            PositionalEncoding(dim_hidden),
            nn.Dropout(p=dropout)
        )
    # end

    def forward(self, ids_input):   # (batch, seqs_with_padding)
        return self.embedder(ids_input)
    # end

    def get_vocab_size(self):
        return self.size_vocab
    # end
# end

### core.py ###



class SimpleEncoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleEncoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 2)
    # end

    def forward(self, embeddings, masks, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention(embeddings, embeddings, embeddings, masks))
        return self.layers_residual[1](embeddings, self.layer_feedforward)
    # end
# end



class SimpleDecoderLayer(nn.Module):

    def __init__(self, dim_hidden, dim_feedforward, n_head, dropout=0.1):
        super(SimpleDecoderLayer, self).__init__()

        self.n_head = n_head
        self.dim_hidden = dim_hidden
        self.dim_feedforward = dim_feedforward

        self.layer_attention_decoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_attention_encoder = MultiHeadedAttention(n_head, dim_hidden)
        self.layer_feedforward = PositionwiseFeedForward(dim_hidden, dim_feedforward, dropout)
        self.layers_residual = clones(ResidualLayer(dim_hidden, dropout), 3)

    def forward(self, embeddings, masks_encoder, output_encoder, masks_decoder, *args):
        embeddings = self.layers_residual[0](embeddings, lambda embeddings: self.layer_attention_decoder(embeddings, embeddings, embeddings, masks_decoder))
        embeddings = self.layers_residual[1](embeddings, lambda embeddings: self.layer_attention_encoder(embeddings, output_encoder, output_encoder, masks_encoder))
        return self.layers_residual[2](embeddings, self.layer_feedforward)
    # end
# end


class SimpleTransformerStack(nn.Module):

    def __init__(self, obj_layer, n_layers):
        super(SimpleTransformerStack, self).__init__()
        self.layers = clones(obj_layer, n_layers)

        self.norm = torch.nn.LayerNorm(obj_layer.dim_hidden)
        self.keys_cache = ['output']
        self.cache = Dotdict({
            'output': None
        })
    # end

    def forward(self, embedding_encoder=None, masks_encoder=None, output_encoder=None, embedding_decoder=None, masks_decoder=None ,noncache=False, **kwargs):  # input -> (batch, len_seq, vocab)

        if output_encoder is not None and embedding_decoder is not None and masks_decoder is not None:
            embeddings = embedding_decoder
        else:
            embeddings = embedding_encoder
        # end

        for layer in self.layers:
            embeddings = layer(embeddings, masks_encoder, output_encoder, masks_decoder)
        # end

        output = self.norm(embeddings)

        if not noncache:
            self.cache.output = output
        # end

        return output
    # end

    # def get_vocab_size(self):
    #     return self.embedder.embedder_token.shape[-1]
    # # end

    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end


class SimpleEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embedder_encoder, embedder_decoder, pooling=False):
        super(SimpleEncoderDecoder, self).__init__()

        self.pooling = pooling
        
        self.embedder_encoder = embedder_encoder
        self.encoder = encoder

        self.embedder_decoder = embedder_decoder
        self.decoder = decoder
        
        self.keys_cache = ['output_encoder_pooled']
        self.cache = Dotdict({
            'output_encoder_pooled': None
        })
    # end

    def forward(self, ids_encoder=None, masks_encoder=None, ids_decoder=None, masks_decoder=None, nocache=False, **kwargs):
        
        output_encoder = self.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder, nocache=nocache)
        output = output_encoder
        
        if self.pooling:
            output_encoder_refilled = output_encoder.masked_fill(masks_encoder.transpose(-1,-2)==False, 0)
            output_encoder_pooled = torch.mean(output_encoder_refilled, dim=-2)
            self.cache.output_encoder_pooled = output_encoder_pooled
            
            output_encoder_pooled_expanded = output_encoder_pooled.unsqueeze(-2).expand(output_encoder.shape)
            output = output_encoder_pooled_expanded
        # end
        
        if self.embedder_decoder and self.decoder:
            output_decoder = self.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=output, masks_decoder=masks_decoder, nocache=nocache)
            output = output_decoder
        # end if
        
        return output
    # end
    
    def embed_and_encode(self, ids_encoder=None, masks_encoder=None, nocache=False, **kwargs):
        self.encoder.clear_cache()
        
        embedding_encoder = self.embedder_encoder(ids_encoder)
        output_encoder = self.encoder(
            embedding_encoder=embedding_encoder,
            masks_encoder=masks_encoder,
            nocache=nocache
        )
        
        return output_encoder
    # end

    
    def embed_and_decode(self, ids_decoder=None, masks_encoder=None, output_encoder=None, masks_decoder=None, nocache=False, **kwargs):
        self.decoder.clear_cache()
        
        embedding_decoder = self.embedder_decoder(ids_decoder)
        output_decoder = self.decoder(
            masks_encoder=masks_encoder,
            output_encoder=output_encoder,    #(len_seq, dim_hidden) -> (1, dim_hidden)
            embedding_decoder=embedding_decoder,
            masks_decoder=masks_decoder,
            nocache=nocache
        )

        return output_decoder
    # end
    

    def clear_cache(self):
        self.encoder.clear_cache()
        
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
        
        if self.decoder:
            self.decoder.clear_cache()
        # end
    # end


    def get_vocab_size(self, name_embedder):
        embedder = getattr(self, f'embedder_{name_embedder}')
        return embedder.get_vocab_size()
    # end

# end

class LinearAndNorm(nn.Module):
    def __init__(self, dim_in = None, dim_out = None, dropout=0.1, eps_norm=1e-12):
        super(LinearAndNorm, self).__init__()

        self.linear = torch.nn.Linear(dim_in, dim_out)
        self.norm = torch.nn.LayerNorm(dim_out, eps_norm)
        self.dropout = torch.nn.Dropout(p=dropout)
    # end

    def forward(self, seqs_in):
        return self.dropout(self.norm(self.linear(seqs_in).relu()))
    # end
# end



class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end



class Collator_Base:

    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked

        index_special_token_2_id = {k: v for k, v in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids)}

        self.id_pad = index_special_token_2_id['[PAD]']
        self.id_mask = index_special_token_2_id['[MASK]']
        self.id_cls = index_special_token_2_id['[CLS]']
        self.id_sep = index_special_token_2_id['[SEP]']
        self.id_unk = index_special_token_2_id['[UNK]']

        self.regex_special_token = re.compile(r'\[(PAD|MASK|CLS|SEP|EOL|UNK)\]')
    # end

    def _preprocess(self, line):
        line = re.sub(self.regex_special_token, r'<\1>', line)
        line = re.sub(r'''('|"|`){2}''', '', line)
        line = re.sub(r'\.{2,3}', '', line)
        line = re.sub(r' {2,}', ' ', line)
        line = line.lstrip().rstrip()
        return line
    # end

    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False,
                      need_masked=0):
        id_pad = self.id_pad
        id_mask = self.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq - 1))
                random.shuffle(index_masked)
                anchor_mask = int(need_masked * (len_seq - 2)) or 1
                index_masked = torch.LongTensor(index_masked[:anchor_mask])
                # index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))

                sequences_masked_padded.append(sequence_masked_padded)
            # end
        #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.id_pad).unsqueeze(-2)  # (nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs  # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end

    # end

    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end


class Collator_SC(Collator_Base):

    def __call__(self, list_corpus_source):

        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        labels_similarity = []
        labels_sc = []

        for corpus_source in list_corpus_source:  # (line0, line1, sim), output of zip remove single case
            if len(corpus_source) == 3:  # (line0, line1, sim)
                corpus_line = [corpus_source[0], corpus_source[1]]
                labels_similarity.append(corpus_source[2])
            elif len(corpus_source) == 2:  # (line, label_sc)
                corpus_line = [corpus_source[0]]
                labels_sc.append(corpus_source[1])
            else:
                corpus_line = [corpus_source[0]]
            # end

            for line in corpus_line:
                tokens = self.tokenizer.encode(self._preprocess(line), add_special_tokens=False)

                # TODO: check edge
                if len(tokens) > self.size_seq_max - 2:
                    tokens = tokens[:self.size_seq_max - 2]
                # end

                tokens_input_encoder.append([self.id_cls] + tokens + [self.id_sep])
                tokens_input_decoder.append([self.id_cls] + tokens)
                tokens_label_decoder.append(tokens + [self.id_sep])
            # end

        # end

        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder,
                                                                                             self.size_seq_max,
                                                                                             need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max,
                                                                                need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)
        # labels_similarity = torch.Tensor(labels_similarity).unsqueeze(0).transpose(0,1)
        labels_similarity = torch.Tensor(labels_similarity)
        labels_sc = torch.LongTensor(labels_sc)

        return Batch(
            ids_encoder=inputs_encoder,  # contains [mask]s
            masks_encoder=masks_encoder,
            labels_encoder=labels_encoder,  # doesn't contain [mask]
            segments_encoder=segments_encoder,
            ids_decoder=inputs_decoder,
            masks_decoder=masks_decoder,
            labels_decoder=labels_decoder,
            segments_label=segments_label,
            labels_similarity=labels_similarity,
            labels_sc=labels_sc
        )

    # end
# end


In [2]:
def GOSV(path_base, filename_base, postfix, index_label_2_id, split=0.1):
    filename = f'{filename_base}{postfix}'
    path_file = os.path.join(path_base, filename)
    contents = parse_csv_file_to_json(path_file)
    
    list_corpus = [(content['processed'], index_label_2_id[content['target']]) for content in contents]
    
    indexs_all = list(range(len(list_corpus)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end

In [3]:
class SimpleEncoderHead_MLM(nn.Module):

    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'corrects_segmented': 0,
            'corrects_masked': 0,
            'num_segmented': 0,
            'num_masked': 0 
        })
    # end
    
    def __init__(self, model, size_vocab, dim_hidden=128, dropout=0.1):
        super(SimpleEncoderHead_MLM, self).__init__()
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden, dropout=dropout)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
        self.extractor.weight = nn.Parameter(model.embedder_encoder.embedder[0].lut.weight)
        
        self.keys_cache = ['labels_mlm', 'masks_encoder', 'segments_encoder', 'output']
        self.cache = Dotdict({
            'labels_mlm': None,
            'masks_encoder': None,
            'segments_encoder': None,
            'output': None
        })
        
        self.func_loss = torch.nn.CrossEntropyLoss()
    # end


    def forward(self, model, labels_encoder=None, segments_encoder=None, masks_encoder=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        output_encoder = model.encoder.cache.output
        output_ffn = self.ffn(output_encoder)
        output_mlm = self.extractor(output_ffn) # output_mlm = prediction_logits

        if not nocache:
            self.cache.labels_mlm = labels_encoder
            self.cache.masks_encoder = masks_encoder
            self.cache.segments_encoder = segments_encoder
            self.cache.output = output_mlm
        # end

        return output_mlm
    # end
    
    def get_loss(self):
        
        labels_mlm = self.cache.labels_mlm
        masks_encoder = self.cache.masks_encoder
        segments_encoder = self.cache.segments_encoder
        output_mlm = self.cache.output
        
        info_acc = SimpleEncoderHead_MLM.get_info_accuracy_template()
        
        segments_encoder_2d = segments_encoder.transpose(-1,-2)[:,:,0]
        hidden_mlm_segmented = output_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1]) # should be (segmented_all_batchs, size_vocab)
        
        loss_segments = self.func_loss(hidden_mlm_segmented, labels_mlm.masked_select(segments_encoder_2d))
        info_acc.corrects_segmented = torch.sum(hidden_mlm_segmented.argmax(-1) == labels_mlm.masked_select(segments_encoder_2d)).cpu().item()
        info_acc.num_segmented = hidden_mlm_segmented.shape[0]
        
        masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
        masks_masked_perbatch = masks_masked[:,0,:]
        hidden_mlm_masked = output_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1])

        if hidden_mlm_masked.shape[0] != 0:
            loss_masked = self.func_loss(hidden_mlm_masked, labels_mlm.masked_select(masks_masked_perbatch))       
            info_acc.corrects_masked = torch.sum(hidden_mlm_masked.argmax(-1) == labels_mlm.masked_select(masks_masked_perbatch)).cpu().item()
            info_acc.num_masked = hidden_mlm_masked.shape[0]
        else:
            loss_masked = 0
            info_acc.corrects_masked = 0
            info_acc.num_masked = 1
        # end
        
        loss_mlm = loss_segments + loss_masked * 3
        
        return loss_mlm, info_acc
    # end
    
    
    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end

In [4]:
class SimpleDecoderHead_S2S(nn.Module):

    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'corrects_segmented': 0,
            'num_segmented': 0 
        })
    # end


    def __init__(self, model, size_vocab, dim_hidden=128, dropout=0.1):
        super(SimpleDecoderHead_S2S, self).__init__()
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden, dropout=dropout)
        self.extractor = torch.nn.Linear(dim_hidden, size_vocab, bias=False)
        self.extractor.weight = nn.Parameter(model.embedder_decoder.embedder[0].lut.weight)

        self.func_loss = torch.nn.CrossEntropyLoss()
        
        self.keys_cache = ['output', 'labels_s2s', 'segments_decoder']
        self.cache = Dotdict({
            'output': None,
            'labels_s2s': None,
            'segments_decoder': None
        })

    # end



    def forward(self, model, labels_decoder=None, segments_label=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        output_decoder = model.decoder.cache.output
        output_ffn = self.ffn(output_decoder)
        output_s2s = self.extractor(output_ffn)   # output_mlm = prediction_logits
        
        if not nocache:
            self.cache.segments_label = segments_label
            self.cache.labels_s2s =  labels_decoder
            self.cache.output = output_s2s
        # end

        return output_s2s
    # end


    def get_loss(self):
        labels_s2s = self.cache.labels_s2s
        output_s2s = self.cache.output
        info_acc = SimpleDecoderHead_S2S.get_info_accuracy_template()
        
        segments_label = self.cache.segments_label
        segments_label_2d = segments_label.transpose(-1,-2)[:,:,0]
        hidden_s2s_segmented = output_s2s.masked_select(segments_label_2d.unsqueeze(-1)).reshape(-1, output_s2s.shape[-1])

        loss_segments = self.func_loss(hidden_s2s_segmented, labels_s2s.masked_select(segments_label_2d))
        info_acc.corrects_segmented = torch.sum(hidden_s2s_segmented.argmax(-1) == labels_s2s.masked_select(segments_label_2d)).cpu().item()
        info_acc.num_segmented = hidden_s2s_segmented.shape[0]
        
        return loss_segments * 4, info_acc
    # end


    def evaluate(self):
        pass
    # end


    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end



In [5]:
class SimpleEncoderHead_AveragePooling_SC(nn.Module):  # SC-> SequenceClassification

    @classmethod
    def get_info_accuracy_template(cls):
        return Dotdict({
            'corrects': 0,
            'num': 0 
        })
    # end
    
    def __init__(self, num_labels, dim_hidden=128, dropout=0.1):
        super(SimpleEncoderHead_AveragePooling_SC, self).__init__()
        
        self.ffn = LinearAndNorm(dim_in=dim_hidden, dim_out=dim_hidden, dropout=dropout)
        self.classifier = torch.nn.Linear(dim_hidden, num_labels, bias=False)
        
        self.keys_cache = ['labels_sc', 'output']
        self.cache = Dotdict({
            'labels_sc': None,
            'output': None
        })
        
        self.func_loss = torch.nn.CrossEntropyLoss()
    # end


    def forward(self, model, labels_sc=None, nocache=False, **kwargs):   # labels_input -> (batch, seq, labels)
        output_encoder_pooled = model.cache.output_encoder_pooled
        output_ffn = self.ffn(output_encoder_pooled)
        output_sc = self.classifier(output_ffn) # output_sc = prediction_logits

        if not nocache:
            self.cache.labels_sc = labels_sc
            self.cache.output = output_sc
        # end

        return output_sc
    # end
    
    def get_loss(self):
        
        labels_sc = self.cache.labels_sc
        output_sc = self.cache.output
        
        info_acc = SimpleEncoderHead_AveragePooling_SC.get_info_accuracy_template()
        
        loss_sc = self.func_loss(output_sc, labels_sc)
        info_acc.corrects = torch.sum(output_sc.argmax(-1) == labels_sc).cpu().item()
        info_acc.num = output_sc.shape[0]
        
        return loss_sc, info_acc
    # end
    
    
    def clear_cache(self):
        for key_cache in self.keys_cache:
            self.cache[key_cache] = None
        # end
    # end
# end

In [6]:
class HeadManager(nn.Module):
    def __init__(self):
        super(HeadManager, self).__init__()
        self.index_name_head = set()
    # end

    def register(self, head):
        name_head = head.__class__.__name__
        setattr(self, name_head, head)
        self.index_name_head.add(name_head)
        return self
    # end

    def forward(self, model, **kwargs):
        for name in self.index_name_head:
            head = getattr(self, name)
            head.forward(model, **kwargs)
        # end
    # end

    def get_head(self, klass):
        return getattr(self, klass.__name__)
    # end

    def clear_cache(self):
        for name_head in self.index_name_head:
            getattr(self, name_head).clear_cache()
        # end
    # end
# end


class Trainer(nn.Module):
    def __init__(self, model=None, manager=None):
        super(Trainer, self).__init__()
        self.model = model
        self.manager = manager
    # end

    def forward(self, **kwargs):
        self.clear_cache()
        
        self.model.forward(**kwargs)
        self.manager.forward(self.model, **kwargs)
    # end
    
    def clear_cache(self):
        self.model.clear_cache() if self.model else None
        self.manager.clear_cache() if self.manager else None
    # end
# end


class ModelLoader:
    def __init__(self, path_checkpoints='./checkpoints'):
        self.dict_name_item = {}
        self.path_checkpoints = path_checkpoints
    # end
    
    def add_item(self, item, name=None):
        if not name:
            name = item.__class__.__name__
        # end
        
        self.dict_name_item[name] = item
        return self
    # end
    
    
    def update_checkpoint(self, name_checkpoint, name_checkpoint_previous=None):  # epoch_n
        if not self.dict_name_item:
            print(f'[ALERT] no item added, skip saving checkpoint.')
            return
        # end
        
        if name_checkpoint_previous:
            result = self._delete_checkpoint_folder(name_checkpoint_previous)
            if result:
                print(f'[INFO] {name_checkpoint_previous} is cleared.')
            else:
                print(f'[ALERT] {name_checkpoint_previous} fail to be cleared.')
            # end
        # end
        
        folder_checkpoint = self._create_checkpoint_folder(name_checkpoint)
        for name_item, item in self.dict_name_item.items():
            path_checkpoint_item = os.path.join(folder_checkpoint, f'{name_item}.pt')
            torch.save(item.state_dict(), path_checkpoint_item)
            
            size_file_saved_MB = os.path.getsize(path_checkpoint_item) / 1024 / 1024
            print(f'[INFO] {name_item} is saved, {size_file_saved_MB} MB')
        # end
        
        print(f'[INFO] {name_checkpoint} is saved')
    # end

    
    def load_item_state(self, name_checkpoint, instance_item, name_item=None):
        
        if not name_item:
            name_item = instance_item.__class__.__name__
        # end
        
        if name_checkpoint is None:
            print('[ALERT] ignoring loading {} due to no checkpoint'.format(name_item))
            return
        # end
        
        
        path_checkpoint_item = os.path.join(self.path_checkpoints, name_checkpoint, f'{name_item}.pt')
        if not os.path.exists(path_checkpoint_item):
            print(f'[ERROR] {path_checkpoint_item} not exists')
            return
        # end
        if issubclass(instance_item.__class__, torch.nn.Module):
            instance_item.load_state_dict(torch.load(path_checkpoint_item), strict=False)
        else:
            instance_item.load_state_dict(torch.load(path_checkpoint_item))
        # end
        
        print(f'[INFO] {name_item} loaded for {name_checkpoint}.')
        return instance_item
    # end
    
    
    def list_items(self):
        return list(self.dict_name_item.keys())
    # end
    
    def _create_checkpoint_folder(self, name_checkpoint):
        path_folder_target = os.path.join(self.path_checkpoints, name_checkpoint)
        Path(path_folder_target).mkdir(parents=True, exist_ok=True)
        return path_folder_target
    # end
    
    def _delete_checkpoint_folder(self, name_checkpoint_previous):
        path_folder_target = os.path.join(self.path_checkpoints, name_checkpoint_previous)
        if os.path.exists(path_folder_target):
            shutil.rmtree(path_folder_target, ignore_errors=True)
        # end
        return (not os.path.exists(path_folder_target))
    # end
# end

In [7]:
class Builder:
    
    @classmethod
    def build_model_with_mlm_sc_s2s(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer, num_labels):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)
        head_sc = SimpleEncoderHead_AveragePooling_SC(num_labels, dim_hidden)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)

        manager = HeadManager().register(head_mlm).register(head_sc).register(head_s2s)
        trainer = Trainer(model=model, manager=manager)

        return trainer
    # end
    
    @classmethod
    def load_model_with_mlm_sc_s2s(cls, size_vocab, dim_hidden, dim_feedforward, n_head, n_layer, num_labels, name_checkpoint):
        embedder_encoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_encoder = SimpleEncoderLayer(dim_hidden, dim_feedforward, n_head)
        encoderstack = SimpleTransformerStack(sample_encoder, n_layer)
        
        embedder_decoder = SimpleEmbedder(size_vocab=size_vocab, dim_hidden=dim_hidden)
        sample_decoder = SimpleDecoderLayer(dim_hidden, dim_feedforward, n_head)
        decoderstack = SimpleTransformerStack(sample_decoder, n_layer)

        model = SimpleEncoderDecoder(encoderstack, decoderstack, embedder_encoder, embedder_decoder, pooling=True)
        head_mlm = SimpleEncoderHead_MLM(model, size_vocab, dim_hidden)
        head_sc = SimpleEncoderHead_AveragePooling_SC(num_labels, dim_hidden)
        head_s2s = SimpleDecoderHead_S2S(model, size_vocab, dim_hidden)
        
        loader = ModelLoader('./checkpoints')
        loader.add_item(model).load_item_state(name_checkpoint, model)
        loader.add_item(head_s2s).load_item_state(name_checkpoint, head_s2s)
        loader.add_item(head_sc).load_item_state(name_checkpoint, head_sc)
        loader.add_item(head_mlm).load_item_state(name_checkpoint, head_mlm)

        manager = HeadManager().register(head_mlm).register(head_sc).register(head_s2s)
        trainer = Trainer(model=model, manager=manager)

        return trainer, loader
    # end

# end

def train_a_batch(batch, trainer, optimizer=None, scheduler=None):
    trainer.train()
    trainer.forward(**batch())
    
    loss_s2s, info_acc_s2s = trainer.manager.get_head(SimpleDecoderHead_S2S).get_loss()
    loss_mlm, info_acc_mlm = trainer.manager.get_head(SimpleEncoderHead_MLM).get_loss()
    loss_sc, info_acc_sc = trainer.manager.get_head(SimpleEncoderHead_AveragePooling_SC).get_loss()

    # crossentropy loss
    loss_all = loss_mlm + loss_sc + loss_s2s
    loss_all_value = loss_all.item()
    
    loss_all.backward()
    
    if optimizer:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    # end
    
    if scheduler:
        scheduler.step()
    # end
    
    trainer.clear_cache()
    return loss_all_value, Dotdict({'mlm': info_acc_mlm, 'sc': info_acc_sc, 's2s': info_acc_s2s})
# end


def evaluate_a_batch(batch, trainer, *args, **kwargs):
    trainer.eval()
    with torch.no_grad():
        trainer.forward(**batch())
    # end
    
    loss_s2s, info_acc_s2s = trainer.manager.get_head(SimpleDecoderHead_S2S).get_loss()
    loss_mlm, info_acc_mlm = trainer.manager.get_head(SimpleEncoderHead_MLM).get_loss()
    loss_sc, info_acc_sc = trainer.manager.get_head(SimpleEncoderHead_AveragePooling_SC).get_loss()
    
    # crossentropy loss
    loss_all = loss_mlm + loss_sc + loss_s2s
    loss_all_value = loss_all.item()
    
    trainer.clear_cache()
    return loss_all_value, Dotdict({'mlm': info_acc_mlm, 'sc': info_acc_sc, 's2s': info_acc_s2s})
# end


# For s2s head
def greedy_generate(model, head, tokenizer, collator, **kwargs):
    id_start = collator.id_cls
    id_end = collator.id_sep
    id_pad = collator.id_pad
    size_seq_max = collator.size_seq_max

    ids_encoder_twin = kwargs['ids_encoder']
    masks_encoder_twin = kwargs['masks_encoder']
    
    ids_decoder_all = []
    
    for j in range(ids_encoder_twin.shape[0]):
        ids_encoder = ids_encoder_twin[j,].unsqueeze(0)
        masks_encoder = masks_encoder_twin[j,].unsqueeze(0)

        output_encoder = model.embed_and_encode(ids_encoder=ids_encoder, masks_encoder=masks_encoder)
        ids_decoder = torch.zeros(1, 1).fill_(id_start).type_as(ids_encoder.data)

        for i in range(size_seq_max - 1):
            masks_decoder = collator.subsequent_mask(ids_decoder.size(1)).type_as(ids_encoder.data)
            output_decoder = model.embed_and_decode(ids_decoder=ids_decoder, masks_encoder=masks_encoder, output_encoder=output_encoder, masks_decoder=masks_decoder)

            output_ffn = head.ffn(output_decoder)
            output_s2s = head.extractor(output_ffn)   # output_mlm = prediction_logits

            logits_nextword = torch.softmax(output_s2s[:, -1], dim=-1)  # mynote: select dim2=-1, remain=all; last is the next

            id_nextword = torch.argmax(logits_nextword, dim=-1)
            id_nextword = id_nextword.data[0]

            if id_nextword == id_end:
                break
            # end

            ids_decoder = torch.cat([ids_decoder, torch.zeros(1, 1).type_as(ids_encoder.data).fill_(id_nextword)], dim=1)
        # end
        
        ids_pad = torch.full((1, size_seq_max - ids_decoder.shape[-1]), id_pad).type_as(ids_decoder.data)
        
        ids_decoder_all.append(torch.cat([ids_decoder, ids_pad], dim=-1).squeeze(0))
    # end for 

    return torch.stack(ids_decoder_all)
# end


In [8]:
def main(
    folder_data, folder_output, version_data, version_data_last, postfix_train, postfix_test,
    tokenizer, collator, index_label_2_labelid, index_labelid_2_label,
    epochs, seq_max, batch_size, dim_hidden, dim_feedforward, n_head, n_layer,
    lr_base_optimizer, betas_optimizer, eps_optimizer, warmup
):

    trainer, loader = Builder.load_model_with_mlm_sc_s2s(tokenizer.vocab_size, dim_hidden, dim_feedforward, n_head, n_layer, num_labels, str(version_data_last))
    trainer = trainer.to('cuda')

    train_source, valid_source, _ = GOSV(folder_data, version_data, postfix_train, index_label_2_labelid, split=0.1)
    test_source, _, _ = GOSV(folder_data, version_data, postfix_test, index_label_2_labelid, split=0)
    

    dataloader_train = DataLoader(train_source, batch_size, shuffle=False, collate_fn=collator)
    dataloader_eval = DataLoader(valid_source, batch_size, shuffle=False, collate_fn=collator)
    dataloader_test = DataLoader(test_source, 1, shuffle=False, collate_fn=collator)


    for p in trainer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
        # end
    # end


    optimizer = torch.optim.Adam(trainer.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
    
    ### start train/eval epochs ####################################
    for e in range(epochs):
        
        info_acc_heads_train = Dotdict({
            'mlm': SimpleEncoderHead_MLM.get_info_accuracy_template(),
            'sc': SimpleEncoderHead_AveragePooling_SC.get_info_accuracy_template(),
            's2s': SimpleDecoderHead_S2S.get_info_accuracy_template(),
        })


        info_acc_heads_eval = Dotdict({
            'mlm': SimpleEncoderHead_MLM.get_info_accuracy_template(),
            'sc': SimpleEncoderHead_AveragePooling_SC.get_info_accuracy_template(),
            's2s': SimpleDecoderHead_S2S.get_info_accuracy_template(),
        })

        # train phase
        
        losss_per_e = []
        for i, batch in enumerate(tqdm(dataloader_train)):
            loss_current, info_acc_heads_batch = train_a_batch(batch, trainer, optimizer, None)
            info_acc_heads_train += info_acc_heads_batch

            losss_per_e.append(loss_current)
            # if i % 200 == 0:
            #     print('Epoch: {} Batch: {}, loss: {}, rate: {}, acc_mlm: {}, acc_sc: {}, acc_s2s: {}'.format(
            #         e, i, loss_current, optimizer.param_groups[0]['lr'],
            #         info_acc_heads_batch.mlm.corrects_masked / info_acc_heads_batch.mlm.num_masked,
            #         info_acc_heads_batch.sc.corrects / info_acc_heads_batch.sc.num,
            #         info_acc_heads_batch.s2s.corrects_segmented / info_acc_heads_batch.s2s.num_segmented,
            #     ), end='\r')
            # # end
        # end

        loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
        print('[{}] Epoch: {} training ends. Status: Average loss: {}, Average MLM accuracy: {}, Average SC accuracy: {}, Average S2S accuracy: {}'.format(
            datetime.utcnow(), e, loss_average_per_e,
            info_acc_heads_train.mlm.corrects_masked / info_acc_heads_train.mlm.num_masked,
            info_acc_heads_train.sc.corrects / info_acc_heads_train.sc.num,
            info_acc_heads_train.s2s.corrects_segmented / info_acc_heads_train.s2s.num_segmented,
        ))

        if e % 2 == 0:
            lr_scheduler.step() # schedule per 2 epoch
        # end


        # eval phase start
        losss_per_e = []
        for i, batch in enumerate(tqdm(dataloader_eval)):
            loss_current, info_acc_heads_batch = evaluate_a_batch(batch, trainer)
            info_acc_heads_eval += info_acc_heads_batch

            losss_per_e.append(loss_current)
        # end

        loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
        # print('[{}] Epoch: {} Evalutation ends. Status: Average loss: {}, Average MLM accuracy: {}, Average SC accuracy: {}'.format(
        print('[{}] Epoch: {} Evalutation ends. Status: Average loss: {}, Average MLM accuracy: {}, Average SC accuracy: {}, Average S2S accuracy: {}'.format(        
            datetime.utcnow(), e, loss_average_per_e,
            info_acc_heads_eval.mlm.corrects_masked / info_acc_heads_eval.mlm.num_masked,
            info_acc_heads_eval.sc.corrects / info_acc_heads_eval.sc.num,
            info_acc_heads_eval.s2s.corrects_segmented / info_acc_heads_eval.s2s.num_segmented,
        ))
        # eval phase end
    # end
    ### end train/eval epochs ####################################
    
    
    ### start test  ##############################################
    trainer.eval()
    
    list_corpus_test = []
    for i, batch in enumerate(tqdm(dataloader_test)):

        info_batch = batch()
        with torch.no_grad():
            trainer.forward(**info_batch)
        # end

        # this is to calcluate s2s acc
        loss_s2s, info_acc_s2s = trainer.manager.get_head(SimpleDecoderHead_S2S).get_loss()
        num_s2s = info_acc_s2s['num_segmented']
        corrects_s2s = info_acc_s2s['corrects_segmented']
        info_s2s = {'num': num_s2s, 'corrects': corrects_s2s}
        # end

        # this is to calculate sc
        label_sc = trainer.manager.get_head(SimpleEncoderHead_AveragePooling_SC).cache.labels_sc.squeeze(0).detach().cpu() # (batch, label)  -> (label)
        logit_sc = trainer.manager.get_head(SimpleEncoderHead_AveragePooling_SC).cache.output.squeeze(0).detach().cpu()  #(batch, seq, num_label) -> (seq, num_label)
        
        pred_sc = logit_sc.argmax(-1) # (seq)
        conf_sc = torch.index_select(logit_sc.softmax(-1), -1, pred_sc)
        info_sc = {'pred': pred_sc.item(), 'label': label_sc.item(), 'conf': conf_sc.item()}
        
        trainer.clear_cache()

        # this is to calculate greedy decode acc
        list_acc_greedy_decode = []
        result = greedy_generate(trainer.model, trainer.manager.get_head(SimpleDecoderHead_S2S), tokenizer, collator, **info_batch)
        result_cpu_list = result.cpu().tolist()
        labels_decoder_cpu_list = info_batch['labels_decoder'].cpu().tolist()
        
        for result_cpu, labels_decoder in zip(result_cpu_list, labels_decoder_cpu_list):
        
            sentence_predicted = tokenizer.decode(result_cpu).split(' [PAD]')[0]
            set_tokens_predicted = set(sentence_predicted.split()) - set(['[CLS]','[SEP]'])
            # print('predicted: {}'.format(set_tokens_predicted))

            sentence_origin = tokenizer.decode(labels_decoder).split(' [PAD]')[0]
            set_tokens_origin = set(sentence_origin.split()) - set(['[CLS]','[SEP]'])
            # print('origin: {}'.format(set_tokens_origin))
            set_hit = set_tokens_predicted & set_tokens_origin
            list_acc_greedy_decode.append({'num': len(set_tokens_origin), 'corrects': len(set_hit)})
        # end for
        info_greedy = {'greedy': list_acc_greedy_decode}
        # end this

        list_corpus_test.append({'info_sc': info_sc,'info_s2s': info_s2s, 'info_greedy': info_greedy})
        trainer.clear_cache()
    # end for
    
    path_file_output = os.path.join(folder_output, f'{version_data}.json')
    with open(path_file_output, 'w+') as file:
        file.write(json.dumps(list_corpus_test))
    # end
        
    ### end   test  ##############################################
    
    loader.update_checkpoint(str(version_data), str(version_data_last))
# end


In [9]:
import re
import json
import transformers
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset
from transformers import AutoTokenizer
from datetime import datetime
from tqdm import tqdm


gpu = 2
torch.cuda.set_device(gpu)

epochs=1

# source
seq_max = 256
batch_size = 28


# model & head
dim_hidden = 768
# dim_feedforward = 768
dim_feedforward = 3072
n_head = 12
n_layer = 6

# optimizer
lr_base_optimizer = 1e-4
betas_optimizer = (0.9, 0.999)
eps_optimizer = 1e-9

# scheduler
warmup = 200

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
collator = Collator_SC(tokenizer, seq_max)  # TODO: here

# labels
index_label_2_labelid = {label:id_label for id_label, label in enumerate(sorted(["product","testcase","testbed","usererror","targetvm","nimbus","infra"]))}
index_labelid_2_label = {id_label: label for label, id_label in index_label_2_labelid.items()}
num_labels=len(index_label_2_labelid)


folder_data = 'data'
folder_output = 'outputs_gosv'
postfix_train = '_train_0.35_15.csv'
postfix_test = '_test.csv'

versions_data = sorted([int(filename.split('_test.csv')[0]) for filename in os.listdir(folder_data) if '_test.csv' in filename and filename[0] != '.'])
# versions_data = versions_data[1:]    #TODO remove

version_data_last = 202205240000
for version_data in versions_data:
    main(
        folder_data, folder_output, version_data, version_data_last, postfix_train, postfix_test,
        tokenizer, collator, index_label_2_labelid, index_labelid_2_label,
        epochs, seq_max, batch_size, dim_hidden, dim_feedforward, n_head, n_layer,
        lr_base_optimizer, betas_optimizer, eps_optimizer, warmup
    )
    
    print('[INFO] finish {}'.format(version_data))
    version_data_last = version_data
# end

In [None]:
print('done')

: 