In [1]:
%cd ../

/home/kireev/pycharm-deploy/vtb


In [2]:
import pickle
import random

In [3]:
from glob import glob

In [4]:
from itertools import chain

In [5]:
import numpy as np
import pandas as pd

In [6]:
import torch
import pytorch_lightning as pl

In [7]:
from pyhocon import ConfigFactory

In [8]:
import matplotlib.pyplot as plt

In [9]:
from dltranz.data_load.iterable_processing.category_size_clip import CategorySizeClip
from dltranz.data_load import augmentation_chain
from dltranz.data_load.augmentations.seq_len_limit import SeqLenLimit
from dltranz.data_load.augmentations.random_slice import RandomSlice

from dltranz.seq_encoder import create_encoder

from dltranz.metric_learn.sampling_strategies import get_sampling_strategy
from dltranz.metric_learn.losses import get_loss

from dltranz.tb_interface import get_scalars

from dltranz.data_load import padded_collate_wo_target

In [10]:
from vtb_code.data import PairedDataset, paired_collate_fn, PairedZeroDataset, DropDuplicate
from vtb_code.metrics import PrecisionK, MeanReciprocalRankK, ValidationCallback

In [11]:
FOLD_ID = 1

In [12]:
fold_id_test = FOLD_ID

In [13]:
folds_count = len(glob('data/train_matching_*.csv'))
folds_count

6

In [14]:
# fold_id_valid = np.random.choice([i for i in range(folds_count) if i != fold_id_test], size=1)[0]
fold_id_valid = (fold_id_test + 1) % folds_count
fold_id_valid

2

In [15]:
df_matching_train = pd.concat([pd.read_csv(f'data/train_matching_{i}.csv')
                              for i in range(folds_count) 
                              if i not in (fold_id_test, fold_id_valid)])
df_matching_valid = pd.read_csv(f'data/train_matching_{fold_id_valid}.csv')
df_matching_test = pd.read_csv(f'data/train_matching_{fold_id_test}.csv')

In [16]:
[len(df) for df in [df_matching_train, df_matching_valid, df_matching_test]]

[11721, 2930, 2930]

In [17]:
%%time
with open(f'data/features_f{FOLD_ID}.pickle', 'rb') as f:
    (
        features_trx_train,
        features_trx_valid,
        features_trx_test,
        features_trx_puzzle,
        features_click_train,
        features_click_valid,
        features_click_test,
        features_click_puzzle,
    ) = pickle.load(f)

CPU times: user 8.85 s, sys: 3.61 s, total: 12.5 s
Wall time: 12.5 s


# Preetrain

In [18]:
# v = []
# for batch in features_trx_train.values():
#     v.append(batch['transaction_amt'])
# v = torch.cat(v)

# trx_amnt_quantiles = torch.quantile(torch.unique(v), torch.linspace(0, 1, 100))
# trx_amnt_quantiles

In [19]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch

In [20]:
from vtb_code.models import MeanLoss

In [21]:
class CustomTrxTransform(torch.nn.Module):
    def __init__(self, trx_amnt_quantiles):
        super().__init__()
        self.trx_amnt_quantiles = torch.nn.Parameter(trx_amnt_quantiles, requires_grad=False)
        
    def forward(self, x):
        x.payload['transaction_amt_q'] = torch.bucketize(x.payload['transaction_amt'], self.trx_amnt_quantiles) + 1
        return x
    
class CustomClickTransform(torch.nn.Module):
    def forward(self, x):
#         x.payload['cat_id'] = torch.clamp(x.payload['cat_id'], 0, 300)
#         x.payload['level_0'] = torch.clamp(x.payload['level_0'], 0, 200)
#         x.payload['level_1'] = torch.clamp(x.payload['level_1'], 0, 200)
#         x.payload['level_2'] = torch.clamp(x.payload['level_2'], 0, 200)
#         x.payload['c_cnt_clamp'] = torch.clamp(x.payload['c_cnt'], 0, 20).int()
        return x

In [22]:
class DateFeaturesTransform(torch.nn.Module):
    def forward(self, x):
        et = x.payload['event_time'].int()
        et_day = et.div(24 * 60 * 60, rounding_mode='floor').int()
        x.payload['hour'] = et.div(60 * 60, rounding_mode='floor') % 24 + 1
        x.payload['weekday'] = et.div(60 * 60 * 24, rounding_mode='floor') % 7 + 1
        x.payload['day_diff'] = torch.clamp(torch.diff(et_day, prepend=et_day[:, :1], dim=1), 0, 14)
        return x

In [23]:
class PBLinear(torch.nn.Linear):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

In [24]:
class PBL2Norm(torch.nn.Module):
    def __init__(self, beta):    
        super().__init__()
        self.beta = beta
    
    def forward(self, x):
        return PaddedBatch(self.beta * x.payload / (x.payload.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5), 
                           x.seq_lens)

In [25]:
from transformers import LongformerConfig, LongformerModel

In [26]:
class MLMPretrainModule(pl.LightningModule):
    def __init__(self, trx_amnt_quantiles, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                ):
        super().__init__()
        self.save_hyperparameters()
        
        common_trx_size = params['common_trx_size']
        t = TrxEncoder(self.hparams.params['trx_seq.trx_encoder'])
        self.seq_encoder_trx = torch.nn.Sequential(
            CustomTrxTransform(trx_amnt_quantiles=trx_amnt_quantiles),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )

        t = TrxEncoder(self.hparams.params['click_seq.trx_encoder'])
        self.seq_encoder_click = torch.nn.Sequential(
            CustomClickTransform(),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, common_trx_size),
            PBL2Norm(self.hparams.params['mlm.beta']),
        )
            
        self.token_mask = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        self.token_cls = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        
        self.transf = LongformerModel(
            config=LongformerConfig(
                hidden_size=common_trx_size,
                num_attention_heads=params['transf.nhead'],
                intermediate_size=params['transf.dim_feedforward'],
                num_hidden_layers=params['transf.num_layers'],
                vocab_size=4,
                max_position_embeddings=self.hparams.params['transf.max_len'],
                attention_window=params['transf.attention_window'],
            ),
            add_pooling_layer=False,
        )
        
        self.train_mlm_loss = MeanLoss(compute_on_step=False)
        self.valid_mlm_loss = MeanLoss(compute_on_step=False)
                
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
            
    def get_mask(self, attention_mask):
        return torch.bernoulli(attention_mask.float() * self.hparams.params['mlm.replace_proba']).bool()
        
    def mask_x(self, x, attention_mask, mask):
        shuffled_tokens = x[attention_mask.bool()]
        B, T, H = x.size()
        ix = torch.multinomial(torch.ones(shuffled_tokens.size(0)), B * T, replacement=True)
        shuffled_tokens = shuffled_tokens[ix].view(B, T, H)
        
        rand = torch.rand(B, T, device=x.device).unsqueeze(2).expand(B, T, H)
        replace_to = torch.where(
            rand < 0.8,
            self.token_mask.expand_as(x),          # [MASK] token 80%
            torch.where(
                rand < 0.9,
                shuffled_tokens,                   # random token 10%
                x,                                 # unchanged 10%
            )
        )
        return torch.where(mask.bool().unsqueeze(2).expand_as(x), replace_to, x)
            
    def forward(self, z: PaddedBatch):
        B, T, H = z.payload.size()
        device = z.payload.device
        
        if self.training:
            start_pos = np.random.randint(0, self.hparams.params['transf.max_len'] - T - 1, 1)[0]
        else:
            start_pos = 0

        inputs_embeds=z.payload
        attention_mask=z.seq_len_mask.float()

        inputs_embeds = torch.cat([
            self.token_cls.expand(inputs_embeds.size(0), 1, H),
            inputs_embeds,
        ], dim=1)
        attention_mask = torch.cat([
            torch.ones(inputs_embeds.size(0), 1, device=device),
            attention_mask,
        ], dim=1)
        position_ids=torch.arange(T + 1, device=z.device).view(1, -1).expand(B, T + 1) + start_pos
        global_attention_mask = torch.cat([
            torch.ones(inputs_embeds.size(0), 1, device=device),
            torch.zeros(inputs_embeds.size(0), inputs_embeds.size(1) - 1, device=device),
        ], dim=1)

        out = self.transf(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            global_attention_mask=global_attention_mask,
        ).last_hidden_state

        return PaddedBatch(out, z.seq_lens)
        
    def get_neg_ix(self, mask):
        """Sample from predicts, where `mask == True`, without self element.
        sample from predicted tokens from batch
        """
        mn = mask.float().view(1, -1) - \
            torch.eye(mask.numel(), device=mask.device)[mask.flatten()]
        neg_ix = torch.multinomial(mn, self.hparams.params['mlm.neg_count'])
        b_ix = neg_ix.div(mask.size(1), rounding_mode='trunc')
        neg_ix = neg_ix % mask.size(1)
        return b_ix, neg_ix
        
    def loss_mlm(self, x: PaddedBatch):
        mask = self.get_mask(x.seq_len_mask)
        masked_x = self.mask_x(x.payload, x.seq_len_mask, mask)
        B, T, H = masked_x.size()
        
        out = self.forward(PaddedBatch(masked_x, x.seq_lens)).payload[:, 1:]
        
        target = x.payload[mask].unsqueeze(1)  # N, 1, H
        predict = out[mask].unsqueeze(1) # N, 1, H
        neg_ix = self.get_neg_ix(mask)
        negative = out[neg_ix[0], neg_ix[1]]  # N, nneg, H
        out_samples = torch.cat([predict, negative], dim=1)
        probas = torch.softmax((target * out_samples).sum(dim=2), dim=1)
        loss = -torch.log(probas[:, 0])
        return loss
    
    def training_step(self, batch, batch_idx):
        x_trx, x_click = batch
        
        z_trx = self.seq_encoder_trx(x_trx)  # PB: B, T, H
        z_click = self.seq_encoder_click(x_click)  # PB: B, T, H
        z = PaddedBatch(
            torch.cat([z_trx.payload, z_click.payload], dim=0),
            torch.cat([z_trx.seq_lens, z_click.seq_lens], dim=0),
        )
        
        loss_mlm = self.loss_mlm(z)
        self.train_mlm_loss(loss_mlm)
        loss_mlm = loss_mlm.mean()
        self.log(f'loss/mlm', loss_mlm)

        return loss_mlm

    def validation_step(self, batch, batch_idx):
        x_trx, x_click = batch
        
        z_trx = self.seq_encoder_trx(x_trx)  # PB: B, T, H
        z_click = self.seq_encoder_click(x_click)  # PB: B, T, H
        z = PaddedBatch(
            torch.cat([z_trx.payload, z_click.payload], dim=0),
            torch.cat([z_trx.seq_lens, z_click.seq_lens], dim=0),
        )
        
        loss_mlm = self.loss_mlm(z)
        self.valid_mlm_loss(loss_mlm)

    def training_epoch_end(self, _):
        self.log(f'metrics/train_mlm', self.train_mlm_loss, prog_bar=False)
        
    def validation_epoch_end(self, _):
        self.log(f'metrics/valid_mlm', self.valid_mlm_loss, prog_bar=True)


# Use pretrained

In [27]:
"""

"""


mlm_model = MLMPretrainModule.load_from_checkpoint(
    'lightning_logs/version_32/checkpoints/epoch=11-step=34999.ckpt')

In [28]:
# mlm_model.freeze()

In [29]:
from tqdm.auto import tqdm

In [30]:
from dltranz.seq_encoder.utils import NormEncoder

In [31]:
from dltranz.trx_encoder import TrxEncoder, PaddedBatch
from dltranz.seq_encoder.rnn_encoder import RnnEncoder
from dltranz.seq_encoder.utils import LastStepEncoder, FirstStepEncoder

In [32]:
valid_dl_trx = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(df_matching_valid['bank'].unique()).reshape(-1, 1), 
        data=[
            features_trx_valid,
        ],
        augmentations=[
            augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), SeqLenLimit(2000)),  # 2000
        ],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=4,
    batch_size=256,
    persistent_workers=False,
)

valid_dl_click = torch.utils.data.DataLoader(
    PairedDataset(
        np.sort(df_matching_valid[lambda x: x['rtk'].ne('0')]['rtk'].unique()).reshape(-1, 1),
        data=[
            features_click_valid,
        ],
        augmentations=[
            augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), SeqLenLimit(5000)),  # 5000
        ],
        n_sample=1,
    ),
    collate_fn=paired_collate_fn,
    shuffle=False,
    num_workers=4,
    batch_size=128,
    persistent_workers=False,
)

class ValidationCallback():
    def __init__(self, v_trx, v_click, target, k=100, batch_size=1024):
        self.v_trx = v_trx
        self.v_click = v_click
        self.target = target
        self.k = k
        self.batch_size = batch_size

    def on_train_epoch_end(self):
        seq_encoder_trx.eval()
        seq_encoder_click.eval()

        with torch.no_grad():
            z_trx = []
            for ((x_trx, _),) in self.v_trx:
                z_trx.append(seq_encoder_trx(x_trx.to(device)))
            z_trx = torch.cat(z_trx, dim=0)
            z_click = []
            for ((x_click, _),) in self.v_click:
                z_click.append(seq_encoder_click(x_click.to(device)))
            z_click = torch.cat(z_click, dim=0)

            z_out = []
            for i in range(0, z_trx.size(0), self.batch_size):
                z_out.append(
                    -((z_trx[i:i + self.batch_size].unsqueeze(1) - z_click.unsqueeze(0)).pow(2)).sum(dim=2)
                )
            z_out = torch.cat(z_out, dim=0)

            precision, mrr, r1 = self.logits_to_metrics(z_out)

        return precision.item(), mrr.item(), r1.item()

    def logits_to_metrics(self, z_out):
        T, C = z_out.size()
        z_ranks = torch.zeros_like(z_out)
        z_ranks[
            torch.arange(T, device=z_out.device).view(-1, 1).expand(T, C),
            torch.argsort(z_out, dim=1, descending=True),
        ] = torch.arange(C, device=z_out.device).float().view(1, -1).expand(T, C) + 1
        z_ranks = torch.cat([
            torch.ones(T, device=z_out.device).float().view(-1, 1),
            z_ranks + 1,
        ], dim=1)
        
        click_uids = np.concatenate([['0'], self.v_click.dataset.pairs[:, 0]])
        true_ranks = z_ranks[
            np.arange(T),
            np.searchsorted(click_uids,
                            self.target.set_index('bank')['rtk'].loc[self.v_trx.dataset.pairs[:, 0]].values)
        ]
        precision = torch.where(true_ranks <= self.k,
                                torch.ones(1, device=z_out.device), torch.zeros(1, device=z_out.device)).mean()
        mrr = torch.where(true_ranks <= self.k, 1 / true_ranks, torch.zeros(1, device=z_out.device)).mean()
        r1 = 2 * mrr * precision / (mrr + precision)
        return precision, mrr, r1

vc = ValidationCallback(valid_dl_trx, valid_dl_click, df_matching_valid)

In [33]:
df_match = pd.concat([df_matching_train, df_matching_test], axis=0)[lambda x: x['rtk'].ne('0')]

In [34]:
train_features_trx = dict(chain(features_trx_train.items(), features_trx_test.items()))
train_features_click = dict(chain(features_click_train.items(), features_click_test.items()))

In [35]:
train_uid_t = np.sort(np.array(list(train_features_trx.keys())))
train_uid_c = np.sort(np.array(list(train_features_click.keys())))

In [36]:
device = torch.device('cuda:1')

In [37]:
seq_encoder_trx = torch.nn.Sequential(
    mlm_model.seq_encoder_trx,
    mlm_model,
    FirstStepEncoder(),
    NormEncoder(),
)
seq_encoder_click = torch.nn.Sequential(
    mlm_model.seq_encoder_click,
    mlm_model,
    FirstStepEncoder(),
    NormEncoder(),
)

seq_encoder_trx.to(device)
seq_encoder_click.to(device)
pass

In [38]:
aug_trx = augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024))
aug_click = augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048))

In [39]:
class sample_wrong_idx_for_train:
    def __init__(self, epoch_id, batch_size):
        self.batch_size = batch_size
        
        seq_encoder_trx.eval()
        seq_encoder_click.eval()

        z_trx = []
        batch_size = 512
        for i in tqdm(range(0, len(train_uid_t), batch_size), leave=False, desc='trx'):
            batch = padded_collate_wo_target([aug_trx(train_features_trx[uid]) 
                                              for uid in train_uid_t[i:i + batch_size]])
            with torch.no_grad():
                z = seq_encoder_trx(batch.to(device))
            z_trx.append(z.cpu())
        z_trx = torch.cat(z_trx, dim=0)
        z_click = []
        batch_size = 256
        for i in tqdm(range(0, len(train_uid_c), batch_size), leave=False, desc='click'):
            batch = padded_collate_wo_target([aug_click(train_features_click[uid]) 
                                              for uid in train_uid_c[i:i + batch_size]])
            with torch.no_grad():
                z = seq_encoder_click(batch.to(device))
            z_click.append(z.cpu())
        z_click = torch.cat(z_click, dim=0)

        del batch
        del z
        torch.cuda.empty_cache()

        batch_size = 512
        m_distances = []
        for i in range(0, z_trx.size(0), batch_size):
            m_distances.append(
                ((z_trx[i:i + batch_size].unsqueeze(1).to(device) - 
                  z_click.unsqueeze(0).to(device)
                 ).pow(2)).sum(dim=2).cpu())
        m_distances = torch.cat(m_distances, dim=0)
        torch.cuda.empty_cache()

        T, C = m_distances.size()
        true_dist = m_distances[
            np.searchsorted(train_uid_t, df_match['bank'].values),
            np.searchsorted(train_uid_c, df_match['rtk'].values),
        ]

        m_wrong_pairs = m_distances[np.searchsorted(train_uid_t, df_match['bank'].values)] < true_dist.view(-1, 1)
        self.m_wrong_pairs = m_wrong_pairs
        
        self.true_pos = torch.zeros_like(m_wrong_pairs).bool()
        self.true_pos[
            torch.arange(self.true_pos.size(0)),
            np.searchsorted(train_uid_c, df_match['rtk'].values),
        ] = True
        
        wrong_pair_rate = m_wrong_pairs.float().mean().item()
        wrong_pair_count = m_wrong_pairs.sum().item()
        print(f'[{epoch_id:03d}]: wrong_pair_rate = {wrong_pair_rate:.3f}, wrong_pair_count = {wrong_pair_count:.3f}')
    
    def __iter__(self):
        shuffled_ix = np.arange(self.m_wrong_pairs.size(0))
        np.random.shuffle(shuffled_ix)
        
        for i in range(0, self.m_wrong_pairs.size(0), self.batch_size):
            t_ix = shuffled_ix[i : i + self.batch_size]
            anchor_uid = df_match['bank'].values[t_ix]
            
            neg_mask = ~self.true_pos[t_ix].any(dim=0)
            
            m = self.m_wrong_pairs[t_ix][:, neg_mask].sum(dim=0)
            m = m / m.sum()
            
            neg_ix = torch.multinomial(m, len(anchor_uid))
            neg_uid = train_uid_c[neg_mask][neg_ix]
            
            yield anchor_uid, neg_uid


In [40]:
beta = 10
batch_size = 64
max_epochs = 6000 // (200 // 4)

optim = torch.optim.Adam(mlm_model.parameters(), lr=0.0005 / 25)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optim,
    max_lr=0.0005,
    total_steps=int(len(df_match) * 1 / batch_size / 4 * max_epochs),
    pct_start=2 / 50,
    anneal_strategy='cos',
    cycle_momentum=False,
    div_factor=25.0,
    final_div_factor=10000.0,
    three_phase=False,
)
print(f'scheduler.total_steps = {scheduler.total_steps}')

total_step = 0

for epoch_id in range(max_epochs):
    epoch_loss = 0
    
    seq_encoder_trx.train()
    seq_encoder_click.train()
    shuffle_ix = np.arange(len(df_match))
    np.random.shuffle(shuffle_ix)
    
    if (epoch_id + 1) % 4 == 0:
        sample_wrong_idx_for_train(epoch_id, batch_size)
    
    for batch_idx, i in tqdm(enumerate(range(0, len(shuffle_ix), batch_size)), leave=False, desc='train'):
        x_anchors = padded_collate_wo_target(
            [aug_trx(train_features_trx[uid]) for uid in df_match.iloc[shuffle_ix[i:i+batch_size]]['bank'].values])
        x_positive = padded_collate_wo_target(
            [aug_click(train_features_click[uid]) for uid in df_match.iloc[shuffle_ix[i:i+batch_size]]['rtk'].values])

        z_anchors = seq_encoder_trx(x_anchors.to(device))
        z_positive = seq_encoder_click(x_positive.to(device))
        B = z_anchors.size(0)

        loss = -(z_anchors.unsqueeze(1) - 
                 z_positive.unsqueeze(0)
                ).pow(2).sum(dim=2)
#         print(loss.shape)
        loss = loss * beta
        loss = -torch.log(torch.diag(torch.softmax(loss, dim=1))).mean()
        epoch_loss = epoch_loss + loss.item()

        loss.backward()

        if (batch_idx + 1) % 4 == 0:
            optim.step()
            optim.zero_grad()
            scheduler.step()
            total_step = total_step + 1

    del loss
    del z_anchors
    del z_positive

    torch.cuda.empty_cache()
    
    epoch_loss = epoch_loss / batch_idx + 1
    precision, mrr, r1 = vc.on_train_epoch_end()
    print(f'[{total_step:05d}] epoch_loss = {epoch_loss:.4f} '
          f'precision = {precision:.3f}, mrr = {mrr:.3f}, r1 = {r1:.3f} ')


scheduler.total_steps = 22923


train: 0it [00:00, ?it/s]

[00048] epoch_loss = 5.1631 precision = 0.197, mrr = 0.167, r1 = 0.181 


train: 0it [00:00, ?it/s]

[00096] epoch_loss = 5.1625 precision = 0.200, mrr = 0.167, r1 = 0.182 


train: 0it [00:00, ?it/s]

[00144] epoch_loss = 5.1625 precision = 0.201, mrr = 0.167, r1 = 0.182 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[003]: wrong_pair_rate = 0.490, wrong_pair_count = 73211508.000


train: 0it [00:00, ?it/s]

[00192] epoch_loss = 5.1508 precision = 0.237, mrr = 0.168, r1 = 0.197 


train: 0it [00:00, ?it/s]

[00240] epoch_loss = 4.9979 precision = 0.271, mrr = 0.171, r1 = 0.209 


train: 0it [00:00, ?it/s]

[00288] epoch_loss = 4.8506 precision = 0.295, mrr = 0.172, r1 = 0.217 


train: 0it [00:00, ?it/s]

[00336] epoch_loss = 4.7823 precision = 0.311, mrr = 0.174, r1 = 0.223 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[007]: wrong_pair_rate = 0.256, wrong_pair_count = 38209510.000


train: 0it [00:00, ?it/s]

[00384] epoch_loss = 4.7034 precision = 0.334, mrr = 0.174, r1 = 0.229 


train: 0it [00:00, ?it/s]

[00432] epoch_loss = 4.6938 precision = 0.342, mrr = 0.175, r1 = 0.231 


train: 0it [00:00, ?it/s]

[00480] epoch_loss = 4.6719 precision = 0.347, mrr = 0.176, r1 = 0.233 


train: 0it [00:00, ?it/s]

[00528] epoch_loss = 4.6447 precision = 0.349, mrr = 0.175, r1 = 0.233 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[011]: wrong_pair_rate = 0.225, wrong_pair_count = 33682466.000


train: 0it [00:00, ?it/s]

[00576] epoch_loss = 4.5821 precision = 0.361, mrr = 0.176, r1 = 0.236 


train: 0it [00:00, ?it/s]

[00624] epoch_loss = 4.5989 precision = 0.367, mrr = 0.176, r1 = 0.238 


train: 0it [00:00, ?it/s]

[00672] epoch_loss = 4.5843 precision = 0.374, mrr = 0.177, r1 = 0.241 


train: 0it [00:00, ?it/s]

[00720] epoch_loss = 4.5589 precision = 0.375, mrr = 0.178, r1 = 0.242 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[015]: wrong_pair_rate = 0.206, wrong_pair_count = 30795324.000


train: 0it [00:00, ?it/s]

[00768] epoch_loss = 4.5120 precision = 0.383, mrr = 0.178, r1 = 0.243 


train: 0it [00:00, ?it/s]

[00816] epoch_loss = 4.5099 precision = 0.390, mrr = 0.180, r1 = 0.247 


train: 0it [00:00, ?it/s]

[00864] epoch_loss = 4.4922 precision = 0.395, mrr = 0.180, r1 = 0.248 


train: 0it [00:00, ?it/s]

[00912] epoch_loss = 4.4837 precision = 0.392, mrr = 0.181, r1 = 0.247 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[019]: wrong_pair_rate = 0.189, wrong_pair_count = 28241599.000


train: 0it [00:00, ?it/s]

[00960] epoch_loss = 4.4109 precision = 0.371, mrr = 0.178, r1 = 0.240 


train: 0it [00:00, ?it/s]

[01008] epoch_loss = 4.4631 precision = 0.402, mrr = 0.180, r1 = 0.249 


train: 0it [00:00, ?it/s]

[01056] epoch_loss = 4.4392 precision = 0.403, mrr = 0.180, r1 = 0.249 


train: 0it [00:00, ?it/s]

[01104] epoch_loss = 4.4033 precision = 0.395, mrr = 0.181, r1 = 0.248 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[023]: wrong_pair_rate = 0.173, wrong_pair_count = 25811316.000


train: 0it [00:00, ?it/s]

[01152] epoch_loss = 4.3437 precision = 0.403, mrr = 0.182, r1 = 0.250 


train: 0it [00:00, ?it/s]

[01200] epoch_loss = 4.3677 precision = 0.402, mrr = 0.181, r1 = 0.250 


train: 0it [00:00, ?it/s]

[01248] epoch_loss = 4.3586 precision = 0.409, mrr = 0.181, r1 = 0.251 


train: 0it [00:00, ?it/s]

[01296] epoch_loss = 4.3409 precision = 0.409, mrr = 0.181, r1 = 0.251 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[027]: wrong_pair_rate = 0.159, wrong_pair_count = 23773137.000


train: 0it [00:00, ?it/s]

[01344] epoch_loss = 4.2555 precision = 0.411, mrr = 0.183, r1 = 0.253 


train: 0it [00:00, ?it/s]

[01392] epoch_loss = 4.3026 precision = 0.413, mrr = 0.181, r1 = 0.251 


train: 0it [00:00, ?it/s]

[01440] epoch_loss = 4.3081 precision = 0.405, mrr = 0.181, r1 = 0.250 


train: 0it [00:00, ?it/s]

[01488] epoch_loss = 4.2798 precision = 0.418, mrr = 0.182, r1 = 0.253 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[031]: wrong_pair_rate = 0.146, wrong_pair_count = 21835724.000


train: 0it [00:00, ?it/s]

[01536] epoch_loss = 4.1945 precision = 0.418, mrr = 0.182, r1 = 0.253 


train: 0it [00:00, ?it/s]

[01584] epoch_loss = 4.2214 precision = 0.416, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[01632] epoch_loss = 4.2155 precision = 0.413, mrr = 0.182, r1 = 0.253 


train: 0it [00:00, ?it/s]

[01680] epoch_loss = 4.2048 precision = 0.408, mrr = 0.181, r1 = 0.251 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[035]: wrong_pair_rate = 0.140, wrong_pair_count = 20990129.000


train: 0it [00:00, ?it/s]

[01728] epoch_loss = 4.1232 precision = 0.421, mrr = 0.182, r1 = 0.254 


train: 0it [00:00, ?it/s]

[01776] epoch_loss = 4.1804 precision = 0.420, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[01824] epoch_loss = 4.1671 precision = 0.410, mrr = 0.182, r1 = 0.252 


train: 0it [00:00, ?it/s]

[01872] epoch_loss = 4.1426 precision = 0.424, mrr = 0.184, r1 = 0.256 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[039]: wrong_pair_rate = 0.125, wrong_pair_count = 18681135.000


train: 0it [00:00, ?it/s]

[01920] epoch_loss = 4.0186 precision = 0.407, mrr = 0.180, r1 = 0.250 


train: 0it [00:00, ?it/s]

[01968] epoch_loss = 4.1340 precision = 0.431, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[02016] epoch_loss = 4.0782 precision = 0.429, mrr = 0.184, r1 = 0.258 


train: 0it [00:00, ?it/s]

[02064] epoch_loss = 4.0610 precision = 0.420, mrr = 0.182, r1 = 0.254 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[043]: wrong_pair_rate = 0.122, wrong_pair_count = 18198798.000


train: 0it [00:00, ?it/s]

[02112] epoch_loss = 3.9186 precision = 0.418, mrr = 0.183, r1 = 0.254 


train: 0it [00:00, ?it/s]

[02160] epoch_loss = 4.0160 precision = 0.428, mrr = 0.182, r1 = 0.256 


train: 0it [00:00, ?it/s]

[02208] epoch_loss = 3.9993 precision = 0.414, mrr = 0.182, r1 = 0.253 


train: 0it [00:00, ?it/s]

[02256] epoch_loss = 3.9765 precision = 0.411, mrr = 0.183, r1 = 0.253 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[047]: wrong_pair_rate = 0.107, wrong_pair_count = 16066093.000


train: 0it [00:00, ?it/s]

[02304] epoch_loss = 3.8572 precision = 0.419, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[02352] epoch_loss = 3.9545 precision = 0.431, mrr = 0.186, r1 = 0.260 


train: 0it [00:00, ?it/s]

[02400] epoch_loss = 3.9471 precision = 0.427, mrr = 0.185, r1 = 0.258 


train: 0it [00:00, ?it/s]

[02448] epoch_loss = 3.8969 precision = 0.428, mrr = 0.184, r1 = 0.257 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[051]: wrong_pair_rate = 0.100, wrong_pair_count = 14873759.000


train: 0it [00:00, ?it/s]

[02496] epoch_loss = 3.7699 precision = 0.418, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[02544] epoch_loss = 3.9367 precision = 0.422, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[02592] epoch_loss = 3.8684 precision = 0.424, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[02640] epoch_loss = 3.8241 precision = 0.430, mrr = 0.185, r1 = 0.259 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[055]: wrong_pair_rate = 0.094, wrong_pair_count = 14050756.000


train: 0it [00:00, ?it/s]

[02688] epoch_loss = 3.6555 precision = 0.416, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[02736] epoch_loss = 3.7767 precision = 0.428, mrr = 0.185, r1 = 0.258 


train: 0it [00:00, ?it/s]

[02784] epoch_loss = 3.7679 precision = 0.427, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[02832] epoch_loss = 3.7592 precision = 0.427, mrr = 0.185, r1 = 0.258 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[059]: wrong_pair_rate = 0.086, wrong_pair_count = 12781837.000


train: 0it [00:00, ?it/s]

[02880] epoch_loss = 3.5809 precision = 0.425, mrr = 0.186, r1 = 0.259 


train: 0it [00:00, ?it/s]

[02928] epoch_loss = 3.7250 precision = 0.425, mrr = 0.186, r1 = 0.259 


train: 0it [00:00, ?it/s]

[02976] epoch_loss = 3.7329 precision = 0.423, mrr = 0.185, r1 = 0.257 


train: 0it [00:00, ?it/s]

[03024] epoch_loss = 3.6889 precision = 0.429, mrr = 0.185, r1 = 0.259 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[063]: wrong_pair_rate = 0.080, wrong_pair_count = 12006397.000


train: 0it [00:00, ?it/s]

[03072] epoch_loss = 3.4872 precision = 0.427, mrr = 0.186, r1 = 0.259 


train: 0it [00:00, ?it/s]

[03120] epoch_loss = 3.6563 precision = 0.416, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[03168] epoch_loss = 3.6226 precision = 0.430, mrr = 0.185, r1 = 0.258 


train: 0it [00:00, ?it/s]

[03216] epoch_loss = 3.6058 precision = 0.423, mrr = 0.185, r1 = 0.257 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[067]: wrong_pair_rate = 0.075, wrong_pair_count = 11146929.000


train: 0it [00:00, ?it/s]

[03264] epoch_loss = 3.3906 precision = 0.424, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[03312] epoch_loss = 3.5656 precision = 0.417, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[03360] epoch_loss = 3.5714 precision = 0.429, mrr = 0.185, r1 = 0.258 


train: 0it [00:00, ?it/s]

[03408] epoch_loss = 3.5262 precision = 0.423, mrr = 0.184, r1 = 0.257 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[071]: wrong_pair_rate = 0.070, wrong_pair_count = 10418070.000


train: 0it [00:00, ?it/s]

[03456] epoch_loss = 3.3244 precision = 0.419, mrr = 0.185, r1 = 0.257 


train: 0it [00:00, ?it/s]

[03504] epoch_loss = 3.4965 precision = 0.431, mrr = 0.184, r1 = 0.258 


train: 0it [00:00, ?it/s]

[03552] epoch_loss = 3.4848 precision = 0.434, mrr = 0.186, r1 = 0.260 


train: 0it [00:00, ?it/s]

[03600] epoch_loss = 3.4609 precision = 0.425, mrr = 0.183, r1 = 0.256 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[075]: wrong_pair_rate = 0.064, wrong_pair_count = 9601073.000


train: 0it [00:00, ?it/s]

[03648] epoch_loss = 3.2125 precision = 0.435, mrr = 0.185, r1 = 0.260 


train: 0it [00:00, ?it/s]

[03696] epoch_loss = 3.4482 precision = 0.425, mrr = 0.185, r1 = 0.258 


train: 0it [00:00, ?it/s]

[03744] epoch_loss = 3.4129 precision = 0.418, mrr = 0.185, r1 = 0.256 


train: 0it [00:00, ?it/s]

[03792] epoch_loss = 3.4056 precision = 0.417, mrr = 0.185, r1 = 0.257 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[079]: wrong_pair_rate = 0.059, wrong_pair_count = 8791787.000


train: 0it [00:00, ?it/s]

[03840] epoch_loss = 3.1272 precision = 0.422, mrr = 0.183, r1 = 0.256 


train: 0it [00:00, ?it/s]

[03888] epoch_loss = 3.3479 precision = 0.414, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[03936] epoch_loss = 3.3489 precision = 0.422, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[03984] epoch_loss = 3.3421 precision = 0.420, mrr = 0.184, r1 = 0.256 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[083]: wrong_pair_rate = 0.058, wrong_pair_count = 8637039.000


train: 0it [00:00, ?it/s]

[04032] epoch_loss = 3.0560 precision = 0.424, mrr = 0.183, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04080] epoch_loss = 3.3087 precision = 0.417, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04128] epoch_loss = 3.3099 precision = 0.433, mrr = 0.186, r1 = 0.260 


train: 0it [00:00, ?it/s]

[04176] epoch_loss = 3.2714 precision = 0.417, mrr = 0.184, r1 = 0.256 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[087]: wrong_pair_rate = 0.052, wrong_pair_count = 7797385.000


train: 0it [00:00, ?it/s]

[04224] epoch_loss = 3.0075 precision = 0.423, mrr = 0.186, r1 = 0.258 


train: 0it [00:00, ?it/s]

[04272] epoch_loss = 3.2530 precision = 0.419, mrr = 0.185, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04320] epoch_loss = 3.2063 precision = 0.424, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[04368] epoch_loss = 3.2084 precision = 0.413, mrr = 0.185, r1 = 0.255 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[091]: wrong_pair_rate = 0.050, wrong_pair_count = 7525797.000


train: 0it [00:00, ?it/s]

[04416] epoch_loss = 2.9317 precision = 0.412, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[04464] epoch_loss = 3.2361 precision = 0.422, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04512] epoch_loss = 3.1850 precision = 0.416, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[04560] epoch_loss = 3.1506 precision = 0.419, mrr = 0.185, r1 = 0.257 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[095]: wrong_pair_rate = 0.048, wrong_pair_count = 7121373.000


train: 0it [00:00, ?it/s]

[04608] epoch_loss = 2.8765 precision = 0.423, mrr = 0.183, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04656] epoch_loss = 3.1206 precision = 0.420, mrr = 0.185, r1 = 0.257 


train: 0it [00:00, ?it/s]

[04704] epoch_loss = 3.1034 precision = 0.421, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04752] epoch_loss = 3.1063 precision = 0.414, mrr = 0.184, r1 = 0.254 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[099]: wrong_pair_rate = 0.044, wrong_pair_count = 6568136.000


train: 0it [00:00, ?it/s]

[04800] epoch_loss = 2.7987 precision = 0.413, mrr = 0.185, r1 = 0.256 


train: 0it [00:00, ?it/s]

[04848] epoch_loss = 3.0688 precision = 0.426, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[04896] epoch_loss = 3.0415 precision = 0.418, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[04944] epoch_loss = 3.0415 precision = 0.416, mrr = 0.184, r1 = 0.255 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[103]: wrong_pair_rate = 0.043, wrong_pair_count = 6496203.000


train: 0it [00:00, ?it/s]

[04992] epoch_loss = 2.7402 precision = 0.415, mrr = 0.185, r1 = 0.256 


train: 0it [00:00, ?it/s]

[05040] epoch_loss = 2.9875 precision = 0.417, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05088] epoch_loss = 3.0240 precision = 0.417, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05136] epoch_loss = 3.0324 precision = 0.417, mrr = 0.183, r1 = 0.254 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[107]: wrong_pair_rate = 0.041, wrong_pair_count = 6149530.000


train: 0it [00:00, ?it/s]

[05184] epoch_loss = 2.6998 precision = 0.412, mrr = 0.185, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05232] epoch_loss = 2.9962 precision = 0.412, mrr = 0.184, r1 = 0.254 


train: 0it [00:00, ?it/s]

[05280] epoch_loss = 2.9810 precision = 0.423, mrr = 0.184, r1 = 0.257 


train: 0it [00:00, ?it/s]

[05328] epoch_loss = 2.9370 precision = 0.412, mrr = 0.182, r1 = 0.253 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[111]: wrong_pair_rate = 0.039, wrong_pair_count = 5798155.000


train: 0it [00:00, ?it/s]

[05376] epoch_loss = 2.6586 precision = 0.414, mrr = 0.184, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05424] epoch_loss = 2.9183 precision = 0.422, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05472] epoch_loss = 2.9110 precision = 0.418, mrr = 0.183, r1 = 0.255 


train: 0it [00:00, ?it/s]

[05520] epoch_loss = 2.8932 precision = 0.424, mrr = 0.183, r1 = 0.256 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[115]: wrong_pair_rate = 0.036, wrong_pair_count = 5453418.000


train: 0it [00:00, ?it/s]

[05568] epoch_loss = 2.5851 precision = 0.418, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[05616] epoch_loss = 2.8648 precision = 0.421, mrr = 0.184, r1 = 0.256 


train: 0it [00:00, ?it/s]

[05664] epoch_loss = 2.8573 precision = 0.411, mrr = 0.183, r1 = 0.253 


train: 0it [00:00, ?it/s]

[05712] epoch_loss = 2.9054 precision = 0.418, mrr = 0.183, r1 = 0.254 


trx:   0%|          | 0/29 [00:00<?, ?it/s]

click:   0%|          | 0/48 [00:00<?, ?it/s]

[119]: wrong_pair_rate = 0.036, wrong_pair_count = 5361652.000


train: 0it [00:00, ?it/s]

[05760] epoch_loss = 2.5363 precision = 0.413, mrr = 0.183, r1 = 0.254 
