In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
import sys

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import pandas as pd
from os import path
from tqdm import tqdm
from tqdm._tqdm_notebook import tqdm_notebook
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
import joblib
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

sys.path.append('/home/ivan/Рабочий стол/vtb-matching/')
tqdm_notebook.pandas()

In [3]:
from dataset_utils import SiamLikeDataset, train_val_test_split

In [10]:
import torch.nn as nn
import torch.nn.functional as f

In [4]:
data_dir = '/home/ivan/Рабочий стол/vtb-matching/data'

In [5]:
markup = pd.read_csv(path.join(data_dir, 'markup.csv'))
train, valid, test = train_val_test_split(
    markup,test_size=0.25,valid_size=0.25,random_state=42,stratify='target'
)
train.shape, valid.shape, test.shape

((73354, 3), (36678, 3), (36678, 3))

In [8]:
dtst_train = SiamLikeDataset(markup=train,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))
dtst_valid = SiamLikeDataset(markup=valid,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))
dtst_test = SiamLikeDataset(markup=test,
                       transactions_path=path.join(data_dir, 'transaction_data'), 
                       clickstream_path=path.join(data_dir, 'clickstream_data'))

In [9]:
batch_size = 128
kwargs = {'num_workers': 0, 'pin_memory': False}
train_dataloader = DataLoader(dtst_train, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
valid_dataloader = DataLoader(dtst_valid, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
test_dataloader = DataLoader(dtst_test, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)

In [11]:
le_mcc = joblib.load(path.join(data_dir, 'models_objects', 'le_mcc'))
le_currency_rk = joblib.load(path.join(data_dir, 'models_objects', 'le_currency_rk'))
le_click_categories = joblib.load(path.join(data_dir, 'models_objects', 'le_click_categories'))

In [12]:
class EmbeddingModel(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int = 3):
        super().__init__()
        self.emb = nn.Sequential(
            nn.Embedding(num_embeddings, embedding_dim, padding_idx=0),
            nn.Dropout(p=0.1)
        )

    def forward(self, x):
        output = self.emb(x)
        return output
    

class LSTMModel(nn.Module):
    def __init__(self, input_size: int, ):
        super().__init__()
        self.lstm_1d = nn.Sequential(
            nn.BatchNorm1d(input_size),
            nn.Dropout(p=0.1),
            nn.LSTM(input_size=input_size, hidden_size=64, 
                    num_layers=1, batch_first=True, 
                    dropout=0.1, bidirectional=True)
        )
        self.lstm_2d = nn.Sequential(
            nn.LSTM(input_size=input_size, hidden_size=64, 
                    num_layers=1, batch_first=True, dropout=0.1, bidirectional=True)
        )

    def forward(self, x):
        if len(x.shape) == 2:
            output, _ = self.lstm_1d(x)
        else:
            output, _ = self.lstm_2d(x)
        return torch.cat((output[:, -1, :64], output[:, 0, 64:]), dim=1)  #актуально для bidirectional


class BankModel(nn.Module):
    
    def __init__(self, 
                 mcc_classes: int, mcc_emb_size: int,
                 currency_rk_classes: int, currency_rk_emb_size: int):
        super().__init__()
        self.emb_mcc = EmbeddingModel(num_embeddings=mcc_classes+1, 
                                      embedding_dim=mcc_emb_size)
        self.emb_currency_rk = EmbeddingModel(num_embeddings=currency_rk_classes+1, 
                                              embedding_dim=currency_rk_emb_size)
        self.lstm_mcc = LSTMModel(mcc_emb_size)
        self.lstm_currency_rk = LSTMModel(currency_rk_emb_size)
        self.lstm_transaction_amt = LSTMModel(1)
        
        self.fc = nn.Sequential(
            nn.Linear(128*3, 256),
            nn.PReLU(),
            nn.Linear(256, 256),
            nn.PReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        
        mcc_out = self.emb_mcc(x['mcc_code'])
        mcc_out = self.lstm_mcc(mcc_out)
        
        currency_rk_out = self.emb_mcc(x['currency_rk'])
        currency_rk_out = self.lstm_mcc(currency_rk_out)
        
        transaction_amt_out = self.lstm_transaction_amt(torch.unsqueeze(x['transaction_amt'].float(), 2))
        
        out = torch.cat((mcc_out, currency_rk_out, transaction_amt_out), dim=1)
        out = self.fc(out)
        return out
    
    
class RTKModel(nn.Module):
    
    def __init__(self, 
                 cat_id_classes: int, cat_id_emb_size: int):
        super().__init__()
        self.emb_cat_id = EmbeddingModel(num_embeddings=cat_id_classes+1, 
                                      embedding_dim=cat_id_emb_size)
        self.lstm_cat_id = LSTMModel(cat_id_emb_size)
        
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.PReLU(),
            nn.Linear(256, 256),
            nn.PReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        
        cat_id_out = self.emb_cat_id(x['cat_id'])
        cat_id_out = self.lstm_cat_id(cat_id_out)
        out = self.fc(cat_id_out)
        return out
    
    
class CombinedModel(nn.Module):
    def __init__(self, 
                 mcc_classes: int, mcc_emb_size: int,
                 currency_rk_classes: int, currency_rk_emb_size: int, 
                 cat_id_classes: int, cat_id_emb_size: int):
        super().__init__()
        self.m_bank = BankModel(mcc_classes=mcc_classes, 
                                mcc_emb_size=mcc_emb_size, 
                                currency_rk_classes=currency_rk_classes, 
                                currency_rk_emb_size=currency_rk_emb_size)
        self.m_rtk = RTKModel(cat_id_classes=cat_id_classes, 
                              cat_id_emb_size=cat_id_emb_size)
        
    def forward(self, x):
        bank_out = self.m_bank(x)
        rtk_out = self.m_rtk(x)
        return bank_out, rtk_out
    
    
class ContrastiveLoss(nn.Module):
    
    # https://github.com/adambielski/siamese-triplet/blob/master/losses.py
    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, raw=False):
        distances = (output2 - output1).pow(2).sum(1)  # squared distances
        losses = 0.5 * (target.float() * distances +
                        (1 + -1 * target).float() * f.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        if raw:
            return losses
        else:
            return losses.mean() #if size_average else losses.sum()

In [13]:
model = CombinedModel(mcc_classes=len(le_mcc.classes_), 
              mcc_emb_size=3, 
              currency_rk_classes=len(le_currency_rk.classes_), 
              currency_rk_emb_size=2, cat_id_classes=len(le_click_categories.classes_), 
              cat_id_emb_size=5)



In [14]:
loss = ContrastiveLoss(1)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 50

In [24]:
for epoch in range(n_epochs):
    with tqdm(train_dataloader, unit="batch") as tqdm_train_dataloader:
        model.train()
        for batch in tqdm_train_dataloader:
            tqdm_train_dataloader.set_description(f"Epoch {epoch}")
            model.zero_grad()
            bank_out, rtk_out = model(batch)
            batch_loss = loss(bank_out, rtk_out, batch['target'])
            batch_loss.backward()
            optimizer.step()
            tqdm_train_dataloader.set_postfix(loss=batch_loss.item())
            break
    with torch.no_grad():
        model.eval()
        with tqdm(valid_dataloader, unit="batch") as tqdm_valid_dataloader:
            sample_loss = []
            for batch in tqdm_train_dataloader:
                tqdm_valid_dataloader.set_description(f"Epoch {epoch}")
                bank_out, rtk_out = model(batch)
                batch_loss = loss(bank_out, rtk_out, batch['target'], raw=True).detach()
                sample_loss.extend(batch_loss.tolist())
                tqdm_valid_dataloader.set_postfix(loss=batch_loss.mean().item())

Epoch 0:   0%|                                                                                                                                                                             | 0/573 [00:08<?, ?batch/s, loss=0.0339]
Epoch 0:   0%|                                                                                                                                                                             | 0/286 [10:31<?, ?batch/s, loss=0.0994]


KeyboardInterrupt: 

In [None]:
loss = ContrastiveLoss(1)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 50

train_epoch_losses = []
valid_epoch_losses = []
es_counter = 0
for epoch in range(n_epochs):
    with tqdm(train_dataloader, unit="batch") as tqdm_train_dataloader:
        model.train()
        train_loss = []
        for batch in tqdm_train_dataloader:
            tqdm_train_dataloader.set_description(f"train Epoch {epoch}")
            model.zero_grad()
            bank_out, rtk_out = model(batch)
            batch_loss = loss(bank_out, rtk_out, batch['target'], raw=True)
            batch_loss.mean().backward()
            optimizer.step()
            train_loss.extend(batch_loss.detach().tolist())
            tqdm_train_dataloader.set_postfix(
                batch_loss=batch_loss.mean().item(),
                epoch_loss=np.mean(train_loss) if not train_epoch_losses else np.mean(train_epoch_losses))
            break
    with torch.no_grad():
        model.eval()
        with tqdm(valid_dataloader, unit="batch") as tqdm_valid_dataloader:
            valid_loss = []
            for batch in tqdm_train_dataloader:
                tqdm_valid_dataloader.set_description(f"valid Epoch {epoch}")
                bank_out, rtk_out = model(batch)
                batch_loss = loss(bank_out, rtk_out, batch['target'], raw=True).detach()
                valid_loss.extend(batch_loss.tolist())
                tqdm_valid_dataloader.set_postfix(
                    batch_loss=batch_loss.mean().item(),
                    epoch_loss=np.mean(valid_loss) if not valid_epoch_losses else np.mean(valid_epoch_losses))
                break
    train_epoch_losses.append(np.mean(train_loss))
    valid_epoch_losses.append(np.mean(valid_loss))
    now_time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    torch.save(model.state_dict(),
               path.join(
                   data_dir,
                   'nn_chpt',
                   f'model_{now_time_str}_{round(train_epoch_losses[-1],5)}_{round(valid_epoch_losses[-1],5)}'))
    if (len(valid_epoch_losses) > 1) and (valid_epoch_losses[-1] >= [-2]):
        es_counter += 1
    elif(len(valid_epoch_losses) > 1) and (valid_epoch_losses[-1] < [-2]):
        es_counter = 0
    if es_counter == 5:
        break