In [1]:
%cd ../

/mnt/kireev/pycharm-deploy/vtb


In [2]:
import pickle
import random

In [3]:
from glob import glob

In [4]:
from itertools import chain

In [5]:
import numpy as np
import pandas as pd

In [6]:
import torch
import pytorch_lightning as pl

In [7]:
from pyhocon import ConfigFactory

In [8]:
import matplotlib.pyplot as plt

In [9]:
from dltranz.data_load.iterable_processing.category_size_clip import CategorySizeClip
from dltranz.data_load import augmentation_chain
from dltranz.data_load.augmentations.seq_len_limit import SeqLenLimit
from dltranz.data_load.augmentations.random_slice import RandomSlice

from dltranz.seq_encoder import create_encoder

from dltranz.metric_learn.sampling_strategies import get_sampling_strategy
from dltranz.metric_learn.losses import get_loss

from dltranz.tb_interface import get_scalars

In [10]:
from vtb_code.data import PairedDataset, paired_collate_fn, PairedZeroDataset, DropDuplicate
from vtb_code.metrics import PrecisionK, MeanReciprocalRankK, ValidationCallback

In [None]:
FOLD_ID = 1

In [None]:
fold_id_test = FOLD_ID

In [None]:
folds_count = len(glob('data/train_matching_*.csv'))
folds_count

In [None]:
# fold_id_valid = np.random.choice([i for i in range(folds_count) if i != fold_id_test], size=1)[0]
fold_id_valid = (fold_id_test + 1) % folds_count
fold_id_valid

In [None]:
df_matching_train = pd.concat([pd.read_csv(f'data/train_matching_{i}.csv')
                              for i in range(folds_count) 
                              if i not in (fold_id_test, fold_id_valid)])
df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
df_matching_test = pd.read_csv(f'data/train_matching_{fold_id_test}.csv')

In [None]:
del fold_id_test

In [None]:
%%time
with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
    (
        features_trx_train,
        features_trx_valid,
        features_trx_test,
        features_trx_puzzle,
        features_click_train,
        features_click_valid,
        features_click_test,
        features_click_puzzle,
    ) = pickle.load(f)

# Preetrain, one for all models in ensemble

In [None]:
train_dl_mlm_trx = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(np.array(list(features_trx_train.keys()))).reshape(-1, 1),
        data=[features_trx_train],
        augmentations=[augmentation_chain(
            DropDuplicate('mcc_code', col_new_cnt='c_cnt'), 
            RandomSlice(32, 128)
        )],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=12,
    batch_size=128,
    persistent_workers=True,
)

In [None]:
train_dl_mlm_click = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(np.array(list(features_click_train.keys()))).reshape(-1, 1),
        data=[features_click_train],
        augmentations=[augmentation_chain(
            DropDuplicate('cat_id', col_new_cnt='c_cnt'), 
            RandomSlice(32, 128)
        )],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=12,
    batch_size=128,
    persistent_workers=True,
)

In [None]:
v = []
for batch in train_dl_mlm_trx:
    v.append(batch[0][0].payload['transaction_amt'][batch[0][0].seq_len_mask.bool()])
v = torch.cat(v)

trx_amnt_quantiles = torch.quantile(torch.unique(v), torch.linspace(0, 1, 100))
trx_amnt_quantiles

In [11]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch

In [12]:
from vtb_code.models import MeanLoss

In [13]:
class CustomTrxTransform(torch.nn.Module):
    def __init__(self, trx_amnt_quantiles):
        super().__init__()
        self.trx_amnt_quantiles = torch.nn.Parameter(trx_amnt_quantiles, requires_grad=False)
        
    def forward(self, x):
        x.payload['transaction_amt_q'] = torch.bucketize(x.payload['transaction_amt'], self.trx_amnt_quantiles) + 1
        return x
    
class CustomClickTransform(torch.nn.Module):
    def forward(self, x):
#         x.payload['cat_id'] = torch.clamp(x.payload['cat_id'], 0, 300)
#         x.payload['level_0'] = torch.clamp(x.payload['level_0'], 0, 200)
#         x.payload['level_1'] = torch.clamp(x.payload['level_1'], 0, 200)
#         x.payload['level_2'] = torch.clamp(x.payload['level_2'], 0, 200)
#         x.payload['c_cnt_clamp'] = torch.clamp(x.payload['c_cnt'], 0, 20).int()
        return x

In [14]:
class DateFeaturesTransform(torch.nn.Module):
    def forward(self, x):
        et = x.payload['event_time'].int()
        et_day = et.div(24 * 60 * 60, rounding_mode='floor').int()
        x.payload['hour'] = et.div(60 * 60, rounding_mode='floor') % 24 + 1
        x.payload['weekday'] = et.div(60 * 60 * 24, rounding_mode='floor') % 7 + 1
        x.payload['day_diff'] = torch.clamp(torch.diff(et_day, prepend=et_day[:, :1], dim=1), 0, 14)
        return x

In [15]:
class PBLinear(torch.nn.Linear):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

In [16]:
class PBL2Norm(torch.nn.Module):
    def __init__(self, beta):    
        super().__init__()
        self.beta = beta
    
    def forward(self, x):
        return PaddedBatch(self.beta * x.payload / (x.payload.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5), 
                           x.seq_lens)

In [17]:
from vtb_code.data import frequency_encoder

In [18]:
class MLMPretrainModule(pl.LightningModule):
    def __init__(self, data_type, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                ):
        super().__init__()
        self.save_hyperparameters()
        
        common_trx_size = params['common_trx_size']
        self.seq_encoder = None
        
        self.token_mask = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        self.transf = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=common_trx_size,
                nhead=params['transf.nhead'],
                dim_feedforward=params['transf.dim_feedforward'],
                dropout=params['transf.dropout'],
                batch_first=True,
            ),
            num_layers=params['transf.num_layers'], 
            norm=torch.nn.LayerNorm(common_trx_size) if params['transf.norm'] else None,
        )
        
        if params['transf.use_pe']:
            self.pe = torch.nn.Parameter(self.get_pe(), requires_grad=False)
        else:
            self.pe = None
        self.padding_mask = torch.nn.Parameter(torch.tensor([True, False]).bool(), requires_grad=False)

        self.train_mlm_loss_all = MeanLoss(compute_on_step=False)
        self.valid_mlm_loss_all = MeanLoss(compute_on_step=False)
        self.train_mlm_loss_self = MeanLoss(compute_on_step=False)
        self.valid_mlm_loss_self = MeanLoss(compute_on_step=False)
        
    def get_pe(self):
        max_len = self.hparams.params['transf.max_len']
        H = self.hparams.params['common_trx_size']
        f = 2 * np.pi * torch.arange(max_len).view(1, -1, 1) / \
        torch.exp(torch.linspace(*np.log([4, max_len]), H // 2)).view(1, 1, -1)
        return torch.cat([torch.sin(f), torch.cos(f)], dim=2)
        
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
            
    def get_mask(self, x: PaddedBatch):
        return torch.bernoulli(x.seq_len_mask.float() * self.hparams.params['mlm.replace_proba']).bool()
        
    def mask_x(self, x: PaddedBatch, mask):
        return torch.where(mask.unsqueeze(2).expand_as(x.payload), 
                           self.token_mask.expand_as(x.payload), x.payload)
        
    def get_neg_ix(self, mask, neg_type):
        """Sample from predicts, where `mask == True`, without self element.
        For `neg_type='all'` - sample from predicted tokens from batch
        For `neg_type='self'` - sample from predicted tokens from row
        """
        if neg_type == 'all':
            mn = mask.float().view(1, -1) - \
                torch.eye(mask.numel(), device=mask.device)[mask.flatten()]
            neg_ix = torch.multinomial(mn, self.hparams.params['mlm.neg_count_all'])
            b_ix = neg_ix.div(mask.size(1), rounding_mode='trunc')
            neg_ix = neg_ix % mask.size(1)
            return b_ix, neg_ix
        if neg_type == 'self':
            mask_ix = mask.nonzero(as_tuple=False)
            one_pos = torch.eye(mask.size(1), device=mask.device)[mask_ix[:, 1]]
            mn = mask[mask_ix[:, 0]].float() - one_pos
            mn = mn + 1e-9 * (1 - one_pos)
            neg_ix = torch.multinomial(mn, self.hparams.params['mlm.neg_count_self'], replacement=True)
            b_ix = mask_ix[:, 0].view(-1, 1).expand_as(neg_ix)
            return b_ix, neg_ix
        raise AttributeError(f'Unknown neg_type: {neg_type}')
    
    def sentence_encoding(self, x: PaddedBatch):
        return None
        
    def mlm_loss(self, x: PaddedBatch, neg_type, x_orig: PaddedBatch):
        mask = self.get_mask(x)
        masked_x = self.mask_x(x, mask)
        B, T, H = masked_x.size()
        
        if self.pe is not None:
            if self.training:
                start_pos = np.random.randint(0, self.hparams.params['transf.max_len'] - T, 1)[0]
            else:
                start_pos = 0
            pe = self.pe[:, start_pos:start_pos + T]
            masked_x = masked_x + pe
            
        se = self.sentence_encoding(x_orig)
        if se is not None:
            masked_x = masked_x + se

        out = self.transf(masked_x, src_key_padding_mask=self.padding_mask[x.seq_len_mask])
        
        if self.pe is not None:
            out = out - pe
        if se is not None:
            out = out - se

        target = x.payload[mask].unsqueeze(1)  # N, 1, H
        predict = out[mask].unsqueeze(1) # N, 1, H
        neg_ix = self.get_neg_ix(mask, neg_type)
        negative = out[neg_ix[0], neg_ix[1]]  # N, nneg, H
        out_samples = torch.cat([predict, negative], dim=1)
        probas = torch.softmax((target * out_samples).sum(dim=2), dim=1)
        loss = -torch.log(probas[:, 0])
        return loss
    
    def training_step(self, batch, batch_idx):
        (x_trx, _), = batch
        
        z_trx = self.seq_encoder(x_trx)  # PB: B, T, H
        
        loss_mlm = self.mlm_loss(z_trx, neg_type='all', x_orig=x_trx)
        self.train_mlm_loss_all(loss_mlm)
        loss_mlm_all = loss_mlm.mean()
        self.log(f'loss/mlm_{self.hparams.data_type}', loss_mlm_all)

        loss_mlm = self.mlm_loss(z_trx, neg_type='self', x_orig=x_trx)
        self.train_mlm_loss_self(loss_mlm)
        loss_mlm_self = loss_mlm.mean()
        self.log(f'loss/mlm_{self.hparams.data_type}_self', loss_mlm_self)
        
        return loss_mlm_all + loss_mlm_self

    def validation_step(self, batch, batch_idx):
        (x_trx, _), = batch
        z_trx = self.seq_encoder(x_trx)  # PB: B, T, H
        
        loss_mlm = self.mlm_loss(z_trx, neg_type='all', x_orig=x_trx)
        self.valid_mlm_loss_all(loss_mlm)
        
        loss_mlm = self.mlm_loss(z_trx, neg_type='self', x_orig=x_trx)
        self.valid_mlm_loss_self(loss_mlm)
        
    def training_epoch_end(self, _):
        self.log(f'metrics/train_{self.hparams.data_type}_mlm', self.train_mlm_loss_all, prog_bar=False)
        self.log(f'metrics/train_{self.hparams.data_type}_mlm_self', self.train_mlm_loss_self, prog_bar=False)
        
    def validation_epoch_end(self, _):
        self.log(f'metrics/valid_{self.hparams.data_type}_mlm', self.valid_mlm_loss_all, prog_bar=True)
        self.log(f'metrics/valid_{self.hparams.data_type}_mlm_self', self.valid_mlm_loss_self, prog_bar=True)

        
class MLMPretrainModuleTrx(MLMPretrainModule):
    def __init__(self,
                 trx_amnt_quantiles, 
                 params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                ):
        super().__init__(data_type='trx',
                         params=params,
                         lr=lr, weight_decay=weight_decay,
                         max_lr=max_lr, pct_start=pct_start, total_steps=total_steps,
                        )
        self.save_hyperparameters()
        
        common_trx_size = self.hparams.params['common_trx_size']
        t = TrxEncoder(self.hparams.params['trx_seq.trx_encoder'])
        self.seq_encoder = torch.nn.Sequential(
            CustomTrxTransform(trx_amnt_quantiles=trx_amnt_quantiles),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )
        

class MLMPretrainModuleClick(MLMPretrainModule):
    def __init__(self, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                ):
        super().__init__(data_type='click',
                         params=params,
                         lr=lr, weight_decay=weight_decay,
                         max_lr=max_lr, pct_start=pct_start, total_steps=total_steps,
                        )
        self.save_hyperparameters()
        
        common_trx_size = self.hparams.params['common_trx_size']
        t = TrxEncoder(self.hparams.params['click_seq.trx_encoder'])
        self.seq_encoder = torch.nn.Sequential(
            CustomClickTransform(),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )
        
#         self.se = torch.nn.Embedding(7, common_trx_size, padding_idx=0)

#     def sentence_encoding(self, x: PaddedBatch):
#         se = torch.stack([frequency_encoder(v, m.bool())
#                           for v, m in zip(x.payload['new_uid'], x.seq_len_mask)], dim=0).clamp(None, 6)
#         return self.se(se)
    

In [None]:
config = ConfigFactory.parse_string('''
    common_trx_size: 256
    transf: {
        nhead: 4
        dim_feedforward: 1024
        dropout: 0.1
        num_layers: 3
        norm: false
        max_len: 6000
        use_pe: true
    }
    mlm: {
        replace_proba: 0.1
        neg_count_all: 64
        neg_count_self: 8
        beta: 10
    }
    trx_seq: {
        trx_encoder: {
          use_batch_norm_with_lens: false
          norm_embeddings: false,
          embeddings_noise: 0.000,
          embeddings: {
            mcc_code: {in: 350, out: 64},
            currency_rk: {in: 10, out: 4}
            transaction_amt_q: {in: 110, out: 8}
            
            hour: {in: 30, out: 16}
            weekday: {in: 10, out: 4}
            day_diff: {in: 15, out: 8}
          },
          numeric_values: {
            transaction_amt: identity
            c_cnt: log
          }
          was_logified: false
          log_scale_factor: 1.0
        },
    }
    click_seq: {
        trx_encoder: {
          use_batch_norm_with_lens: false
          norm_embeddings: false,
          embeddings_noise: 0.000,
          embeddings: {
            cat_id: {in: 400, out: 64},
            level_0: {in: 400, out: 16}
            level_1: {in: 400, out: 8}
            level_2: {in: 400, out: 4}
            
            hour: {in: 30, out: 16}
            weekday: {in: 10, out: 4}
            day_diff: {in: 15, out: 8}
          },
          numeric_values: {
            c_cnt: log
          }
          was_logified: false
          log_scale_factor: 1.0
        },
    }
''')

mlm_model_trx = MLMPretrainModuleTrx(
    params=config,                     
    lr=0.001, weight_decay=0,
    max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
    trx_amnt_quantiles=trx_amnt_quantiles,
)
mlm_model_click = MLMPretrainModuleClick(
    params=config,                     
    lr=0.001, weight_decay=0,
    max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
)


In [None]:
trainer = pl.Trainer(
    gpus=[0],
    max_steps=8000,
    enable_progress_bar=False,
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        pl.callbacks.ModelCheckpoint(
            every_n_train_steps=2000, save_top_k=-1,
        ),
    ]
)
model_version_trx = trainer.logger.version
print('baseline all:  {:.3f}'.format(np.log(mlm_model_trx.hparams.params['mlm.neg_count_all'] + 1)))
print('baseline self: {:.3f}'.format(np.log(mlm_model_trx.hparams.params['mlm.neg_count_self'] + 1)))
print(f'version = {model_version_trx}')
trainer.fit(mlm_model_trx, train_dl_mlm_trx)
print('done')

In [None]:
trainer = pl.Trainer(
    gpus=[0],
    max_steps=6000,
    enable_progress_bar=False,
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        pl.callbacks.ModelCheckpoint(
            every_n_train_steps=2000, save_top_k=-1,
        ),
    ]
)
model_version_click = trainer.logger.version
print('baseline all:  {:.3f}'.format(np.log(mlm_model_click.hparams.params['mlm.neg_count_all'] + 1)))
print('baseline self: {:.3f}'.format(np.log(mlm_model_click.hparams.params['mlm.neg_count_self'] + 1)))
print(f'version = {model_version_click}')
trainer.fit(mlm_model_click, train_dl_mlm_click)
print('done')

# Use pretrained

In [19]:
from dltranz.seq_encoder.utils import NormEncoder

In [20]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch
from dltranz.seq_encoder.rnn_encoder import RnnEncoder
from dltranz.seq_encoder.utils import LastStepEncoder

In [21]:
class L2Scorer(torch.nn.Module):
    def forward(self, x):
        B, H = x.size()
        a, b =x[:, :H // 2], x[:, H // 2:]
        return -(a - b).pow(2).sum(dim=1)

In [22]:
class PBLayerNorm(torch.nn.LayerNorm):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

In [23]:
class PairedModule(pl.LightningModule):
    def __init__(self, params, k,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 beta, neg_count,
                ):
        super().__init__()
        self.save_hyperparameters()
        
        common_trx_size = mlm_model_trx.hparams.params['common_trx_size']
        self.rnn_enc =  torch.nn.Sequential(
            RnnEncoder(common_trx_size, params['rnn']), 
            LastStepEncoder(),
#             NormEncoder(),
        )
        self._seq_encoder_trx = torch.nn.Sequential(
            mlm_model_trx.seq_encoder,
            PBLayerNorm(common_trx_size),
        )
        self._seq_encoder_click = torch.nn.Sequential(
            mlm_model_click.seq_encoder,
            PBLayerNorm(common_trx_size),
        )
        self.mlm_model_click = mlm_model_click
        
        self.cls = torch.nn.Sequential(
            L2Scorer(),
        )

        self.train_precision = PrecisionK(k=k, compute_on_step=False)
        self.train_mrr = MeanReciprocalRankK(k=k, compute_on_step=False)
        self.valid_precision = PrecisionK(k=k, compute_on_step=False)
        self.valid_mrr = MeanReciprocalRankK(k=k, compute_on_step=False)
        
    def seq_encoder_trx(self, x):
        x = self._seq_encoder_trx(x)
        return self.rnn_enc(x)
    
    def seq_encoder_click(self, x_orig):
        x = self._seq_encoder_click(x_orig)
#         x = PaddedBatch(
#             x.payload + self.mlm_model_click.sentence_encoding(x_orig),
#             x.seq_lens,
#         )
        return self.rnn_enc(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def loss_fn_p(self, embeddings, labels, ref_emb, ref_labels):
        beta = self.hparams.beta
        neg_count = self.hparams.neg_count
        
        pos_ix = (labels.view(-1, 1) == ref_labels.view(1, -1)).nonzero(as_tuple=False)
        pos_labels = labels[pos_ix[:, 0]]
        neg_w = ((pos_labels.view(-1, 1) != ref_labels.view(1, -1))).float()
        neg_ix = torch.multinomial(neg_w, neg_count - 1)
        all_ix = torch.cat([pos_ix[:, [1]], neg_ix], dim=1)
        logits = -(embeddings[pos_ix[:, [0]]] - ref_emb[all_ix]).pow(2).sum(dim=2)
        logits = logits * beta
        logs = -torch.log(torch.softmax(logits, dim=1))[:, 0]
#         logs = torch.relu(logs + np.log(0.1))
        return logs.mean()

    def training_step(self, batch, batch_idx):
        # pairs
        x_trx, l_trx, m_trx, x_click, l_click, m_click = batch
        z_trx = self.seq_encoder_trx(x_trx)  # B, H
        z_click = self.seq_encoder_click(x_click)  # B, H
        loss_pt = self.loss_fn_p(embeddings=z_trx, labels=l_trx, ref_emb=z_click, ref_labels=l_click)
        self.log('loss/loss_pt', loss_pt)
        
        loss_pc = self.loss_fn_p(embeddings=z_click, labels=l_click, ref_emb=z_trx, ref_labels=l_trx)
        self.log('loss/loss_pc', loss_pc)
       
        with torch.no_grad():
            out = -(z_trx.unsqueeze(1) - z_click.unsqueeze(0)).pow(2).sum(dim=2)
            out = out[m_trx == 0][:, m_click == 0]
            T, C = out.size()
            assert T == C
            n_samples = z_trx.size(0) // (l_trx.max().item() + 1)
            for i in range(n_samples):
                l2 = out[i::n_samples, i::n_samples]
                self.train_precision(l2)
                self.train_mrr(l2)
        
        return loss_pt + 0.1 * loss_pc  #  loss_pc 

    def training_epoch_end(self, _):
        self.log('train_metrics/precision', self.train_precision, prog_bar=True)
        self.log('train_metrics/mrr', self.train_mrr, prog_bar=True)


In [None]:
for enseble_fold_id in range(folds_count):
    if enseble_fold_id == fold_id_valid:
        continue
    
    batch_size = 128
    train_dl = torch.utils.data.DataLoader(
        PairedZeroDataset(
            pd.concat([df_matching_train, df_matching_test], axis=0)[lambda x: x['rtk'].ne('0')].values,
            data=[
                dict(chain(features_trx_train.items(), features_trx_test.items())),
                dict(chain(features_click_train.items(), features_click_test.items())),
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
            ],
            n_sample=2,
        ),
        collate_fn=PairedZeroDataset.collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=24,
        batch_size=batch_size,
        persistent_workers=False,
    )
    
#     mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(
#         'lightning_logs/version_/checkpoints/epoch=86-step=7999.ckpt')  # 42
#     mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(
#         'lightning_logs/version_/checkpoints/epoch=77-step=5999.ckpt')  # 43
    pl.seed_everything(random.randint(1, 2**16-1))
    
    sup_model = PairedModule(
        ConfigFactory.parse_string('''
        common_trx_size: 128
        rnn: {
          type: gru,
          hidden_size: 256,
          bidir: false,
          trainable_starter: static
        }
    '''),                     
        k=100 * batch_size // 3000,
        lr=0.0022, weight_decay=0,
        max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
        beta=0.2 / 1.4, neg_count=120,
    )

    trainer = pl.Trainer(
        gpus[0],
        max_steps=3000,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=1000, save_top_k=-1,
            ),
        ]
    )
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={fold_id_valid}, '
          f'enseble_fold_id={enseble_fold_id}, version={trainer.logger.version}')
    trainer.fit(sup_model, train_dl)

# Validation

In [24]:
class ValidationCallback():
    def __init__(self, v_trx, v_click, target, device, k=100, batch_size=1024):
        self.v_trx = v_trx
        self.v_click = v_click
        self.target = target
        self.k = k
        self.batch_size = batch_size
        self.device = device

    def run(self, model):
        device = self.device
        model.eval()
        model.to(device)

        with torch.no_grad():
            z_trx = []
            for ((x_trx, _),) in self.v_trx:
                z_trx.append(model.seq_encoder_trx(x_trx.to(device)))
            z_trx = torch.cat(z_trx, dim=0)
            z_click = []
            for ((x_click, _),) in self.v_click:
                z_click.append(model.seq_encoder_click(x_click.to(device)))
            z_click = torch.cat(z_click, dim=0)

            z_out = []
            for i in range(0, z_trx.size(0), self.batch_size):
                z_out.append(
                    -((z_trx[i:i + self.batch_size].unsqueeze(1) - z_click.unsqueeze(0)).pow(2)).sum(dim=2)
                )
            z_out = torch.cat(z_out, dim=0)

        return z_out
        precision, mrr, r1 = self.logits_to_metrics(z_out)

        return precision.item(), mrr.item(), r1.item()

    def ranks(self, z_out):
        T, C = z_out.size()
        z_ranks = torch.zeros_like(z_out)
        z_ranks[
            torch.arange(T, device=z_out.device).view(-1, 1).expand(T, C),
            torch.argsort(z_out, dim=1, descending=True),
        ] = torch.arange(C, device=z_out.device).float().view(1, -1).expand(T, C) + 1
        return z_ranks
        
    def logits_to_metrics(self, z_out):
        T, C = z_out.size()
        z_ranks = self.ranks(z_out)
        z_ranks = torch.cat([
            torch.ones(T, device=z_out.device).float().view(-1, 1),
            z_ranks + 1,
        ], dim=1)
        
        click_uids = np.concatenate([['0'], self.v_click.dataset.pairs[:, 0]])
        true_ranks = z_ranks[
            np.arange(T),
            np.searchsorted(click_uids,
                            self.target.set_index('bank')['rtk'].loc[self.v_trx.dataset.pairs[:, 0]].values)
        ]
        precision = torch.where(true_ranks <= self.k,
                                torch.ones(1, device=z_out.device), torch.zeros(1, device=z_out.device)).mean()
        mrr = torch.where(true_ranks <= self.k, 1 / true_ranks, torch.zeros(1, device=z_out.device)).mean()
        r1 = 2 * mrr * precision / (mrr + precision)
        return precision.item(), mrr.item(), r1.item()


In [25]:
# only model size required here
mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(
    'lightning_logs/version_0/checkpoints/epoch=21-step=1999.ckpt')  # 42
mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(
    'lightning_logs/version_5/checkpoints/epoch=25-step=1999.ckpt')  # 43

In [26]:
# ensemple with train_dl.pd.sample(frac=0.85)
model_map = [
    {'FOLD_ID': 0, 'fold_id_valid': 1, 'mv': [23, 27, 30, 33, 36]},
    {'FOLD_ID': 1, 'fold_id_valid': 2, 'mv': [24, 28, 31, 34, 37]},
    {'FOLD_ID': 2, 'fold_id_valid': 3, 'mv': [25, 26, 29, 32, 35]},
#     {'FOLD_ID': 3, 'fold_id_valid': 4, 'mv': []},
#     {'FOLD_ID': 4, 'fold_id_valid': 5, 'mv': []},
#     {'FOLD_ID': 5, 'fold_id_valid': 0, 'mv': []},
]

In [27]:
!ls -l lightning_logs/version_23/checkpoints/

total 64104
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 05:42 'epoch=15-step=1999.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 05:53 'epoch=23-step=2999.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 05:31 'epoch=7-step=999.ckpt'


In [28]:
for mmap in model_map:
    FOLD_ID = mmap['FOLD_ID']
    fold_id_test = FOLD_ID
    folds_count = len(glob('data/train_matching_*.csv'))
    fold_id_valid = (fold_id_test + 1) % folds_count
    del fold_id_test
    df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
    with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
        (
            _,
            features_trx_valid,
            _,
            _,
            _,
            features_click_valid,
            _,
            _,
        ) = pickle.load(f)
    
    valid_dl_trx = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid['bank'].unique()).reshape(-1, 1), 
            data=[
                features_trx_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), SeqLenLimit(2000)),  # 2000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=128,
        persistent_workers=False,
    )

    valid_dl_click = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid[lambda x: x['rtk'].ne('0')]['rtk'].unique()).reshape(-1, 1),
            data=[
                features_click_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), SeqLenLimit(5000)),  # 5000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=128,
        persistent_workers=False,
    )
    
    vc = ValidationCallback(valid_dl_trx, valid_dl_click, df_matching_valid, 
                            torch.device('cuda:1'))    
    res = []
    for model_version in mmap['mv']:
        sup_model = PairedModule.load_from_checkpoint(
            f'lightning_logs/version_{model_version}/checkpoints/epoch=23-step=2999.ckpt')
        
        zo = vc.run(sup_model)
        precision, mrr, r1 = vc.logits_to_metrics(zo)
        res.append(zo)
        print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, model_version={model_version}: '
              f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
        
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).min(dim=0).values)
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance min: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack([-vc.ranks(o) for o in res], dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble     rank sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
   

FOLD_ID=0, fold_id_valid=1, model_version=23: precision = 0.472, mrr = 0.190, r1 = 0.271
FOLD_ID=0, fold_id_valid=1, model_version=27: precision = 0.479, mrr = 0.191, r1 = 0.273
FOLD_ID=0, fold_id_valid=1, model_version=30: precision = 0.487, mrr = 0.190, r1 = 0.273
FOLD_ID=0, fold_id_valid=1, model_version=33: precision = 0.484, mrr = 0.192, r1 = 0.275
FOLD_ID=0, fold_id_valid=1, model_version=36: precision = 0.490, mrr = 0.190, r1 = 0.274
FOLD_ID=0, fold_id_valid=1, ensemble distance sum: precision = 0.524, mrr = 0.197, r1 = 0.287
FOLD_ID=0, fold_id_valid=1, ensemble distance min: precision = 0.502, mrr = 0.194, r1 = 0.280
FOLD_ID=0, fold_id_valid=1, ensemble     rank sum: precision = 0.522, mrr = 0.197, r1 = 0.286
FOLD_ID=1, fold_id_valid=2, model_version=24: precision = 0.481, mrr = 0.194, r1 = 0.276
FOLD_ID=1, fold_id_valid=2, model_version=28: precision = 0.481, mrr = 0.193, r1 = 0.275
FOLD_ID=1, fold_id_valid=2, model_version=31: precision = 0.483, mrr = 0.192, r1 = 0.275
FOLD_I

In [56]:
# ensemple full data and changing seed
model_map = [
    {'FOLD_ID': 0, 'fold_id_valid': 1, 'mv': [8, 11, 14, 17, 20]},
    {'FOLD_ID': 1, 'fold_id_valid': 2, 'mv': [9, 13, 15, 18, 21]},
    {'FOLD_ID': 2, 'fold_id_valid': 3, 'mv': [10, 12, 16, 19, 22]},
#     {'FOLD_ID': 3, 'fold_id_valid': 4, 'mv': []},
#     {'FOLD_ID': 4, 'fold_id_valid': 5, 'mv': []},
#     {'FOLD_ID': 5, 'fold_id_valid': 0, 'mv': []},
]

In [57]:
for mmap in model_map:
    FOLD_ID = mmap['FOLD_ID']
    fold_id_test = FOLD_ID
    folds_count = len(glob('data/train_matching_*.csv'))
    fold_id_valid = (fold_id_test + 1) % folds_count
    del fold_id_test
    df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
    with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
        (
            _,
            features_trx_valid,
            _,
            _,
            _,
            features_click_valid,
            _,
            _,
        ) = pickle.load(f)
    
    valid_dl_trx = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid['bank'].unique()).reshape(-1, 1), 
            data=[
                features_trx_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), SeqLenLimit(2000)),  # 2000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=32,
        persistent_workers=False,
    )

    valid_dl_click = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid[lambda x: x['rtk'].ne('0')]['rtk'].unique()).reshape(-1, 1),
            data=[
                features_click_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), SeqLenLimit(5000)),  # 5000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=32,
        persistent_workers=False,
    )
    
    vc = ValidationCallback(valid_dl_trx, valid_dl_click, df_matching_valid, 
                            torch.device('cuda:1'))
    
    res = []
    for model_version in mmap['mv']:
        sup_model = PairedModule.load_from_checkpoint(
            f'lightning_logs/version_{model_version}/checkpoints/epoch=26-step=2999.ckpt')
        
        zo = vc.run(sup_model)
        precision, mrr, r1 = vc.logits_to_metrics(zo)
        res.append(zo)
        print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, model_version={model_version}: '
              f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).min(dim=0).values)
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance min: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')  
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack([-vc.ranks(o) for o in res], dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble     rank sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')

FOLD_ID=0, fold_id_valid=1, model_version=8: precision = 0.494, mrr = 0.193, r1 = 0.277
FOLD_ID=0, fold_id_valid=1, model_version=11: precision = 0.507, mrr = 0.194, r1 = 0.280
FOLD_ID=0, fold_id_valid=1, model_version=14: precision = 0.512, mrr = 0.196, r1 = 0.283
FOLD_ID=0, fold_id_valid=1, model_version=17: precision = 0.500, mrr = 0.194, r1 = 0.280
FOLD_ID=0, fold_id_valid=1, model_version=20: precision = 0.498, mrr = 0.194, r1 = 0.279
FOLD_ID=0, fold_id_valid=1, ensemble distance sum: precision = 0.544, mrr = 0.201, r1 = 0.294
FOLD_ID=0, fold_id_valid=1, ensemble distance min: precision = 0.533, mrr = 0.198, r1 = 0.289
FOLD_ID=0, fold_id_valid=1, ensemble     rank sum: precision = 0.541, mrr = 0.201, r1 = 0.293
FOLD_ID=1, fold_id_valid=2, model_version=9: precision = 0.500, mrr = 0.196, r1 = 0.282
FOLD_ID=1, fold_id_valid=2, model_version=13: precision = 0.512, mrr = 0.195, r1 = 0.282
FOLD_ID=1, fold_id_valid=2, model_version=15: precision = 0.502, mrr = 0.195, r1 = 0.281
FOLD_ID=

In [38]:
# ensemple full data and changing seed
model_map = [
    {'FOLD_ID': 0, 'fold_id_valid': 1, 'mv': [42]},
    {'FOLD_ID': 1, 'fold_id_valid': 2, 'mv': [41]},
    {'FOLD_ID': 2, 'fold_id_valid': 3, 'mv': [43]},
#     {'FOLD_ID': 3, 'fold_id_valid': 4, 'mv': []},
#     {'FOLD_ID': 4, 'fold_id_valid': 5, 'mv': []},
#     {'FOLD_ID': 5, 'fold_id_valid': 0, 'mv': []},
]

In [39]:
!ls -l lightning_logs/version_41/checkpoints/

total 320520
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 08:48 'epoch=1-step=199.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:03 'epoch=10-step=1199.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:05 'epoch=12-step=1399.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:08 'epoch=14-step=1599.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:10 'epoch=15-step=1799.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:12 'epoch=17-step=1999.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:15 'epoch=19-step=2199.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:17 'epoch=21-step=2399.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:19 'epoch=22-step=2599.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:22 'epoch=24-step=2799.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 09:24 'epoch=26-step=2999.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 08:51 'epoch=3-step=399.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 08:55 'epoch=5-step=599.ckpt'
-rw-r--r-- 1 ivan sudo 21880483 Apr 21 08:58 'epoch=7-st

In [40]:
for mmap in model_map:
    FOLD_ID = mmap['FOLD_ID']
    fold_id_test = FOLD_ID
    folds_count = len(glob('data/train_matching_*.csv'))
    fold_id_valid = (fold_id_test + 1) % folds_count
    del fold_id_test
    df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
    with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
        (
            _,
            features_trx_valid,
            _,
            _,
            _,
            features_click_valid,
            _,
            _,
        ) = pickle.load(f)
    
    valid_dl_trx = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid['bank'].unique()).reshape(-1, 1), 
            data=[
                features_trx_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), SeqLenLimit(2000)),  # 2000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=128,
        persistent_workers=False,
    )

    valid_dl_click = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(df_matching_valid[lambda x: x['rtk'].ne('0')]['rtk'].unique()).reshape(-1, 1),
            data=[
                features_click_valid,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), SeqLenLimit(5000)),  # 5000
            ],
            n_sample=1,
        ),
        collate_fn=paired_collate_fn,
        shuffle=False,
        num_workers=4,
        batch_size=128,
        persistent_workers=False,
    )
    
    vc = ValidationCallback(valid_dl_trx, valid_dl_click, df_matching_valid, 
                            torch.device('cuda:1'))
    
    model_version = mmap['mv'][0]
    tails = [
#         'epoch=1-step=199.ckpt',
#         'epoch=3-step=399.ckpt',
#         'epoch=5-step=599.ckpt',
#         'epoch=7-step=799.ckpt',
#         'epoch=8-step=999.ckpt',
#         'epoch=10-step=1199.ckpt',
#         'epoch=12-step=1399.ckpt',
#         'epoch=14-step=1599.ckpt',
#         'epoch=15-step=1799.ckpt',
#         'epoch=17-step=1999.ckpt',
        'epoch=19-step=2199.ckpt',
        'epoch=21-step=2399.ckpt',
        'epoch=22-step=2599.ckpt',
        'epoch=24-step=2799.ckpt',
        'epoch=26-step=2999.ckpt',
    ]
    res = []
    for step_tail in tails:
        sup_model = PairedModule.load_from_checkpoint(
            f'lightning_logs/version_{model_version}/checkpoints/{step_tail}')
        
        zo = vc.run(sup_model)
        precision, mrr, r1 = vc.logits_to_metrics(zo)
        res.append(zo)
        print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, '
              f'model_version={model_version}[{step_tail}]: '
              f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack(res, dim=0).min(dim=0).values)
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble distance min: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    
    precision, mrr, r1 = vc.logits_to_metrics(torch.stack([-vc.ranks(o) for o in res], dim=0).sum(dim=0))
    print(f'FOLD_ID={FOLD_ID}, fold_id_valid={mmap["fold_id_valid"]}, ensemble     rank sum: '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f}')
    

FOLD_ID=0, fold_id_valid=1, model_version=42[epoch=19-step=2199.ckpt]: precision = 0.515, mrr = 0.193, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, model_version=42[epoch=21-step=2399.ckpt]: precision = 0.514, mrr = 0.194, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, model_version=42[epoch=22-step=2599.ckpt]: precision = 0.513, mrr = 0.193, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, model_version=42[epoch=24-step=2799.ckpt]: precision = 0.513, mrr = 0.194, r1 = 0.282
FOLD_ID=0, fold_id_valid=1, model_version=42[epoch=26-step=2999.ckpt]: precision = 0.513, mrr = 0.194, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, ensemble distance sum: precision = 0.515, mrr = 0.194, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, ensemble distance min: precision = 0.514, mrr = 0.194, r1 = 0.281
FOLD_ID=0, fold_id_valid=1, ensemble     rank sum: precision = 0.514, mrr = 0.194, r1 = 0.281
FOLD_ID=1, fold_id_valid=2, model_version=41[epoch=19-step=2199.ckpt]: precision = 0.495, mrr = 0.195, r1 = 0.280
FOLD_ID=1, fold_id_valid=2, model_

```
| Models in ensemble          | FOLD_ID | model count | mean by folds | ensemble | boost  |
| --------------------------- | ------- | ----------- | ------------- | -------- | ------ |
| sample 85% of train dataset | 0       |        5    | 0.2732        | 0.287    | 1.0505 |
| sample 85% of train dataset | 1       |        5    | 0.2748        | 0.288    | 1.0480 |
| sample 85% of train dataset | 2       |        5    | 0.2746        | 0.287    | 1.0452 |
| --------------------------- | ------- | ----------- | ------------- | -------- | ------ |
| full dataset                | 0       |        5    | 0.2798        | 0.294    | 1.0508 |
| full dataset                | 1       |        5    | 0.2820        | 0.295    | 1.0461 |
| full dataset                | 2       |        5    | 0.2802        | 0.293    | 1.0457 |
| --------------------------- | ------- | ----------- | ------------- | -------- | ------ |
| model checkpoints           | 0       |        5    | 0.2812        | 0.281    | 0.9993 |
| model checkpoints           | 1       |        5    | 0.2796        | 0.280    | 1.0014 |
| model checkpoints           | 2       |        5    | 0.2796        | 0.280    | 1.0014 |
| --------------------------- | ------- | ----------- | ------------- | -------- | ------ |

```