In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import logging
import pytorch_lightning as pl
import warnings

warnings.filterwarnings('ignore')
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

## Data load and preprocessing

In [17]:
%env DATA_DIR=../../data2
! mkdir $DATA_DIR
! curl -OL https://storage.googleapis.com/di-datasets/age-prediction-nti-sbebank-2019.zip
! unzip -j -o age-prediction-nti-sbebank-2019.zip 'data/*.csv' -d $DATA_DIR
! mv age-prediction-nti-sbebank-2019.zip $DATA_DIR

!python ../make_datasets.py \
    --data_path $DATA_DIR\
    --trx_files transactions_train.csv transactions_test.csv \
    --col_client_id "client_id" \
    --cols_event_time "#float" "trans_date" \
    --cols_category "trans_date" "small_group" \
    --cols_log_norm "amount_rur" \
    --target_files train_target.csv \
    --col_target bins \
    --test_size 0.1 \
    --output_train_path "$DATA_DIR/train_trx.p" \
    --output_test_path "$DATA_DIR/test_trx.p" \
    --output_test_ids_path "$DATA_DIR/test_ids.csv" \
    --log_file "$DATA_DIR/dataset_age_pred.log"

env: DATA_DIR=../../data2
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  239M  100  239M    0     0   171M      0  0:00:01  0:00:01 --:--:--  171M
Archive:  age-prediction-nti-sbebank-2019.zip
  inflating: ../../data2/test.csv    
  inflating: ../../data2/small_group_description.csv  
  inflating: ../../data2/train_target.csv  
  inflating: ../../data2/transactions_train.csv  
  inflating: ../../data2/transactions_test.csv  
load_source_data       : Loaded 26450577 rows from "/mnt/mikheev/data2/transactions_train.csv"
load_source_data       : Loaded 17667328 rows from "/mnt/mikheev/data2/transactions_test.csv"
load_source_data       : Loaded 44117905 rows in total
trx_to_features        : Found 50000 unique clients
_td_float              : To-float time transformation
trx_to_features        : Encoder stat for "trans_date":
codes | trx_count
                cnt  % of total

## Embedding training

Model training in our framework organised via pytorch-lightning (pl) framework.
The key parts of neural networks training in pl are: 

    * model (pl.LightningModule)
    * data_module (pl.LightningDataModule)
    * pl.trainer (pl.trainer)
    
For futher details check https://www.pytorchlightning.ai/

### model 

In [30]:
from pyhocon import ConfigFactory
from dltranz.lightning_modules.coles_module import CoLESModule


model_params = '''
    validation_metric_params: {
      K: 4,
      metric: cosine
    },
    
    encoder_type: rnn,
    rnn: {
      type: gru,
      hidden_size: 160,
      bidir: false,
      trainable_starter: static
    },
    
    head_layers: [
      [Linear, {"in_features": "{seq_encoder.embedding_size}", "out_features": 256}],
      [BatchNorm1d, {num_features: 256}],
      [ReLU, {}],
      [Linear, {"in_features": 256, "out_features": 256}],
      [NormEncoder, {}]
    ],
    
    transf: {
      shared_layers: false,
      input_size: 64,
      n_heads: 2,
      dim_hidden: 64,
      dropout: 0.1,
      n_layers: 4,
      max_seq_len: 1,
      use_after_mask: false,
      use_positional_encoding: false,
    },

    trx_encoder: {
      norm_embeddings: false,
      embeddings_noise: 0.003
      embeddings: {
        trans_date: {in: 800, out: 16}
        small_group: {in: 250, out: 16}
      }
      numeric_values: {
        amount_rur: identity
      }
    }

    lr_scheduler: {
      step_size: 30
      step_gamma: 0.9025
    }

    train: {
      sampling_strategy: HardNegativePair
      neg_count: 5
      loss: ContrastiveLoss
      margin: 0.5
      lr: 0.002
      weight_decay: 0.0
    }
'''


model = CoLESModule(ConfigFactory.parse_string(model_params))

### Data module

In [32]:
from dltranz.data_load.data_module.coles_data_module import ColesDataModuleTrain

dm_params = '''
    type: map
    setup: {
      col_id: client_id
      dataset_files: {
        data_path: "data/train_trx_file.parquet"
      }
      split_by: files
      valid_size: 0.1
      valid_split_seed: 42
    }
    train: {
      min_seq_len: 25
      augmentations: [
        [DropoutTrx, {trx_dropout: 0.01}]
      ]
      buffer_size: 512
      split_strategy: {
        split_strategy: "SampleSlices"
        split_count: 5
        cnt_min: 25
        cnt_max: 600
      }
      num_workers: 16
      batch_size: 512
    }
    valid: {
      augmentations: []
      split_strategy: {
        split_strategy: SampleSlices
        split_count: 5
        cnt_min: 25
        cnt_max: 100
      }
      num_workers: 16
      batch_size: 1024
    }
'''

dm = ColesDataModuleTrain(ConfigFactory.parse_string(dm_params), model)