In [86]:
import os
from typing import Tuple

import numpy as np
import pandas as pd

from tqdm.auto import tqdm

import torch
from torch import nn

from sklearn.preprocessing import KBinsDiscretizer

from utils.data_utils import global_context, split_data, weekends, global_context_emb_avg
from utils.config_utils import DataConf, ModelConf, ClassificationParamsConf
from datamodules import TransactionRNNDataModule
from datamodules.preprocessing import data_preprocessing
from utils.config_utils import get_config_with_dirs

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [87]:
(data_conf, model_conf, learning_conf, params_conf), _ = get_config_with_dirs('config.ini')

original_df, (train_sequences, val_sequences, test_sequences) = data_preprocessing(
    data_conf,
    model_conf,
    params_conf
)

  0%|          | 0/490513 [00:00<?, ?it/s]

Preparing gc from embedding layer:   0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

In [55]:
train_sequences

Unnamed: 0_level_0,small_group,amount_rur,hour,day,day_of_week,month,average_amt,top_mcc_1,top_mcc_2,top_mcc_3,gc_id,amnt_avg_embed,mcc_avg_embed,target_flag
client_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
1700,"[10, 10, 3, 142, 3, 2, 3, 3, 3, 3, 3, 3, 3]","[37, 50, 48, 47, 47, 2, 46, 50, 49, 50, 49, 49...","[7, 8, 0, 6, 6, 0, 4, 10, 10, 6, 6, 0, 14]","[12, 6, 11, 3, 3, 5, 5, 18, 18, 20, 20, 24, 7]","[5, 0, 5, 0, 0, 2, 2, 3, 3, 5, 5, 2, 2]","[11, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6]","[44, 43, 44, 44, 44, 44, 44, 43, 43, 43, 43, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[1, 4, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 8]","[[-0.16445142030715942, 0.208607017993927, -0....","[[-0.10418444126844406, -0.25044000148773193, ...",3
902,"[22, 2, 22, 97, 2, 3, 1, 1, 1, 22, 1, 3, 2, 61...","[33, 33, 17, 11, 28, 44, 21, 18, 15, 19, 16, 5...","[0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[27, 28, 29, 29, 31, 4, 5, 6, 6, 7, 7, 9, 9, 9...","[5, 6, 0, 0, 2, 6, 0, 1, 1, 2, 2, 4, 4, 4, 5, ...","[5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[[-0.10222120583057404, 0.24029222130775452, -...","[[-0.1319003850221634, -0.22998060286045074, -...",2
1820,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 205, 3, 3, 3...","[50, 36, 42, 50, 50, 20, 20, 20, 20, 20, 49, 3...","[7, 7, 11, 7, 7, 11, 11, 11, 11, 11, 11, 0, 7,...","[21, 21, 21, 22, 23, 27, 27, 27, 27, 27, 27, 2...","[2, 2, 2, 3, 4, 1, 1, 1, 1, 1, 1, 3, 3, 1, 1, ...","[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 1...","[44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[[-0.12508225440979004, 0.21130897104740143, -...","[[-0.1040748655796051, -0.2802312970161438, -0...",3
721,"[23, 3, 3, 13, 22, 13, 22, 3, 13, 3, 22, 22, 1...","[29, 31, 43, 19, 10, 14, 18, 31, 18, 31, 13, 1...","[0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[1, 1, 1, 3, 3, 3, 4, 5, 5, 6, 6, 7, 7, 9, 9, ...","[5, 5, 5, 0, 0, 0, 1, 2, 2, 3, 3, 4, 4, 6, 6, ...","[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ...","[44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, ...","[[-0.1222296729683876, 0.2330232858657837, -0....","[[-0.1459289789199829, -0.21619509160518646, -...",5
1842,"[26, 39, 13, 3, 55, 16, 2, 40, 40, 43, 46, 16,...","[24, 46, 17, 48, 46, 48, 38, 44, 44, 44, 43, 4...","[0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[19, 19, 19, 19, 20, 20, 20, 20, 20, 21, 21, 2...","[4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 0, 0, 0, ...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[[-0.10222120583057404, 0.24029222130775452, -...","[[-0.1319003850221634, -0.22998060286045074, -...",19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4227,"[25, 78, 27, 5, 27, 19, 2, 27, 3, 27, 2, 14, 2...","[35, 35, 33, 35, 15, 27, 35, 13, 43, 24, 29, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 9, 9, 9, 10, 10...","[3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 6, 0, 0, ...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[[-0.08820830285549164, 0.2611183524131775, -0...","[[-0.10115717351436615, -0.18622733652591705, ...",27
558,"[3, 3, 25, 2, 2, 2, 71, 71, 2, 58, 5, 5, 43, 2...","[6, 6, 40, 38, 27, 22, 29, 34, 19, 24, 19, 28,...","[8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[23, 23, 24, 24, 24, 24, 25, 25, 25, 26, 26, 2...","[4, 4, 5, 5, 5, 5, 6, 6, 6, 0, 0, 0, 1, 1, 1, ...","[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...","[[-0.09153849631547928, 0.24858978390693665, -...","[[-0.10688411444425583, -0.1952364146709442, -...",3
9679,"[78, 78, 78, 78, 53, 78, 78, 78, 78, 78, 78, 7...","[2, 2, 2, 39, 1, 1, 4, 3, 3, 7, 2, 7, 7, 17, 3...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[6, 11, 11, 11, 12, 16, 16, 16, 16, 17, 17, 17...","[1, 6, 6, 6, 0, 4, 4, 4, 4, 5, 5, 5, 0, 0, 1, ...","[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 1...","[44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[[-0.16445142030715942, 0.208607017993927, -0....","[[-0.10418444126844406, -0.25044000148773193, ...",78
2023,"[2, 2, 63, 68, 68, 4, 2, 4, 17, 2, 2, 27, 19, ...","[16, 12, 23, 33, 33, 16, 42, 30, 22, 34, 14, 2...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[7, 10, 10, 10, 10, 11, 11, 11, 12, 13, 14, 14...","[4, 0, 0, 0, 0, 1, 1, 1, 2, 3, 4, 4, 4, 5, 6, ...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...","[42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 4...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 1...","[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ...","[[-0.08820830285549164, 0.2611183524131775, -0...","[[-0.10115717351436615, -0.18622733652591705, ...",19


In [88]:
datamodule = TransactionRNNDataModule(
    train_sequences,
    val_sequences,
    test_sequences,
    params_conf,
    learning_conf
)

In [89]:
a = next(iter(datamodule.train_dataloader()))

In [85]:
a[-1].shape

torch.Size([128, 401, 16])

In [50]:
avg_amnt_seqs = []
avg_mcc_seqs = []
for seq in tqdm(train_sequences.iloc, total=train_sequences.shape[0]):
    amnt_seq = []
    mcc_seq = []

    for i in seq['gc_id']:
        amnt_seq.append(gc_emb_amnt[i])
        mcc_seq.append(gc_emb_mcc[i])
    avg_amnt_seqs.append(amnt_seq)
    avg_mcc_seqs.append(mcc_seq)

  0%|          | 0/3804 [00:00<?, ?it/s]

In [51]:
train_sequences['amnt_avg_embed'] = avg_amnt_seqs

In [52]:
train_sequences.iloc[0]['amnt_avg_embed']

[array([-0.14871599,  0.19030489, -0.46283162, -0.34657732, -0.24200143,
        -0.18323986,  0.06345016, -0.40866023]),
 array([-0.09616771,  0.21815494, -0.41340044, -0.37405574, -0.25912285,
        -0.20399565,  0.09143222, -0.418614  ]),
 array([-0.10876427,  0.2172938 , -0.42839098, -0.3649483 , -0.24618533,
        -0.20384894,  0.08288978, -0.42031041]),
 array([-0.10876427,  0.2172938 , -0.42839098, -0.3649483 , -0.24618533,
        -0.20384894,  0.08288978, -0.42031041]),
 array([-0.10876427,  0.2172938 , -0.42839098, -0.3649483 , -0.24618533,
        -0.20384894,  0.08288978, -0.42031041]),
 array([-0.10876427,  0.2172938 , -0.42839098, -0.3649483 , -0.24618533,
        -0.20384894,  0.08288978, -0.42031041]),
 array([-0.10876427,  0.2172938 , -0.42839098, -0.3649483 , -0.24618533,
        -0.20384894,  0.08288978, -0.42031041]),
 array([-0.08864156,  0.22401129, -0.41461539, -0.38022062, -0.24510355,
        -0.20337507,  0.09225266, -0.4231846 ]),
 array([-0.08864156,  0.