
# TabNet-Training: Attentive Interpretable Tabular Learning
https://www.kaggle.com/code/i2nfinit3y/jane-street-tabm-ft-transformer-training/notebook

# TabNet-Inference
https://www.kaggle.com/code/i2nfinit3y/jane-street-tabm-ft-transformer-inference

In [None]:
!pip -q install rtdl_num_embeddings delu rtdl_revisiting_models 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import rtdl_num_embeddings
from rtdl_num_embeddings import compute_bins
import rtdl_revisiting_models
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset

from sklearn.model_selection import train_test_split

from sklearn.metrics import r2_score
import pandas as pd
import math
import numpy as np
import delu
from tqdm import tqdm
import polars as pl
from collections import OrderedDict
import sys

from torch import Tensor
from typing import List, Callable, Union, Any, TypeVar, Tuple

import joblib

import gc

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
feature_train_list = [f"feature_{idx:02d}" for idx in range(79)] 
target_col = "responder_6"
feature_train = feature_train_list \
                + [f"responder_{idx}_lag_1" for idx in range(9)] 

start_dt = 800
end_dt = 1577

feature_cat = ["feature_09", "feature_10", "feature_11"]
feature_cont = [item for item in feature_train if item not in feature_cat]
std_feature = [i for i in feature_train_list if i not in feature_cat] + [f"responder_{idx}_lag_1" for idx in range(9)]

batch_size = 1024
# batch_size = 8192
num_epochs = 20

data_stats = joblib.load("/kaggle/input/jane-street-data-preprocessing/data_stats.pkl")
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
train_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/training.parquet")
valid_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/validation.parquet")
all_original = pl.concat([train_original, valid_original])

def get_category_mapping(df, column):
    unique_values = df.select([column]).unique().collect().to_series()
    return {cat: idx for idx, cat in enumerate(unique_values)}
category_mappings = {col: get_category_mapping(all_original, col) for col in feature_cat + ['symbol_id', 'time_id']}

_category_mappings = {'feature_09': {2: 0, 4: 1, 9: 2, 11: 3, 12: 4, 14: 5, 15: 6, 25: 7, 26: 8, 30: 9, 34: 10, 42: 11, 44: 12, 46: 13, 49: 14, 50: 15, 57: 16, 64: 17, 68: 18, 70: 19, 81: 20, 82: 21},
 'feature_10': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 10: 7, 12: 8},
 'feature_11': {9: 0, 11: 1, 13: 2, 16: 3, 24: 4, 25: 5, 34: 6, 40: 7, 48: 8, 50: 9, 59: 10, 62: 11, 63: 12, 66: 13,
  76: 14, 150: 15, 158: 16, 159: 17, 171: 18, 195: 19, 214: 20, 230: 21, 261: 22, 297: 23, 336: 24, 376: 25, 388: 26, 410: 27, 522: 28, 534: 29, 539: 30},
 'symbol_id': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19,
  20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38},
 'time_id' : {i : i for i in range(968)}}

def encode_column(df, column, mapping):
    def encode_category(category):
        return mapping.get(category, -1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=pl.Int16).alias(column)
    )

for col in feature_cat + ['symbol_id', 'time_id']:
    train_original = encode_column(train_original, col, category_mappings[col])
    valid_original = encode_column(valid_original, col, category_mappings[col])

In [5]:
train_data1 = train_original \
             .filter((pl.col("date_id") >= start_dt) & (pl.col("date_id") <= end_dt)) \
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

train_data2 = valid_original \
             .filter(pl.col("date_id") <= end_dt) \
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

train_data = pl.concat([train_data1, train_data2])
valid_data = valid_original \
             .filter(pl.col("date_id") > end_dt)\
             .sort(['date_id', 'time_id'])\
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])


In [6]:
train_data_tensor = torch.tensor(train_data.collect().to_numpy(), dtype=torch.float32)
valid_data_tensor = torch.tensor(valid_data.collect().to_numpy(), dtype=torch.float32)

In [20]:
train_ds = TensorDataset(train_data_tensor)
train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=1, pin_memory=False, shuffle=True)
valid_ds = TensorDataset(valid_data_tensor)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, num_workers=1, pin_memory=False, shuffle=False)

In [8]:
all_data = False
if all_data:
    train_ds = ConcatDataset([train_ds, valid_ds])
    train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True)

In [9]:
n_cont_features = 85
n_cat_features = 5
n_classes = None
cat_cardinalities = [23, 10, 32, 40, 969]

In [13]:
bins_input = train_data_tensor[:, :-4][:, [col for col in range(train_data_tensor[:, :-4].shape[1]) if col not in [9, 10, 11]]]
nan_mask = torch.isnan(bins_input)
inf_mask = torch.isinf(bins_input)
valid_rows = ~(nan_mask.any(dim=1) | inf_mask.any(dim=1))
valid_rows

tensor([False, False, False,  ...,  True, False,  True])

In [14]:

bins_input_clean = bins_input[valid_rows][:1_000_000]
bins = compute_bins(bins_input_clean , n_bins=32) 

In [10]:
def make_parameter_groups(model):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # Gradient가 필요 없는 경우 생략
        if 'cont_embeddings' in name or 'bias' in name:
            # Embedding 레이어와 bias에는 weight decay를 적용하지 않음
            no_decay.append(param)
        else:
            # 나머지 파라미터에는 weight decay 적용
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.0},
        {'params': decay, 'weight_decay': 5e-3}
    ]

In [11]:
from typing import Optional

class Model(nn.Module):
    def __init__(
        self,
        n_cont_features: int,
        cat_cardinalities: list[int],
        bins: Optional[list[Tensor]],
        mlp_kwargs: dict,
    ) -> None:
        super().__init__()
        self.cat_cardinalities = cat_cardinalities
        # The total representation size for categorical features
        # == the sum of one-hot representation sizes
        # == the sum of the numbers of distinct values of all features.
        d_cat = sum(cat_cardinalities)

        # Choose any of the embeddings below.

        # d_embedding = 24
        # self.cont_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(
        #     n_cont_features, d_embedding, lite=False
        # )
        # d_num = n_cont_features * d_embedding

        # assert bins is not None
        # self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEncoding(bins)
        # d_num = sum(len(b) - 1 for b in bins)

        assert bins is not None
        d_embedding = 8
        self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
            bins, d_embedding, activation=False, version='B'
        )
        d_num = n_cont_features * d_embedding

        # d_embedding = 32
        # self.cont_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(
        #     n_cont_features, d_embedding
        # )
        # d_num = n_cont_features * d_embedding

        self.backbone = rtdl_revisiting_models.MLP(d_in=d_num + d_cat, **mlp_kwargs)

    def forward(self, x_cont: Tensor, x_cat: Optional[Tensor]) -> Tensor:
        x = []

        # Step 1. Embed the continuous features.
        # Flattening is needed for MLP-like models.
        x.append(self.cont_embeddings(x_cont).flatten(1))

        # Step 2. Encode the categorical features using any strategy.
        if x_cat is not None:
            x.extend(
                F.one_hot(column, cardinality)
                for column, cardinality in zip(x_cat.T, self.cat_cardinalities)
            )

        # Step 3. Assemble the vector input for the backbone.
        x = torch.column_stack(x)

        # Step 4. Apply the backbone.
        return self.backbone(x)

In [15]:
task_type = 'regression'
model = Model(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    bins=bins,
    mlp_kwargs={        
        'n_blocks': 3,
        'd_block': 512,
        'dropout': 0.25,
        'd_out': n_classes if task_type == 'multiclass' else 1,
    },
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    # make_parameter_groups(model),
    lr=1e-4,
    weight_decay=5e-3 ,
)


In [33]:
class R2Loss(nn.Module):
    def __init__(self):
        super(R2Loss, self).__init__()

    def forward(self, y_pred, y_true):
        mse_loss = torch.sum((y_pred - y_true) ** 2)
        var_y = torch.sum(y_true ** 2)
        loss = mse_loss / (var_y + 1e-38)
        return loss


class LogCoshLoss(nn.Module):
    def __init__(self):
        super(LogCoshLoss, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.log(torch.cosh(y_pred - y_true))
        return torch.mean(loss)
# loss_fn = nn.HuberLoss(delta=0.2)
loss_fn = R2Loss()
# loss_fn = nn.MSELoss()

timer = delu.tools.Timer()
patience = 5
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "epoch": -1,
}
timer.run()

def r2_val(y_true, y_pred, sample_weight):
    residuals = sample_weight * (y_true - y_pred) ** 2
    weighted_residual_sum = np.sum(residuals)

    # Calculate weighted sum of squared true values (denominator)
    weighted_true_sum = np.sum(sample_weight * (y_true) ** 2)

    # Calculate weighted R2
    r2 = 1 - weighted_residual_sum / weighted_true_sum

    return r2

In [34]:
!export POLARS_ALLOW_FORKING_THREAD=1

In [35]:
for epoch in range(num_epochs):
    model.train()

    train_pred_list = []
    with tqdm(train_dl, total=len(train_dl), leave=True) as phar:
        i = 0
        for train_tensor in phar:
            optimizer.zero_grad()
            X_input = train_tensor[0][:, :-4].to(device)
            y_input = train_tensor[0][:, -4].to(device)
            w_input = train_tensor[0][:, -3].to(device)
            symbol_input    = train_tensor[0][:, -2].to(device)
            time_input      = train_tensor[0][:, -1].to(device)            
            x_cont_input = X_input[:, [col for col in range(X_input.shape[1]) if col not in [9, 10, 11]]]

            x_cat_input = X_input[:, [9, 10, 11]]
            x_cat_input = torch.concat(
                [x_cat_input, symbol_input.unsqueeze(-1), time_input.unsqueeze(-1)], axis=1
            ).to(torch.int64)

            # Replace NaNs/Infs in x_cont_input with 0
            x_cont_input = torch.where(torch.isnan(x_cont_input) | torch.isinf(x_cont_input), torch.tensor(0.0, device=x_cont_input.device), x_cont_input)
            output = model(x_cont_input, x_cat_input).squeeze(-1)
            loss = loss_fn(output, y_input)
            

            loss.backward()
            torch.nn.utils.clip_grad_value_(
                model.parameters(), 
                clip_value=1.0,
            )
            optimizer.step()

            if i % 10 == 0:
                train_pred_list.append((output, y_input, w_input))
                phar.set_postfix(
                    OrderedDict(
                        epoch=f'{epoch + 1}/{num_epochs}',
                        loss=f'{loss.item():.6f}',
                        lr=f'{optimizer.param_groups[0]["lr"]:.3e}'
                    )
                )
            phar.update(1)
            i += 1

    weights_train = torch.cat([x[2] for x in train_pred_list]).cpu().numpy()
    y_train = torch.cat([x[1] for x in train_pred_list]).cpu().numpy()
    prob_train = torch.cat([x[0] for x in train_pred_list]).detach().cpu().numpy()
    train_r2 = r2_val(y_train, prob_train, weights_train)


    model.eval()
    valid_loss_list = []
    valid_pred_list = []
    with tqdm(valid_dl, total=len(valid_dl), leave=True) as phar:
        for valid_tensor in phar:
            X_valid = valid_tensor[0][:, :-4].to(device)
            y_valid = valid_tensor[0][:, -4].to(device)
            w_valid = valid_tensor[0][:, -3].to(device)
            symbol_valid = valid_tensor[0][:, -2].to(device)
            time_valid = valid_tensor[0][:, -1].to(device)
            
            x_cont_valid = X_valid[:, [col for col in range(X_valid.shape[1]) if col not in [9, 10, 11]]]
            x_cont_valid = torch.where(
                torch.isnan(x_cont_valid) | torch.isinf(x_cont_valid),
                torch.tensor(0.0, device=x_cont_valid.device), 
            x_cont_valid)
            
            x_cat_valid = X_valid[:, [9, 10, 11]]
            x_cat_valid = (torch.concat(
                [x_cat_valid, symbol_valid.unsqueeze(-1),time_valid.unsqueeze(-1)], axis=1)
            ).to(torch.int64)

            with torch.no_grad():
                y_pred = model(x_cont_valid, x_cat_valid).squeeze(-1)

            val_loss = loss_fn(y_pred, y_valid)
            valid_loss_list.append(val_loss)
            valid_pred_list.append((y_pred, y_valid, w_valid))
            phar.set_postfix(
                OrderedDict(
                    epoch=f'{epoch + 1}/{num_epochs}',
                    val_loss=f'{val_loss.item():.6f}',
                    lr=f'{optimizer.param_groups[0]["lr"]:.3e}'
                )
            )
            phar.update(1)
            i += 1

        valid_loss_mean = sum(valid_loss_list) / len(valid_loss_list)
        weights_eval = torch.cat([x[2] for x in valid_pred_list]).cpu().numpy()
        y_eval = torch.cat([x[1] for x in valid_pred_list]).cpu().numpy()
        prob_eval = torch.cat([x[0] for x in valid_pred_list]).cpu().numpy()
        val_r2 = r2_val(y_eval, prob_eval, weights_eval)

        print(f"Epoch {epoch + 1}: train_r2 = {train_r2:.6f}, val_loss_mean={valid_loss_mean:.6f}, val_r2={val_r2:.6f}, [time] {timer}")

        if val_r2 > best["val"]:
            print("🌸 New best epoch! 🌸")
            best = {"val": val_r2, "epoch": epoch}
            checkpoint = {
                'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'r2': val_r2,
        }
        torch.save(checkpoint, f'epoch{epoch}_r2_{val_r2}.pt')
    print()
    early_stopping.update(val_r2)
    if early_stopping.should_stop():
        print("Early stop")
        break

checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'r2': val_r2,
}

 43%|████▎     | 11502/26687 [04:26<06:51, 36.94it/s, epoch=1/4, loss=0.882956, lr=1.000e-04]