In [18]:
import torch

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
print('torch version (<2): ', torch.__version__)

import ptls
from ptls.frames import PtlsDataModule
from ptls.frames.coles import losses, sampling_strategies
from ptls.frames.coles import split_strategy
from ptls.frames.inference_module import InferenceModule

from ptls.nn.seq_encoder import agg_feature_seq_encoder
from ptls.nn import RnnSeqEncoder, TrxEncoder, Head
from ptls.nn.seq_encoder.agg_feature_seq_encoder import AggFeatureSeqEncoder

from ptls.data_load.datasets import AugmentationDataset, MemoryMapDataset
from ptls.data_load.augmentations import AllTimeShuffle, DropoutTrx
from ptls.data_load.datasets.parquet_dataset import ParquetDataset
from ptls.data_load.iterable_processing import SeqLenFilter, FeatureFilter
from ptls.data_load.datasets import parquet_file_scan
from ptls.data_load.datasets import ParquetDataset, ParquetFiles, AugmentationDataset
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter, FeatureFilter
from ptls.data_load.augmentations import DropoutTrx
from ptls.data_load.datasets import inference_data_loader
from ptls.data_load.utils import collate_feature_dict

from functools import partial

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

torch version (<2):  1.12.1+cu102


In [19]:
plt.style.use('default')
sns.set(rc={'figure.figsize':(8, 6)})
sns.set_style('white')
sns.despine()

<Figure size 800x600 with 0 Axes>

In [20]:
data_module = PtlsDataModule(
    train_data=ptls.frames.coles.ColesDataset(
        splitter=split_strategy.SampleSlices(split_count=5, cnt_min=25, cnt_max=180),
        data=ptls.data_load.datasets.AugmentationDataset(
            data=MemoryMapDataset(
                data=ParquetDataset(
                    i_filters=[SeqLenFilter(min_seq_len=25)],
                    data_files=parquet_file_scan(file_path='src/ptls-experiments/scenario_x5/data/train_trx.parquet',
                                                 valid_rate=0.05,
                                                 return_part='train')
                )
            
            ),
            f_augmentations=[ptls.data_load.augmentations.DropoutTrx(trx_dropout=0.01)]
        )
    ),
    train_batch_size=256,
    train_num_workers=8,
)

In [21]:
import torch
from torch import nn as nn
from torch.nn import functional as F

class L2NormEncoder(nn.Module):
    def __init__(self, eps=1e-9):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor):
        return x / (x.pow(2).sum(dim=-1, keepdim=True) + self.eps).pow(0.5)

class ContrastiveLoss(nn.Module):
    """
    Contrastive loss

    "Signature verification using a siamese time delay neural network", NIPS 1993
    https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf
    """

    def __init__(self, margin, sampling_strategy):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.pair_selector = sampling_strategy
        
        self.norm = L2NormEncoder()

    def forward(self, embeddings, target):
        embeddings = self.norm(embeddings)

        positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target)
        positive_loss = F.pairwise_distance(embeddings[positive_pairs[:, 0]], embeddings[positive_pairs[:, 1]]).pow(2)
        negative_loss = F.relu(
            self.margin - F.pairwise_distance(embeddings[negative_pairs[:, 0]], embeddings[negative_pairs[:, 1]])
        ).pow(2)
        loss = torch.cat([positive_loss, negative_loss], dim=0)
        
        return loss.sum()

In [22]:
class VICRegLoss(nn.Module):
    
    def __init__(self, ):
        super(VICRegLoss, self).__init__()
        
        self._agg_encoder = AggFeatureSeqEncoder(         
            embeddings={
                "level_3": {"in": 200},
                "level_4": {"in": 800},
                "segment_id": {"in": 120},
            },

            numeric_values={
                'trn_sum_from_iss': 'identity',
                'netto': 'identity',
                'regular_points_received': 'identity',            
            },
     
            was_logified=True,  
            log_scale_factor=1
        )
        self.norm = L2NormEncoder()
        
    def forward(self, embeddings, aggs):
        aggs = self._agg_encoder(aggs)
        aggs = self.norm(aggs.T)

        cov_aggs_embs = (aggs @ embeddings) / len(embeddings)
        cov_loss = cov_aggs_embs.pow_(2).sum()
        
        std_embeddings = torch.sqrt(embeddings.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_embeddings))
        
        return (cov_loss, std_loss)

In [23]:
class Loss(nn.Module):
    def __init__(self, contrastiveLoss, vicregLoss):
        super(Loss, self).__init__()
        
        self.contrastiveLoss = contrastiveLoss
        self.vicregLoss = vicregLoss
        
    def forward(self, embeddings, target, aggs):
        
        (cov_loss, std_loss) = self.vicregLoss(embeddings, aggs)
        con_loss = self.contrastiveLoss.forward(embeddings, target)
        
        return ((con_loss, cov_loss, std_loss), 0.55 * con_loss + 1 * cov_loss + 1 * std_loss)

In [24]:
import torch
import pytorch_lightning as pl
from ptls.data_load.padded_batch import PaddedBatch


class ABSModule(pl.LightningModule):
    @property
    def metric_name(self):
        raise NotImplementedError()

    @property
    def is_requires_reduced_sequence(self):
        raise NotImplementedError()

    def shared_step(self, x, y):
        """

        Args:
            x:
            y:

        Returns: y_h, y

        """
        raise NotImplementedError()

    def __init__(self, validation_metric=None,
                       seq_encoder=None,
                       loss=None,
                       optimizer_partial=None,
                       lr_scheduler_partial=None):
        """
        Parameters
        ----------
        params : dict
            params for creating an encoder
        seq_encoder : torch.nn.Module
            sequence encoder, if not provided, will be constructed from params
        """
        super().__init__()
        # self.save_hyperparameters()

        self._loss = loss
        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = self.is_requires_reduced_sequence
        self._validation_metric = validation_metric

        self._optimizer_partial = optimizer_partial
        self._lr_scheduler_partial = lr_scheduler_partial
        
    @property
    def seq_encoder(self):
        return self._seq_encoder

    def forward(self, x):
        return self._seq_encoder(x)

    def training_step(self, batch, _):
        
        y_h, y = self.shared_step(*batch)
                 
        (con_loss, cov_loss, std_loss), loss = self._loss(y_h, y, batch[0])

        self.log('con_loss', con_loss)
        self.log('cov_loss', cov_loss)
        self.log('std_loss', std_loss)
        self.log('loss', loss)

        if type(batch) is tuple:
            x, y = batch
            if isinstance(x, PaddedBatch):
                self.log('seq_len', x.seq_lens.float().mean(), prog_bar=True)
        else:
            # this code should not be reached
            self.log('seq_len', -1, prog_bar=True)
            raise AssertionError('batch is not a tuple')
        return loss

    def validation_step(self, batch, _):
        y_h, y = self.shared_step(*batch)
        self._validation_metric(y_h, y)

    def validation_epoch_end(self, outputs):
        self.log(self.metric_name, self._validation_metric.compute(), prog_bar=True)
        self._validation_metric.reset()

    def configure_optimizers(self):
        optimizer = self._optimizer_partial(self.parameters())
        scheduler = self._lr_scheduler_partial(optimizer)
        
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler = {
                'scheduler': scheduler,
                'monitor': self.metric_name,
            }
        return [optimizer], [scheduler]

In [25]:
from ptls.frames.coles.metric import BatchRecallTopK
from ptls.frames.coles.sampling_strategies import HardNegativePairSelector
from ptls.nn.head import Head
from ptls.nn.seq_encoder.containers import SeqEncoderContainer


class CoLESModule(ABSModule):
    def __init__(self,
                 seq_encoder: SeqEncoderContainer = None,
                 head=None,
                 loss=None,
                 validation_metric=None,
                 optimizer_partial=None,
                 lr_scheduler_partial=None):

        if head is None:
            head = Head(use_norm_encoder=True)

        if loss is None:
            loss = ContrastiveLoss(margin=0.5,
                                   sampling_strategy=HardNegativePairSelector(neg_count=5))

        if validation_metric is None:
            validation_metric = BatchRecallTopK(K=4, metric='cosine')

        super().__init__(validation_metric,
                         seq_encoder,
                         loss,
                         optimizer_partial,
                         lr_scheduler_partial
                        )

        self._head = head
    @property
    def metric_name(self):
        return 'recall_top_k'

    @property
    def is_requires_reduced_sequence(self):
        return True

    def shared_step(self, x, y):
        
        y_h = self(x)
        if self._head is not None:
            y_h = self._head(y_h)
        return y_h, y

In [None]:
model = CoLESModule(
      validation_metric=ptls.frames.coles.metric.BatchRecallTopK(K=4,
                                                                 metric="cosine"),
      seq_encoder=RnnSeqEncoder(
            trx_encoder=TrxEncoder(
            use_batch_norm_with_lens=True,
            norm_embeddings=False,
            embeddings_noise=0.003,
            
            embeddings={
                "level_3": {"in": 200, "out": 16},
                "level_4": {"in": 800, "out": 16},
                "segment_id": {"in": 120, "out": 16},
            },

            numeric_values={
                "trn_sum_from_iss": "identity",
                "netto": "identity",
                "regular_points_received": "identity",
            }
                
            ),
            type="gru",
            hidden_size=800,
            bidir=False,
            trainable_starter="static",
      ),
     
      head=Head(
            use_norm_encoder=False,
            input_size=800,
      ),
        
      loss=Loss(
          ContrastiveLoss(
            margin=0.5,
            sampling_strategy=sampling_strategies.HardNegativePairSelector(neg_count=9),
          ),
          VICRegLoss()
      ),
    
      optimizer_partial=partial(
            torch.optim.Adam, 
            lr=0.002,
            weight_decay=0.0
      ),
    
      lr_scheduler_partial=partial(
            torch.optim.lr_scheduler.StepLR,
            step_size=3,
            gamma=0.9025,
      ),
)

logger = TensorBoardLogger('src/ptls-experiments/scenario_age_pred/lightning_logs',
                           name='CoLES VICReg, hidden=800, coefs=0.55,1,1, margin=1')


trainer = pl.Trainer(
    logger=logger,
    num_sanity_val_steps=0,
#     gpus=1,
    accelerator="gpu" ,
#     auto_select_gpus=False,
    max_epochs=30,
    enable_checkpointing=False,
    deterministic=True,
)

trainer.fit(model, data_module)
print(trainer.logged_metrics)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type            | Params
-------------------------------------------------------
0 | _loss              | Loss            | 0     
1 | _seq_encoder       | RnnSeqEncoder   | 886 K 
2 | _validation_metric | BatchRecallTopK | 0     
3 | _head              | Head            | 0     
-------------------------------------------------------
886 K     Trainable params
0         Non-trainable params
886 K     Total params
3.545     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [28]:
torch.save(model.state_dict(), 'src/ptls-experiments/scenario_x5/model-vicreg-new-coles-pretrained-512.pth')

In [29]:
model.load_state_dict(torch.load('src/ptls-experiments/scenario_x5/model-vicreg-new-coles-pretrained-512.pth'))

<All keys matched successfully>

In [31]:
from ptls.data_load.utils import collate_feature_dict
from ptls.frames.inference_module import InferenceModule
iterable_inference_dataset = ParquetDataset(
    data_files=ParquetFiles(['src/ptls-experiments/scenario_x5/data/train_trx.parquet',
                             'src/ptls-experiments/scenario_x5/data/test_trx.parquet'],                             
                                                         
                            ).data_files,
    i_filters=[FeatureFilter(['client_id'])],
)
next(iter(iterable_inference_dataset))

inference_dl = torch.utils.data.DataLoader(
    dataset=iterable_inference_dataset,
    collate_fn=collate_feature_dict,
    shuffle=False,
    batch_size=128,
    num_workers=0,
)
next(iter(inference_dl)).payload

mod = InferenceModule(model, pandas_output=True, model_out_name='emb')

pred = pl.Trainer(gpus=1).predict(mod, inference_dl)

embeddings_train_test = pd.concat(pred, axis=0)

  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [32]:
embeddings_train_test

Unnamed: 0,client_id,emb_0000,emb_0001,emb_0002,emb_0003,emb_0004,emb_0005,emb_0006,emb_0007,emb_0008,...,emb_0502,emb_0503,emb_0504,emb_0505,emb_0506,emb_0507,emb_0508,emb_0509,emb_0510,emb_0511
0,00037a9650,-0.010576,-0.002334,-0.009176,-0.023901,0.067811,-0.040258,-0.022467,0.032256,0.052724,...,0.016382,-0.041395,-0.020993,-0.012628,-0.008621,-0.034619,0.058045,0.041897,-0.000332,-0.000669
1,0005ce475d,0.014146,-0.043482,-0.006340,-0.001400,0.086435,0.000150,-0.042507,0.039219,0.002245,...,0.019000,-0.051069,-0.021606,0.007528,-0.013405,-0.052815,0.080950,-0.011690,-0.005895,-0.008232
2,000c6e91c8,0.007199,0.005476,-0.005888,0.023478,0.085967,0.091710,-0.034176,0.032222,-0.032754,...,0.014419,-0.045711,-0.011153,-0.031864,0.007012,-0.058724,0.046233,0.018627,-0.005846,-0.002067
3,00183b30d3,0.037918,-0.005334,-0.001674,-0.000382,0.040750,-0.014043,-0.013233,0.017424,-0.003058,...,0.010626,-0.029577,-0.010556,-0.004739,0.002030,-0.028003,0.001890,0.010655,-0.005744,0.004546
4,0036546652,0.004386,0.000443,-0.005890,0.254002,0.040160,0.122409,-0.028133,-0.006420,-0.022485,...,0.017634,-0.049006,-0.011241,-0.001659,-0.013277,-0.300536,0.087789,0.014445,-0.005073,-0.008319
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29,ff4a662c34,0.010093,-0.008350,-0.004891,-0.008193,0.041547,-0.025976,-0.027283,0.035217,0.025977,...,0.011494,-0.033416,-0.008551,-0.003501,-0.000127,-0.028954,0.057420,-0.002152,-0.002127,-0.005268
30,ff6c9ba277,0.006813,0.003591,-0.006365,-0.033015,0.049616,-0.020507,-0.034578,0.068770,-0.112935,...,0.016795,-0.048137,-0.040577,-0.023485,-0.000156,-0.046340,0.070769,0.033496,0.000864,-0.001345
31,ff90033a39,0.005974,-0.026335,-0.006777,-0.035574,0.052344,-0.026809,-0.022433,-0.002919,-0.070382,...,0.020252,-0.054980,0.020380,-0.009821,-0.031865,-0.084448,0.060043,0.016959,-0.008727,-0.004619
32,ffbd796a83,0.017642,-0.006526,-0.002039,-0.003330,0.048242,-0.026711,-0.027495,0.010101,-0.019942,...,0.010821,-0.033886,0.022008,-0.037607,0.007756,-0.046366,0.036116,0.023412,-0.003136,-0.005165


In [33]:
embeddings_train_test.to_pickle('src/ptls-experiments/scenario_x5/data/vicreg_coles_embeddings-512.pickle')