In [1]:
%cd ../

/home/kireev/pycharm-deploy/vtb


In [2]:
import pickle
import random

In [3]:
from glob import glob

In [4]:
from itertools import chain

In [5]:
import numpy as np
import pandas as pd

In [6]:
import torch
import pytorch_lightning as pl

In [7]:
from pyhocon import ConfigFactory

In [8]:
import matplotlib.pyplot as plt

In [9]:
from dltranz.data_load.iterable_processing.category_size_clip import CategorySizeClip
from dltranz.data_load import augmentation_chain
from dltranz.data_load.augmentations.seq_len_limit import SeqLenLimit
from dltranz.data_load.augmentations.random_slice import RandomSlice

from dltranz.seq_encoder import create_encoder

from dltranz.metric_learn.sampling_strategies import get_sampling_strategy
from dltranz.metric_learn.losses import get_loss

from dltranz.tb_interface import get_scalars

from dltranz.data_load import padded_collate_wo_target

In [10]:
from vtb_code.data import PairedDataset, paired_collate_fn, PairedZeroDataset, DropDuplicate
from vtb_code.metrics import PrecisionK, MeanReciprocalRankK, ValidationCallback

In [11]:
FOLD_ID = 1

In [12]:
fold_id_test = FOLD_ID

In [13]:
folds_count = len(glob('data/train_matching_*.csv'))
folds_count

6

In [14]:
# fold_id_valid = np.random.choice([i for i in range(folds_count) if i != fold_id_test], size=1)[0]
fold_id_valid = (fold_id_test + 1) % folds_count
fold_id_valid

2

In [15]:
df_matching_train = pd.concat([pd.read_csv(f'data/train_matching_{i}.csv')
                              for i in range(folds_count) 
                              if i not in (fold_id_test, fold_id_valid)])
df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
df_matching_test = pd.read_csv(f'data/train_matching_{fold_id_test}.csv')

In [16]:
[len(df) for df in [df_matching_train, df_matching_valid, df_matching_test]]

[11721, 2930, 2930]

In [17]:
%%time
with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
    (
        features_trx_train,
        features_trx_valid,
        features_trx_test,
        features_trx_puzzle,
        features_click_train,
        features_click_valid,
        features_click_test,
        features_click_puzzle,
    ) = pickle.load(f)

CPU times: user 9.09 s, sys: 3.42 s, total: 12.5 s
Wall time: 12.5 s


# Preetrain

In [18]:
class MLMDataset(torch.utils.data.Dataset):
    def __init__(self, data, seq_len, random_shift):
        super().__init__()
        
        self.data = data
        self.seq_len = seq_len
        self.random_shift = random_shift
        
        self.keys = np.sort(np.array(list(data.keys())))
        self.ix = []
        for i, k in enumerate(self.keys):
            v = self.data[k]
            et = v['event_time']
            for j in range(0, len(et), seq_len // 2):
                self.ix.append([i, j])
        self.ix = np.array(self.ix)
        
    def __len__(self):
        return self.ix.shape[0]
    
    def __getitem__(self, item):
        ix = self.ix[item]
        v = self.data[self.keys[ix[0]]]
        et = v['event_time']
        pos = ix[1]
        pos = pos + random.randint(-self.random_shift, self.random_shift)
        pos = max(pos, 0)
        pos = min(pos, len(et) - self.seq_len // 2)
        return {k: v[pos: pos + self.seq_len] for k, v in v.items()}


In [19]:
mlm_train_trx_features = dict(chain(
    features_trx_train.items(),
    features_trx_puzzle.items(),
))
mlm_valid_trx_features = features_trx_test

mlm_train_click_features = dict(chain(
    features_click_train.items(),
    features_click_puzzle.items(),
))
mlm_valid_click_features = features_click_test

In [20]:
dd = DropDuplicate('mcc_code', col_new_cnt='c_cnt')
mlm_train_trx_features = {k: dd(v) for k, v in mlm_train_trx_features.items()}
mlm_valid_trx_features = {k: dd(v) for k, v in mlm_valid_trx_features.items()}

In [21]:
dd = DropDuplicate('cat_id', col_new_cnt='c_cnt')
mlm_train_click_features = {k: dd(v) for k, v in mlm_train_click_features.items()}
mlm_valid_click_features = {k: dd(v) for k, v in mlm_valid_click_features.items()}

In [22]:
train_dl_mlm_trx = torch.utils.data.DataLoader(
    MLMDataset(mlm_train_trx_features, 512, 32),
    collate_fn=padded_collate_wo_target,
    shuffle=True,
    num_workers=12,
    batch_size=64,
    persistent_workers=False,
)

valid_dl_mlm_trx = torch.utils.data.DataLoader(
    MLMDataset(mlm_valid_trx_features, 512, 32),
    collate_fn=padded_collate_wo_target,
    shuffle=False,
    num_workers=12,
    batch_size=16,
    persistent_workers=False,
)

In [23]:
train_dl_mlm_click = torch.utils.data.DataLoader(
    MLMDataset(mlm_train_click_features, 512, 32),
    collate_fn=padded_collate_wo_target,
    shuffle=True,
    num_workers=12,
    batch_size=64,
    persistent_workers=False,
)

valid_dl_mlm_click = torch.utils.data.DataLoader(
    MLMDataset(mlm_valid_click_features, 512, 32),
    collate_fn=padded_collate_wo_target,
    shuffle=False,
    num_workers=12,
    batch_size=64,
    persistent_workers=False,
)

In [24]:
len(valid_dl_mlm_trx), len(valid_dl_mlm_click)

(573, 564)

In [25]:
len(train_dl_mlm_trx), len(train_dl_mlm_click), len(valid_dl_mlm_trx), len(valid_dl_mlm_click), 

(814, 3053, 573, 564)

In [26]:
(len(train_dl_mlm_trx.dataset), len(train_dl_mlm_click.dataset), 
 len(valid_dl_mlm_trx.dataset), len(valid_dl_mlm_click.dataset))

(52063, 195372, 9160, 36056)

In [27]:
v = []
for batch in mlm_train_trx_features.values():
    v.append(batch['transaction_amt'])
v = torch.cat(v)

trx_amnt_quantiles = torch.quantile(torch.unique(v), torch.linspace(0, 1, 100))
trx_amnt_quantiles

tensor([-0.9679, -0.8048, -0.7571, -0.7250, -0.7015, -0.6823, -0.6679, -0.6557,
        -0.6437, -0.6332, -0.6233, -0.6151, -0.6064, -0.5995, -0.5919, -0.5855,
        -0.5789, -0.5724, -0.5663, -0.5601, -0.5539, -0.5481, -0.5427, -0.5372,
        -0.5321, -0.5266, -0.5211, -0.5158, -0.5103, -0.5052, -0.4999, -0.4963,
        -0.4932, -0.4900, -0.4865, -0.4831, -0.4802, -0.4769, -0.4738, -0.4705,
        -0.4670, -0.4639, -0.4607, -0.4574, -0.4541, -0.4512, -0.4479, -0.4446,
        -0.4414, -0.4385, -0.4352, -0.4318, -0.4285, -0.4253, -0.4221, -0.4190,
        -0.4156, -0.4121, -0.4091, -0.4058, -0.4024, -0.3991, -0.3958, -0.3923,
        -0.3887, -0.3854, -0.3816, -0.3778, -0.3739, -0.3697, -0.3663, -0.3621,
        -0.3575, -0.3537, -0.3486, -0.3431, -0.3375, -0.3312, -0.3255, -0.3183,
        -0.3126, -0.3052, -0.2980, -0.2897, -0.2807, -0.2722, -0.2639, -0.2576,
        -0.2498, -0.2417, -0.2233, -0.1300,  0.3547,  0.4939,  0.5650,  0.6230,
         0.6739,  0.7290,  0.8014,  1.00

In [28]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch

In [29]:
from vtb_code.models import MeanLoss

In [30]:
class CustomTrxTransform(torch.nn.Module):
    def __init__(self, trx_amnt_quantiles):
        super().__init__()
        self.trx_amnt_quantiles = torch.nn.Parameter(trx_amnt_quantiles, requires_grad=False)
        
    def forward(self, x):
        x.payload['transaction_amt_q'] = torch.bucketize(x.payload['transaction_amt'], self.trx_amnt_quantiles) + 1
        return x
    
class CustomClickTransform(torch.nn.Module):
    def forward(self, x):
#         x.payload['cat_id'] = torch.clamp(x.payload['cat_id'], 0, 300)
#         x.payload['level_0'] = torch.clamp(x.payload['level_0'], 0, 200)
#         x.payload['level_1'] = torch.clamp(x.payload['level_1'], 0, 200)
#         x.payload['level_2'] = torch.clamp(x.payload['level_2'], 0, 200)
#         x.payload['c_cnt_clamp'] = torch.clamp(x.payload['c_cnt'], 0, 20).int()
        return x

In [31]:
class DateFeaturesTransform(torch.nn.Module):
    def forward(self, x):
        et = x.payload['event_time'].int()
        et_day = et.div(24 * 60 * 60, rounding_mode='floor').int()
        x.payload['hour'] = et.div(60 * 60, rounding_mode='floor') % 24 + 1
        x.payload['weekday'] = et.div(60 * 60 * 24, rounding_mode='floor') % 7 + 1
        x.payload['day_diff'] = torch.clamp(torch.diff(et_day, prepend=et_day[:, :1], dim=1), 0, 14)
        return x

In [32]:
class PBLinear(torch.nn.Linear):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

In [33]:
class PBL2Norm(torch.nn.Module):
    def __init__(self, beta):    
        super().__init__()
        self.beta = beta
    
    def forward(self, x):
        return PaddedBatch(self.beta * x.payload / (x.payload.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5), 
                           x.seq_lens)

In [34]:
from transformers import LongformerConfig, LongformerModel

In [35]:
class MLMPretrainModule(pl.LightningModule):
    def __init__(self, trx_amnt_quantiles, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                ):
        super().__init__()
        self.save_hyperparameters()
        
        common_trx_size = params['common_trx_size']
        t = TrxEncoder(self.hparams.params['trx_seq.trx_encoder'])
        self.seq_encoder_trx = torch.nn.Sequential(
            CustomTrxTransform(trx_amnt_quantiles=trx_amnt_quantiles),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )

        t = TrxEncoder(self.hparams.params['click_seq.trx_encoder'])
        self.seq_encoder_click = torch.nn.Sequential(
            CustomClickTransform(),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )
            
        self.token_mask = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        self.token_cls = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        
        self.transf = LongformerModel(
            config=LongformerConfig(
                hidden_size=common_trx_size,
                num_attention_heads=params['transf.nhead'],
                intermediate_size=params['transf.dim_feedforward'],
                num_hidden_layers=params['transf.num_layers'],
                vocab_size=4,
                max_position_embeddings=self.hparams.params['transf.max_len'],
                attention_window=params['transf.attention_window'],
            ),
            add_pooling_layer=False,
        )
        
        self.train_mlm_loss = MeanLoss(compute_on_step=False)
        self.valid_mlm_loss = MeanLoss(compute_on_step=False)
                
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
            
    def get_mask(self, attention_mask):
        return torch.bernoulli(attention_mask.float() * self.hparams.params['mlm.replace_proba']).bool()
        
    def mask_x(self, x, attention_mask, mask):
        shuffled_tokens = x[attention_mask.bool()]
        B, T, H = x.size()
        ix = torch.multinomial(torch.ones(shuffled_tokens.size(0)), B * T, replacement=True)
        shuffled_tokens = shuffled_tokens[ix].view(B, T, H)
        
        rand = torch.rand(B, T, device=x.device).unsqueeze(2).expand(B, T, H)
        replace_to = torch.where(
            rand < 0.8,
            self.token_mask.expand_as(x),          # [MASK] token 80%
            torch.where(
                rand < 0.9,
                shuffled_tokens,                   # random token 10%
                x,                                 # unchanged 10%
            )
        )
        return torch.where(mask.bool().unsqueeze(2).expand_as(x), replace_to, x)
            
    def forward(self, z: PaddedBatch):
        B, T, H = z.payload.size()
        device = z.payload.device
        
        if self.training:
            start_pos = np.random.randint(0, self.hparams.params['transf.max_len'] - T - 1, 1)[0]
        else:
            start_pos = 0

        inputs_embeds=z.payload
        attention_mask=z.seq_len_mask.float()

        inputs_embeds = torch.cat([
            self.token_cls.expand(inputs_embeds.size(0), 1, H),
            inputs_embeds,
        ], dim=1)
        attention_mask = torch.cat([
            torch.ones(inputs_embeds.size(0), 1, device=device),
            attention_mask,
        ], dim=1)
        position_ids=torch.arange(T + 1, device=z.device).view(1, -1).expand(B, T + 1) + start_pos
        global_attention_mask = torch.cat([
            torch.ones(inputs_embeds.size(0), 1, device=device),
            torch.zeros(inputs_embeds.size(0), inputs_embeds.size(1) - 1, device=device),
        ], dim=1)

        out = self.transf(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            global_attention_mask=global_attention_mask,
        ).last_hidden_state

        return PaddedBatch(out, z.seq_lens)
        
    def get_neg_ix(self, mask):
        """Sample from predicts, where `mask == True`, without self element.
        sample from predicted tokens from batch
        """
        mn = mask.float().view(1, -1) - \
            torch.eye(mask.numel(), device=mask.device)[mask.flatten()]
        neg_ix = torch.multinomial(mn, self.hparams.params['mlm.neg_count'])
        b_ix = neg_ix.div(mask.size(1), rounding_mode='trunc')
        neg_ix = neg_ix % mask.size(1)
        return b_ix, neg_ix
        
    def loss_mlm(self, x: PaddedBatch):
        mask = self.get_mask(x.seq_len_mask)
        masked_x = self.mask_x(x.payload, x.seq_len_mask, mask)
        B, T, H = masked_x.size()
        
        out = self.forward(PaddedBatch(masked_x, x.seq_lens)).payload[:, 1:]
        
        target = x.payload[mask].unsqueeze(1)  # N, 1, H
        predict = out[mask].unsqueeze(1) # N, 1, H
        neg_ix = self.get_neg_ix(mask)
        negative = out[neg_ix[0], neg_ix[1]]  # N, nneg, H
        out_samples = torch.cat([predict, negative], dim=1)
        probas = torch.softmax((target * out_samples).sum(dim=2), dim=1)
        loss = -torch.log(probas[:, 0])
        return loss
    
    def training_step(self, batch, batch_idx):
        x_trx, x_click = batch
        
        z_trx = self.seq_encoder_trx(x_trx)  # PB: B, T, H
        z_click = self.seq_encoder_click(x_click)  # PB: B, T, H
        z = PaddedBatch(
            torch.cat([z_trx.payload, z_click.payload], dim=0),
            torch.cat([z_trx.seq_lens, z_click.seq_lens], dim=0),
        )
        
        loss_mlm = self.loss_mlm(z)
        self.train_mlm_loss(loss_mlm)
        loss_mlm = loss_mlm.mean()
        self.log(f'loss/mlm', loss_mlm)

        return loss_mlm

    def validation_step(self, batch, batch_idx):
        x_trx, x_click = batch
        
        z_trx = self.seq_encoder_trx(x_trx)  # PB: B, T, H
        z_click = self.seq_encoder_click(x_click)  # PB: B, T, H
        z = PaddedBatch(
            torch.cat([z_trx.payload, z_click.payload], dim=0),
            torch.cat([z_trx.seq_lens, z_click.seq_lens], dim=0),
        )
        
        loss_mlm = self.loss_mlm(z)
        self.valid_mlm_loss(loss_mlm)

    def training_epoch_end(self, _):
        self.log(f'metrics/train_mlm', self.train_mlm_loss, prog_bar=False)
        
    def validation_epoch_end(self, _):
        self.log(f'metrics/valid_mlm', self.valid_mlm_loss, prog_bar=True)


In [36]:
config_trx = ConfigFactory.parse_string('''
    common_trx_size: 256
    transf: {
        nhead: 2
        dim_feedforward: 512
        num_layers: 4
        attention_window: 32
        max_len: 6000
    }
    mlm: {
        replace_proba: 0.11
        neg_count: 128
        beta: 5
    }
    trx_seq: {
        trx_encoder: {
          use_batch_norm_with_lens: false
          norm_embeddings: false,
          embeddings_noise: 0.000,
          embeddings: {
            mcc_code: {in: 350, out: 64},
            currency_rk: {in: 10, out: 4}
            transaction_amt_q: {in: 110, out: 8}
            
            hour: {in: 30, out: 16}
            weekday: {in: 10, out: 4}
            day_diff: {in: 15, out: 8}
          },
          numeric_values: {
            transaction_amt: identity
            c_cnt: log
          }
          was_logified: false
          log_scale_factor: 1.0
        }
    }
    click_seq: {
        trx_encoder: {
          use_batch_norm_with_lens: false
          norm_embeddings: false,
          embeddings_noise: 0.000,
          embeddings: {
            cat_id: {in: 400, out: 64},
            level_0: {in: 400, out: 16}
            level_1: {in: 400, out: 8}
            level_2: {in: 400, out: 4}
            
            hour: {in: 30, out: 16}
            weekday: {in: 10, out: 4}
            day_diff: {in: 15, out: 8}
          },
          numeric_values: {
            c_cnt: log
          }
          was_logified: false
          log_scale_factor: 1.0
        }
    }
''')

mlm_model = MLMPretrainModule(
    trx_amnt_quantiles=trx_amnt_quantiles,
    params=config_trx,                     
    lr=0.001, weight_decay=0,
    max_lr=0.001, pct_start=3000 / 35000, total_steps=35000,
)


In [37]:
from pytorch_lightning.trainer.supporters import CombinedLoader

In [38]:
trainer = pl.Trainer(
    gpus=[0],
    max_steps=35000,
    enable_progress_bar=True,
    val_check_interval=700,
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        pl.callbacks.ModelCheckpoint(
            every_n_train_steps=1000, save_top_k=-1,
        ),
    ]
)
model_version = trainer.logger.version
print('baseline:  {:.3f}'.format(np.log(mlm_model.hparams.params['mlm.neg_count'] + 1)))
print(f'version = {model_version}')
trainer.fit(
    mlm_model,
    CombinedLoader([train_dl_mlm_trx, train_dl_mlm_click], mode='max_size_cycle'), 
    CombinedLoader([valid_dl_mlm_trx, valid_dl_mlm_click], mode='max_size_cycle'),
)
print('done')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


baseline:  4.860
version = 32



  | Name              | Type            | Params
------------------------------------------------------
0 | seq_encoder_trx   | Sequential      | 51.5 K
1 | seq_encoder_click | Sequential      | 68.7 K
2 | transf            | LongformerModel | 4.4 M 
3 | train_mlm_loss    | MeanLoss        | 0     
4 | valid_mlm_loss    | MeanLoss        | 0     
------------------------------------------------------
4.6 M     Trainable params
100       Non-trainable params
4.6 M     Total params
18.226    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



# Use pretrained

In [None]:
"""
v16: 128 size, 2 layers
v32: 256 size, 4 layers
"""


mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(
    'lightning_logs/version_152/checkpoints/epoch=54-step=9999.ckpt')  # 42
mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(
    'lightning_logs/version_153/checkpoints/epoch=65-step=9999.ckpt')  # 43

In [None]:
# mlm_model_trx.freeze()
# mlm_model_click.freeze()

In [None]:
from dltranz.seq_encoder.utils import NormEncoder

In [None]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch
from dltranz.seq_encoder.rnn_encoder import RnnEncoder
from dltranz.seq_encoder.utils import LastStepEncoder

In [None]:
class L2Scorer(torch.nn.Module):
    def forward(self, x):
        B, H = x.size()
        a, b =x[:, :H // 2], x[:, H // 2:]
        return -(a - b).pow(2).sum(dim=1)

In [None]:
class PBLayerNorm(torch.nn.LayerNorm):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

In [None]:
class PairedModule(pl.LightningModule):
    def __init__(self, params, k,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 beta, neg_count,
                ):
        super().__init__()
        self.save_hyperparameters()
        
        common_trx_size = mlm_model_trx.hparams.params['common_trx_size']
        self.rnn_enc =  torch.nn.Sequential(
#             RnnEncoder(common_trx_size, params['rnn']), 
#             LastStepEncoder(),
#             NormEncoder(),
        )
        self._seq_encoder_trx = torch.nn.Sequential(
#             mlm_model_trx.seq_encoder,
            mlm_model_trx,
            torch.nn.Linear(common_trx_size, 2 * common_trx_size),
#             PBLayerNorm(common_trx_size),
        )
        self._seq_encoder_click = torch.nn.Sequential(
#             mlm_model_click.seq_encoder,
            mlm_model_click,
            torch.nn.Linear(common_trx_size, 2 * common_trx_size),
#             PBLayerNorm(common_trx_size),
        )
#         self.mlm_model_click = mlm_model_click
        
        self.cls = torch.nn.Sequential(
            L2Scorer(),
        )

        self.train_precision = PrecisionK(k=k, compute_on_step=False)
        self.train_mrr = MeanReciprocalRankK(k=k, compute_on_step=False)
        self.valid_precision = PrecisionK(k=k, compute_on_step=False)
        self.valid_mrr = MeanReciprocalRankK(k=k, compute_on_step=False)
        
    def seq_encoder_trx(self, x):
        x = self._seq_encoder_trx(x)
        return self.rnn_enc(x)
    
    def seq_encoder_click(self, x_orig):
        x = self._seq_encoder_click(x_orig)
#         x = PaddedBatch(
#             x.payload + self.mlm_model_click.sentence_encoding(x_orig),
#             x.seq_lens,
#         )
        return self.rnn_enc(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def loss_fn_p(self, embeddings, labels, ref_emb, ref_labels):
        beta = self.hparams.beta
        neg_count = self.hparams.neg_count
        
        pos_ix = (labels.view(-1, 1) == ref_labels.view(1, -1)).nonzero(as_tuple=False)
        pos_labels = labels[pos_ix[:, 0]]
        neg_w = ((pos_labels.view(-1, 1) != ref_labels.view(1, -1))).float()
        neg_ix = torch.multinomial(neg_w, neg_count - 1)
        all_ix = torch.cat([pos_ix[:, [1]], neg_ix], dim=1)
        logits = -(embeddings[pos_ix[:, [0]]] - ref_emb[all_ix]).pow(2).sum(dim=2)
        logits = logits * beta
        logs = -torch.log(torch.softmax(logits, dim=1))[:, 0]
#         logs = torch.relu(logs + np.log(0.1))
        return logs.mean()

    def training_step(self, batch, batch_idx):
        # pairs
        x_trx, l_trx, m_trx, x_click, l_click, m_click = batch
        z_trx = self.seq_encoder_trx(x_trx)  # B, H
        z_click = self.seq_encoder_click(x_click)  # B, H
        loss_pt = self.loss_fn_p(embeddings=z_trx, labels=l_trx, ref_emb=z_click, ref_labels=l_click)
        self.log('loss/loss_pt', loss_pt)
        
        loss_pc = self.loss_fn_p(embeddings=z_click, labels=l_click, ref_emb=z_trx, ref_labels=l_trx)
        self.log('loss/loss_pc', loss_pc)
       
        with torch.no_grad():
            out = -(z_trx.unsqueeze(1) - z_click.unsqueeze(0)).pow(2).sum(dim=2)
            out = out[m_trx == 0][:, m_click == 0]
            T, C = out.size()
            assert T == C
            n_samples = z_trx.size(0) // (l_trx.max().item() + 1)
            for i in range(n_samples):
                l2 = out[i::n_samples, i::n_samples]
                self.train_precision(l2)
                self.train_mrr(l2)
        
        return loss_pt + 0.1 * loss_pc  #  loss_pc 

    def training_epoch_end(self, _):
        self.log('train_metrics/precision', self.train_precision, prog_bar=True)
        self.log('train_metrics/mrr', self.train_mrr, prog_bar=True)


In [None]:
batch_size = 128
train_dl = torch.utils.data.DataLoader(
    PairedZeroDataset(
        pd.concat([df_matching_train, df_matching_test], axis=0)[lambda x: x['rtk'].ne('0')].values,
        data=[
            dict(chain(features_trx_train.items(), features_trx_test.items())),
            dict(chain(features_click_train.items(), features_click_test.items())),
        ],
        augmentations=[
            augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
            augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
        ],
        n_sample=2,
    ),
    collate_fn=PairedZeroDataset.collate_fn,
    drop_last=True,
    shuffle=True,
    num_workers=24,
    batch_size=batch_size,
    persistent_workers=True,
)

# train_dl = torch.utils.data.DataLoader(
#     PairedZeroDataset(
#         pd.concat([df_matching_train, df_matching_test], axis=0)[lambda x: x['rtk'].ne('0')].values,
#         data=[
#             dict(chain(features_trx_train.items(), features_trx_test.items())),
#             dict(chain(features_click_train.items(), features_click_test.items())),
#         ],
#         augmentations=[
#             augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 512)),  # 1024
#             augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 1024)),  # 2048
#         ],
#         n_sample=2,
#     ),
#     collate_fn=PairedZeroDataset.collate_fn,
#     drop_last=True,
#     shuffle=True,
#     num_workers=24,
#     batch_size=batch_size,
#     persistent_workers=True,
# )

valid_dl_trx = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(df_matching_valid['bank'].unique()).reshape(-1, 1), 
        data=[
            features_trx_valid,
        ],
        augmentations=[
            augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), SeqLenLimit(2000)),  # 2000
        ],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=4,
    batch_size=128,
    persistent_workers=True,
)

valid_dl_click = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(df_matching_valid[lambda x: x['rtk'].ne('0')]['rtk'].unique()).reshape(-1, 1),
        data=[
            features_click_valid,
        ],
        augmentations=[
            augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), SeqLenLimit(5000)),  # 5000
        ],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=4,
    batch_size=128,
    persistent_workers=True,
)

In [None]:
sup_model = PairedModule(
    ConfigFactory.parse_string('''
    common_trx_size: 128
    rnn: {
      type: gru,
      hidden_size: 256,
      bidir: false,
      trainable_starter: static
    }
'''),                     
    k=100 * batch_size // 3000,
    lr=0.0022, weight_decay=0,
    max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
    beta=0.2 / 1.4, neg_count=120,
)


In [None]:
class ValidationCallback(pl.Callback):
    def __init__(self, v_trx, v_click, target, device, device_main, k=100, batch_size=1024):
        self.v_trx = v_trx
        self.v_click = v_click
        self.target = target
        self.device = device
        self.device_main = device_main
        self.k = k
        self.batch_size = batch_size

    def on_train_epoch_end(self, trainer, pl_module):
        was_traning = False
        if pl_module.training:
            pl_module.eval()
            was_traning = True

        pl_module.to(self.device)
        with torch.no_grad():
            z_trx = []
            for ((x_trx, _),) in self.v_trx:
                z_trx.append(pl_module.seq_encoder_trx(x_trx.to(self.device)))
            z_trx = torch.cat(z_trx, dim=0)
            z_click = []
            for ((x_click, _),) in self.v_click:
                z_click.append(pl_module.seq_encoder_click(x_click.to(self.device)))
            z_click = torch.cat(z_click, dim=0)

            T = z_trx.size(0)
            C = z_click.size(0)
            device = z_trx.device
            ix_t = torch.arange(T, device=device).view(-1, 1).expand(T, C).flatten()
            ix_c = torch.arange(C, device=device).view(1, -1).expand(T, C).flatten()

            z_out = []
            for i in range(0, len(ix_t), self.batch_size):
                z_pairs = torch.cat([
                    z_trx[ix_t[i:i + self.batch_size]],
                    z_click[ix_c[i:i + self.batch_size]],
                ], dim=1)
                z_out.append(pl_module.cls(z_pairs).unsqueeze(1))
            z_out = torch.cat(z_out, dim=0).view(T, C)

            precision, mrr, r1 = self.logits_to_metrics(z_out)

            pl_module.log('valid_full_metrics/precision', precision, prog_bar=True)
            pl_module.log('valid_full_metrics/mrr', mrr, prog_bar=False)
            pl_module.log('valid_full_metrics/r1', r1, prog_bar=False)

        pl_module.to(self.device_main)
        if was_traning:
            pl_module.train()

    def logits_to_metrics(self, z_out):
        T, C = z_out.size()
        z_ranks = torch.zeros_like(z_out)
        z_ranks[
            torch.arange(T, device=self.device).view(-1, 1).expand(T, C),
            torch.argsort(z_out, dim=1, descending=True),
        ] = torch.arange(C, device=self.device).float().view(1, -1).expand(T, C) + 1
        z_ranks = torch.cat([
            torch.ones(T, device=self.device).float().view(-1, 1),
            z_ranks + 1,
        ], dim=1)
        
        click_uids = np.concatenate([['0'], self.v_click.dataset.pairs[:, 0]])
        true_ranks = z_ranks[
            np.arange(T),
            np.searchsorted(click_uids,
                            self.target.set_index('bank')['rtk'].loc[self.v_trx.dataset.pairs[:, 0]].values)
        ]
        precision = torch.where(true_ranks <= self.k,
                                torch.ones(1, device=self.device), torch.zeros(1, device=self.device)).mean()
        mrr = torch.where(true_ranks <= self.k, 1 / true_ranks, torch.zeros(1, device=self.device)).mean()
        r1 = 2 * mrr * precision / (mrr + precision)
        return precision, mrr, r1


In [None]:
trainer = pl.Trainer(
    gpus=[0],
    max_steps=6000,
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        pl.callbacks.ModelCheckpoint(
            every_n_train_steps=1000, save_top_k=-1,
        ),
        ValidationCallback(valid_dl_trx, valid_dl_click, df_matching_valid,
                           torch.device('cuda:0'), torch.device('cuda:0')),
    ]
)

In [None]:
trainer.fit(sup_model, train_dl)  # valid_dl

In [None]:
"""
0: v78
1: v79
2: v80
3: v81
4: v82
5: v83
"""

In [None]:
df_m = get_scalars('lightning_logs/').set_index('version').loc[[f'version_{i}' for i in [78, 79, 80, 81, 82, 83]]]

# df = df_m[lambda x: x['tag'].str.startswith('train_metrics')] \
# .pivot(index='step', columns='tag', values='value')
# _, axs = plt.subplots(2, 1, figsize=(16, 15))
# for col, ax in zip(df.columns, axs):
#     df[col].plot(ax=ax, title=col, grid=True)
# plt.show()

# df = df_m[lambda x: x['tag'].str.startswith('valid_full_metrics')] \
# .pivot(index='step', columns='tag', values='value')
# _, axs = plt.subplots(3, 1, figsize=(16, 18))
# for col, ax in zip(df.columns, axs):
#     df[col].plot(ax=ax, title=col, grid=True)
# plt.show()

In [None]:
df_m[lambda x: x['tag'].str.startswith('valid_full_metrics')] \
.pivot_table(index='tag', columns='version', values='value', aggfunc=lambda x: x[-1]).round(4)

In [None]:
df_m[lambda x: x['tag'].str.startswith('valid_full_metrics')] \
.pivot_table(index='tag', columns='version', values='value', aggfunc=lambda x: x[-1]).mean(axis=1).round(4)