In [1]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

In [2]:
%load_ext autoreload
%autoreload 2

# Model definition

In [3]:
import hydra
from omegaconf import OmegaConf
import torch

conf = OmegaConf.load('config/coles.yaml')
model = hydra.utils.instantiate(conf.pl_module)
model.load_state_dict(torch.load("models/coles.p"))

<All keys matched successfully>

# Finetune

In [4]:
from glob import glob
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.data_load.datasets import MemoryMapDataset
from ptls.frames.supervised import SeqToTargetDataset
from tqdm.auto import tqdm

from ptls.data_load import IterableChain
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.utils import collate_feature_dict
from ptls.frames import PtlsDataModule

train_data = glob('data/train_transactions_clipped.parquet')
valid_data = glob('data/valid_transactions_clipped.parquet')

feature_cols = list(conf.pl_module.seq_encoder.trx_encoder.embeddings.keys()) + \
               list(conf.pl_module.seq_encoder.trx_encoder.numeric_values.keys())

dataset_conf = {
    'min_seq_len':25,
    }


process = IterableChain(
            SeqLenFilter(min_seq_len=dataset_conf['min_seq_len']),
            FeatureFilter(keep_feature_names=feature_cols + ['flag']),
            ToTorch()
            )
    
def get_dataset(data):
    ds = MemoryMapDataset(ParquetDataset(data, post_processing=process))
    return SeqToTargetDataset(ds, target_col_name='flag')

train_ds = get_dataset(train_data)
valid_ds = get_dataset(valid_data)

dm = PtlsDataModule(
    train_data=train_ds,
    valid_data=valid_ds,
    train_num_workers=4,
    train_batch_size=64)



In [5]:
from functools import partial
import torch
import torchmetrics
from ptls.frames.supervised import SequenceToTarget
from ptls.nn import Head

model_e2e = SequenceToTarget(
    seq_encoder=model.seq_encoder,
    head=Head(
        input_size=model.seq_encoder.embedding_size,
        use_batch_norm=True,
        objective='classification',
        num_classes=2,
    ),
    loss=torch.nn.NLLLoss(),
    metric_list=torchmetrics.Accuracy(compute_on_step=False),
    pretrained_lr=0.00001,
    optimizer_partial=partial(torch.optim.Adam, lr=0.001, weight_decay=1e-5),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.9),
)


In [6]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

trainer_params = conf.trainer




trainer_params = conf.trainer
trainer_params['max_epochs'] = 5
callbacks = [ModelCheckpoint(every_n_epochs=5, save_top_k=-1), LearningRateMonitor(logging_interval='step')]
logger = TensorBoardLogger(save_dir='lightning_logs', name=conf.get('logger_name'))

print(OmegaConf.to_yaml(trainer_params))

trainer = pl.Trainer(**trainer_params, callbacks=callbacks, logger=logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


gpus: 1
auto_select_gpus: false
max_epochs: 5
deterministic: true



In [7]:
%%time
print(f'logger.version = {trainer.logger.version}')
trainer.fit(model_e2e, dm)
print(trainer.logged_metrics)

logger.version = 31


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type          | Params
------------------------------------------------
0 | seq_encoder   | RnnSeqEncoder | 284 K 
1 | head          | Head          | 1.0 K 
2 | loss          | NLLLoss       | 0     
3 | train_metrics | ModuleDict    | 0     
4 | valid_metrics | ModuleDict    | 0     
5 | test_metrics  | ModuleDict    | 0     
------------------------------------------------
285 K     Trainable params
0         Non-trainable params
285 K     Total params
1.142     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{'loss': tensor(0.1038), 'seq_len': tensor(89.4000), 'y': tensor(0.0182), 'val_loss': tensor(0.1063), 'val_Accuracy': tensor(0.9739), 'train_Accuracy': tensor(0.9714)}
CPU times: user 11min 14s, sys: 54.7 s, total: 12min 9s
Wall time: 12min 49s


In [8]:
torch.save(model_e2e.state_dict(), "models/rnn-e2e-pd.pt")

# Infernece

In [9]:
model_e2e.load_state_dict(torch.load("models/rnn-e2e-pd.pt"))

<All keys matched successfully>

In [10]:
# %%time
import tqdm

def inference(model, dl, device='cuda:0'):
    
    model.to(device)
    X = []
    for batch in tqdm.tqdm(dl):
        with torch.no_grad():
            features = batch[0]
            targets = [batch[1].to(device).unsqueeze(dim=1)]
            x = model(features.to(device))
            flag = x[:, 1].unsqueeze(dim=1)
            predicted = [flag]
            X += [torch.cat(predicted + targets, dim=1)]
    return X


valid_dl = torch.utils.data.DataLoader(dataset=valid_ds, 
                                       collate_fn=valid_ds.collate_fn,
                                       num_workers=8,
                                       batch_size=64)

In [11]:
preds = torch.vstack(inference(model_e2e, valid_dl)).cpu().numpy()

100%|██████████████████████████████████████████████████████████████████████████████████████| 406/406 [00:08<00:00, 47.58it/s]


In [12]:
import numpy as np

df_valid = pd.DataFrame(preds, columns = ['predicted_flag', 'flag'])
df_valid.head()

Unnamed: 0,predicted_flag,flag
0,-2.25074,0.0
1,-3.428023,0.0
2,-1.536288,1.0
3,-5.059521,0.0
4,-5.243661,0.0


In [13]:
from sklearn.metrics import roc_auc_score

print("Roc AUC score:", {roc_auc_score(df_valid['flag'],  df_valid['predicted_flag'])})

Roc AUC score: {0.7761719920016419}
