In [1]:
%load_ext autoreload
%autoreload 2

In [62]:
import warnings
import sys

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import pandas as pd
from os import path
from pathlib import Path
from tqdm import tqdm
from tqdm._tqdm_notebook import tqdm_notebook
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as f
import joblib
from numpy import dot
from numpy.linalg import norm
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity
from dataset_utils import SiamLikeDataset, train_val_test_split
from models import CombinedModel

tqdm_notebook.pandas()

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
current_dir = path.join(*path.split(Path().absolute()))
data_dir = path.join(current_dir, 'data')

In [18]:
markup = pd.read_csv(path.join(data_dir, 'markup.csv'))
train, valid, test = train_val_test_split(
    markup,test_size=0.25,valid_size=0.25,random_state=42,stratify='target'
)
train.shape, valid.shape, test.shape

((73354, 3), (36678, 3), (36678, 3))

In [19]:
dtst_train = SiamLikeDataset(markup=train,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))
dtst_valid = SiamLikeDataset(markup=valid,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))
dtst_test = SiamLikeDataset(markup=test,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))

In [23]:
batch_size = 128
kwargs = {'num_workers': 0, 'batch_size': batch_size, 'shuffle': False, 'drop_last': False}
train_dataloader = DataLoader(dtst_train, **kwargs)
valid_dataloader = DataLoader(dtst_valid, **kwargs)
test_dataloader = DataLoader(dtst_test, **kwargs)

In [11]:
le_mcc = joblib.load(path.join(data_dir, 'models_objects', 'le_mcc'))
le_currency_rk = joblib.load(path.join(data_dir, 'models_objects', 'le_currency_rk'))
le_click_categories = joblib.load(path.join(data_dir, 'models_objects', 'le_click_categories'))

In [22]:
model = CombinedModel(mcc_classes=len(le_mcc.classes_),
                          mcc_emb_size=3,
                          currency_rk_classes=len(le_currency_rk.classes_),
                          currency_rk_emb_size=2, cat_id_classes=len(le_click_categories.classes_),
                          cat_id_emb_size=5, device=device).to(device)
model.load_state_dict(torch.load(path.join(
                   data_dir,
                   'nn_chpt', 
                   'model_2022-04-03 20_29_11_0.1004_0.10036'), map_location=device))
model.eval()

CombinedModel(
  (m_bank): BankModel(
    (emb_mcc): EmbeddingModel(
      (emb): Sequential(
        (0): Embedding(387, 3, padding_idx=0)
        (1): Dropout(p=0.1, inplace=False)
      )
    )
    (emb_currency_rk): EmbeddingModel(
      (emb): Sequential(
        (0): Embedding(5, 2, padding_idx=0)
        (1): Dropout(p=0.1, inplace=False)
      )
    )
    (lstm_mcc): LSTMModel(
      (lstm_1d): Sequential(
        (0): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Dropout(p=0.1, inplace=False)
        (2): LSTM(3, 64, batch_first=True, dropout=0.1, bidirectional=True)
      )
      (lstm_2d): Sequential(
        (0): LSTM(3, 64, batch_first=True, dropout=0.1, bidirectional=True)
      )
    )
    (lstm_currency_rk): LSTMModel(
      (lstm_1d): Sequential(
        (0): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Dropout(p=0.1, inplace=False)
        (2): LSTM(2, 64, batch_first=True, dr

In [63]:
load_dict = {'train': train_dataloader, 'valid': valid_dataloader, 'test': test_dataloader}
results_dict = dict.fromkeys(load_dict.keys())

for sample, sample_loader in load_dict.items():
    with torch.no_grad():
        with tqdm(sample_loader, unit="batch") as tqdm_sample_loader:
            results_dict[sample] = dict()
            results_dict[sample]['bank_emb'] = []
            results_dict[sample]['rtk_emb'] = []
            results_dict[sample]['cosine_sim'] = []
            results_dict[sample]['bank_id'] = []
            results_dict[sample]['rtk_id'] = []
            results_dict[sample]['target'] = []
            for batch in tqdm_sample_loader:
                results_dict[sample]['bank_emb'].extend(batch['bank_id'])
                results_dict[sample]['rtk_id'].extend(batch['rtk_id'])
                results_dict[sample]['target'].extend(batch['target'].detach().numpy().tolist())
                tqdm_sample_loader.set_description(f"{sample}")
                bes = model.m_bank(batch).detach().numpy()
                res = model.m_rtk(batch).detach().numpy()
                for b_emb, r_emb in zip(bes, res):
                    results_dict[sample]['cosine_sim'].append(dot(b_emb, r_emb)/(norm(b_emb)*norm(r_emb)))
                    results_dict[sample]['bank_emb'].append(b_emb)
                    results_dict[sample]['rtk_emb'].append(r_emb)
            results_dict[sample]['bank_emb'] = np.array(results_dict[sample]['bank_emb'])
            results_dict[sample]['rtk_emb'] = np.array(results_dict[sample]['rtk_emb'])

train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 574/574 [26:52<00:00,  2.81s/batch]
valid: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 287/287 [18:12<00:00,  3.81s/batch]
test: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 287/287 [19:47<00:00,  4.14s/batch]
