## Data load

In [1]:
import os

if not os.path.exists('data/transactions_train.csv'):
    ! mkdir -p data
    ! curl -OL https://storage.yandexcloud.net/di-datasets/age-prediction-nti-sbebank-2019.zip
    ! unzip -j -o age-prediction-nti-sbebank-2019.zip 'data/*.csv' -d data
    ! mv age-prediction-nti-sbebank-2019.zip data/

## Setup

In [2]:
%load_ext autoreload
%autoreload 2

import logging
import pytorch_lightning as pl
import warnings

warnings.filterwarnings('ignore')
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

## Data preproccessing

In [3]:
import os
import pandas as pd

data_path = 'data/'

source_data = pd.read_csv(os.path.join(data_path, 'transactions_train.csv'))
source_data.head(2)

Unnamed: 0,client_id,trans_date,small_group,amount_rur
0,33172,6,4,71.463
1,33172,6,35,45.017


In [4]:
from ptls.data_preprocessing import PandasDataPreprocessor

preprocessor = PandasDataPreprocessor(
    col_id='client_id',
    cols_event_time='trans_date',
    time_transformation='float',
    cols_category=["trans_date", "small_group"],
    cols_log_norm=["amount_rur"],
    cols_identity=[],
    print_dataset_info=False,
)


In [5]:
%%time

dataset = preprocessor.fit_transform(source_data)

CPU times: user 1min 5s, sys: 11.6 s, total: 1min 17s
Wall time: 1min 17s


In [6]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(dataset, test_size=0.2, random_state=42)

len(train), len(test)

(24000, 6000)

In [7]:
# Load targets:

df_target = pd.read_csv(os.path.join(data_path, 'train_target.csv'))
df_target.set_index('client_id', inplace=True)
df_target.rename(columns={"bins": "target"}, inplace=True)
df_target.head(5)

Unnamed: 0_level_0,target
client_id,Unnamed: 1_level_1
24662,2
1046,0
34089,2
34848,1
47076,3


In [8]:
# Add targets to train and test:

print(train[0].keys())

for el in train:
    el['target'] = df_target['target'][el['client_id']]
for el in test:
    el['target'] = df_target['target'][el['client_id']]

print(train[0].keys())

dict_keys(['client_id', 'trans_date', 'small_group', 'amount_rur', 'event_time'])
dict_keys(['client_id', 'trans_date', 'small_group', 'amount_rur', 'event_time', 'target'])


## FineTuning

### load SequenceEncoder obtained from `coles-emb.ipynb`

In [9]:
import torch
from ptls.seq_encoder import SequenceEncoder
from ptls.models import Head
from ptls.lightning_modules.emb_module import EmbModule

seq_encoder = SequenceEncoder(
    category_features=preprocessor.get_category_sizes(),
    numeric_features=["amount_rur"],
    trx_embedding_noize=0.003
)

head = Head(input_size=seq_encoder.embedding_size, use_norm_encoder=True)

model = EmbModule(seq_encoder=seq_encoder, head=head)

model.load_state_dict(torch.load('coles-emb.pt'))
model.eval()

### model

In [10]:
import copy
from ptls.seq_to_target_demo import SeqToTargetDemo


pretrained_encoder = copy.deepcopy(model.seq_encoder)

downstream_model = SeqToTargetDemo(
    pretrained_encoder,
    encoder_lr=0.0001,
    in_features=pretrained_encoder.embedding_size,
    out_features=4,
    head_lr=0.02,
    weight_decay=0.0,
    lr_scheduler_step_size=10,
    lr_scheduler_step_gamma=0.2)


### Data module

In [11]:
from ptls.data_load.data_module.seq_to_target_data_module import SeqToTargetDatamodule


finetune_dm = SeqToTargetDatamodule(
    dataset=train,
    pl_module=downstream_model,
    min_seq_len=0,
    valid_size=0.05,
    train_num_workers=0,
    train_batch_size=256,
    valid_num_workers=0,
    valid_batch_size=256,
    target_col='target',
    random_state=42)


### Trainer FineTuning

In [12]:
trainer_ft = pl.Trainer(
    max_epochs=4,
    gpus=1 if torch.cuda.is_available() else 0
)

### Training FineTuning

In [13]:
trainer_ft.fit(downstream_model, finetune_dm)

### Testing

In [14]:
from torch.utils.data import DataLoader
from ptls.data_load import padded_collate


test_dataset = list(finetune_dm.post_proc(iter(test)))

test_dataloader = DataLoader(dataset=test_dataset,
                             collate_fn=padded_collate,
                             num_workers=0,
                             batch_size=128)

trainer_ft.test(dataloaders=test_dataloader)

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.6178333163261414}
--------------------------------------------------------------------------------


[{'test_accuracy': 0.6178333163261414}]