
# TabNet-Training: Attentive Interpretable Tabular Learning
https://www.kaggle.com/code/i2nfinit3y/jane-street-tabm-ft-transformer-training/notebook

# TabNet-Inference
https://www.kaggle.com/code/i2nfinit3y/jane-street-tabm-ft-transformer-inference

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import rtdl_num_embeddings
from rtdl_num_embeddings import compute_bins
import rtdl_revisiting_models
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset

from sklearn.model_selection import train_test_split

from sklearn.metrics import r2_score
import pandas as pd
import math
import numpy as np
import delu
from tqdm import tqdm
import polars as pl
from collections import OrderedDict
import sys

from torch import Tensor
from typing import List, Callable, Union, Any, TypeVar, Tuple

import joblib

import gc

In [51]:

feature_list = [f"feature_{idx:02d}" for idx in range(79) if idx != 61]

target_col = "responder_6" 

feature_test = feature_list \
                + [f"responder_{idx}_lag_1" for idx in range(9)] 

feature_cat = ["feature_09", "feature_10", "feature_11"]
feature_cont = [item for item in feature_test if item not in feature_cat]

batch_size = 8192

std_feature = [i for i in feature_list if i not in feature_cat] + [f"responder_{idx}_lag_1" for idx in range(9)]

data_stats = joblib.load("/kaggle/input/jane-street-data-preprocessing/data_stats.pkl")
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

In [52]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
feature_train_list = [f"feature_{idx:02d}" for idx in range(79)] 
target_col = "responder_6"
feature_train = feature_train_list \
                + [f"responder_{idx}_lag_1" for idx in range(9)] 

start_dt = 110
end_dt = 1577

feature_cat = ["feature_09", "feature_10", "feature_11"]
feature_cont = [item for item in feature_train if item not in feature_cat]
std_feature = [i for i in feature_train_list if i not in feature_cat] + [f"responder_{idx}_lag_1" for idx in range(9)]

batch_size = 8192
num_epochs = 100

data_stats = joblib.load("/kaggle/input/jane-street-data-preprocessing/data_stats.pkl")
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

In [53]:
train_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/training.parquet")
valid_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/validation.parquet")
all_original = pl.concat([train_original, valid_original])

def get_category_mapping(df, column):
    unique_values = df.select([column]).unique().collect().to_series()
    return {cat: idx for idx, cat in enumerate(unique_values)}
category_mappings = {col: get_category_mapping(all_original, col) for col in feature_cat + ['symbol_id', 'time_id']}

def encode_column(df, column, mapping):
    def encode_category(category):
        return mapping.get(category, -1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=pl.Int16).alias(column)
    )

for col in feature_cat + ['symbol_id', 'time_id']:
    train_original = encode_column(train_original, col, category_mappings[col])
    valid_original = encode_column(valid_original, col, category_mappings[col])

In [54]:
train_data1 = train_original \
             .filter((pl.col("date_id") >= start_dt) & (pl.col("date_id") <= end_dt)) \
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

train_data2 = valid_original \
             .filter(pl.col("date_id") <= end_dt) \
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

train_data = pl.concat([train_data1, train_data2])
valid_data = valid_original \
             .filter(pl.col("date_id") > end_dt)\
             .sort(['date_id', 'time_id'])\
             .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])


In [55]:
train_data_tensor = torch.tensor(train_data.collect().to_numpy(), dtype=torch.float32)
valid_data_tensor = torch.tensor(valid_data.collect().to_numpy(), dtype=torch.float32)

In [56]:
train_ds = TensorDataset(train_data_tensor)
train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=1, pin_memory=False, shuffle=True)
valid_ds = TensorDataset(valid_data_tensor)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, num_workers=1, pin_memory=False, shuffle=False)

In [57]:
n_cont_features = 85
n_cat_features = 5
n_classes = None
cat_cardinalities = [23, 10, 32, 40, 969]

In [58]:
bins_input = train_data_tensor[:, :-4][:, [col for col in range(train_data_tensor[:, :-4].shape[1]) if col not in [9, 10, 11]]]
nan_mask = torch.isnan(bins_input)
inf_mask = torch.isinf(bins_input)
valid_rows = ~(nan_mask.any(dim=1) | inf_mask.any(dim=1))
valid_rows

tensor([False, False, False,  ...,  True, False,  True])

In [59]:

bins_input_clean = bins_input[valid_rows][:1_000_000]
bins = compute_bins(bins_input_clean , n_bins=32) 

In [60]:
from typing import Optional

class Model(nn.Module):
    def __init__(
        self,
        n_cont_features: int,
        cat_cardinalities: list[int],
        bins: Optional[list[Tensor]],
        mlp_kwargs: dict,
    ) -> None:
        super().__init__()
        self.cat_cardinalities = cat_cardinalities
        # The total representation size for categorical features
        # == the sum of one-hot representation sizes
        # == the sum of the numbers of distinct values of all features.
        d_cat = sum(cat_cardinalities)

        # Choose any of the embeddings below.

        # d_embedding = 24
        # self.cont_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(
        #     n_cont_features, d_embedding, lite=False
        # )
        # d_num = n_cont_features * d_embedding

        # assert bins is not None
        # self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEncoding(bins)
        # d_num = sum(len(b) - 1 for b in bins)

        assert bins is not None
        d_embedding = 8
        self.cont_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
            bins, d_embedding, activation=False, version='B'
        )
        d_num = n_cont_features * d_embedding

        # d_embedding = 32
        # self.cont_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(
        #     n_cont_features, d_embedding
        # )
        # d_num = n_cont_features * d_embedding

        self.backbone = rtdl_revisiting_models.MLP(d_in=d_num + d_cat, **mlp_kwargs)

    def forward(self, x_cont: Tensor, x_cat: Optional[Tensor]) -> Tensor:
        x = []

        # Step 1. Embed the continuous features.
        # Flattening is needed for MLP-like models.
        x.append(self.cont_embeddings(x_cont).flatten(1))

        # Step 2. Encode the categorical features using any strategy.
        if x_cat is not None:
            x.extend(
                F.one_hot(column, cardinality)
                for column, cardinality in zip(x_cat.T, self.cat_cardinalities)
            )

        # Step 3. Assemble the vector input for the backbone.
        x = torch.column_stack(x)

        # Step 4. Apply the backbone.
        return self.backbone(x)

In [61]:
task_type = 'regression'
model = Model(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    bins=bins,
    mlp_kwargs={        
        'n_blocks': 3,
        'd_block': 512,
        'dropout': 0.25,
        'd_out': n_classes if task_type == 'multiclass' else 1,
    },
).to(device)

model.load_state_dict(torch.load("./import-tabm-inference/epoch0_r2_0.004429519176483154.pt")['model_state_dict'])

  model.load_state_dict(torch.load("./import-tabm-inference/epoch0_r2_0.004429519176483154.pt")['model_state_dict'])


<All keys matched successfully>

In [83]:
lags_ : pl.DataFrame | None = None

lags_history = None

def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
    global lags_, lags_history
    if lags is not None:
        lags_ = lags
    for col in feature_cat + ['symbol_id', 'time_id']:
        test = encode_column(test, col, category_mappings[col])
    predictions = test.select('row_id', pl.lit(0.0).alias('responder_6'))
    symbol_ids = test.select('symbol_id').to_numpy()[:, 0]
    time_id = test.select("time_id").to_numpy()[0]
    timie_id_array = test.select("time_id").to_numpy()[:, 0]
    
    if time_id == 0:
        lags = lags.with_columns(pl.col('time_id').cast(pl.Int64))
        lags = lags.with_columns(pl.col('symbol_id').cast(pl.Int64))
        lags_history = lags
        lags = lags.filter(pl.col("time_id") == 0)
        test = test.join(lags, on=["time_id", "symbol_id"],  how="left")
    else:
        lags = lags_history.filter(pl.col("time_id") == time_id)
        test = test.join(lags, on=["time_id", "symbol_id"],  how="left")

    test = test.with_columns([
        pl.col(col).fill_null(0) for col in feature_list + [f"responder_{idx}_lag_1" for idx in range(9)] 
    ])
    test = standardize(test, std_feature, means, stds)
    X_test = test[feature_test].to_numpy()
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    symbol_tensor = torch.tensor(symbol_ids, dtype=torch.float32).to(device)
    time_tensor = torch.tensor(timie_id_array, dtype=torch.float32).to(device)
    X_cat = X_test_tensor[:, [9, 10, 11]]
    X_cont = X_test_tensor[:, [i for i in range(X_test_tensor.shape[1]) if i not in [9, 10, 11]]]
    X_cat = (torch.concat([X_cat, symbol_tensor.unsqueeze(-1), time_tensor.unsqueeze(-1)], axis=1)).to(torch.int64)
    model.eval()
    with torch.no_grad():
        outputs = model(X_cont, X_cat)
        # Assuming the model outputs a tensor of shape (batch_size, 1)
        preds = outputs.squeeze(-1).cpu().numpy()
        preds = preds.mean(1)
    predictions = \
    test.select('row_id').\
    with_columns(
        pl.Series(
            name   = 'responder_6', 
            values = np.clip(preds, a_min = -5, a_max = 5),
            dtype  = pl.Float64,
        )
    )

    # The predict function must return a DataFrame
    assert isinstance(predictions, pl.DataFrame | pd.DataFrame)
    # with columns 'row_id', 'responer_6'
    assert list(predictions.columns) == ['row_id', 'responder_6']
    # and as many rows as the test data.
    assert len(predictions) == len(test)

    return predictions

In [63]:
!export POLARS_ALLOW_FORKING_THREAD=1

In [64]:
model.load_state_dict(torch.load("./import-tabm-inference/epoch0_r2_0.004429519176483154.pt")['model_state_dict'])

  model.load_state_dict(torch.load("./import-tabm-inference/epoch0_r2_0.004429519176483154.pt")['model_state_dict'])


<All keys matched successfully>

In [67]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import polars as pl
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import lightgbm
import torch

In [68]:
valid_from = 1577 # for private you should change to 1455 (1 year)

In [69]:
alltraindata = pl.scan_parquet("/kaggle/input/jane-street-realtime-marketdata-forecasting/train.parquet")
valid_df = alltraindata.filter(pl.col("date_id")>=valid_from).collect()
valid_df = valid_df.with_columns(
    pl.Series(range(len(valid_df))).alias("row_id"),
    pl.lit(True).alias("is_scored")
)
len(valid_df)

4532176

In [70]:
valid_df.write_parquet("valid_df.parquet")

In [72]:
test_sample = pl.read_parquet("/kaggle/input/jane-street-realtime-marketdata-forecasting/test.parquet/date_id=0/part-0.parquet")
test_sample.head(3)

row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,…,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78
i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,3.169998,True,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
1,0,0,1,2.165993,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0
2,0,0,2,3.06555,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,-0.0,-0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0


In [74]:
lag_sample = pl.read_parquet("/kaggle/input/jane-street-realtime-marketdata-forecasting/lags.parquet/date_id=0/part-0.parquet")
train_sample = pl.read_parquet("/kaggle/input/jane-street-realtime-marketdata-forecasting/train.parquet/partition_id=0/part-0.parquet",n_rows=1)
responder_cols = [s for s in train_sample.columns if "responder" in s]

def makelag(date_id):
    """
    Making lag at the previout day

    Args:
    date_id (int): date_id at the previout day
    
    Returns:
    pl.dataframe
    """
    
    lag = alltraindata.filter(pl.col("date_id")==date_id).select(["date_id","time_id","symbol_id"] + responder_cols).collect()
    lag.columns = lag_sample.columns
    
    return lag    

In [77]:
os.makedirs("/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/test.parquet",exist_ok=True)
os.makedirs("/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/lags.parquet",exist_ok=True)

In [73]:
valid_df = valid_df.select(test_sample.columns)

In [76]:
total_iterations = len(valid_df["date_id"].unique())
total_iterations

122

In [79]:
for num_days, df_per_day in tqdm(valid_df.group_by("date_id",maintain_order=True),total=total_iterations,desc="Processing"):
    
       
    day = num_days[0] - valid_from # date_id must start from 0.
    
    os.makedirs(f"/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/test.parquet/date_id={day}",exist_ok=True)
    os.makedirs(f"/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/lags.parquet/date_id={day}",exist_ok=True)
    
    lag = makelag(num_days[0] - 1)
    
    df_per_day.write_parquet(f"/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/test.parquet/date_id={day}/part-0.parquet")
    lag.write_parquet(f"/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/lags.parquet/date_id={day}/part-0.parquet")

Processing: 100%|██████████| 122/122 [00:09<00:00, 12.38it/s]


In [80]:
features = [s for s in valid_df.columns if "feature" in s]
len(features)

79

In [84]:
%%time

EVAL = True 
if EVAL:
    test_dir = '/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/test.parquet'
    lags_dir = '/kaggle/input/janestreet-updated-simulator-for-time-series-api/debug/lags.parquet'
else:
    test_dir = '/kaggle/input/jane-street-realtime-marketdata-forecasting/test.parquet'
    lags_dir = '/kaggle/input/jane-street-realtime-marketdata-forecasting/lags.parquet'

import kaggle_evaluation.jane_street_inference_server
inference_server = kaggle_evaluation.jane_street_inference_server.JSInferenceServer(predict)
import os

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
            test_dir,
            lags_dir
        )
    )

GatewayRuntimeError: (<GatewayRuntimeErrorType.SERVER_RAISED_EXCEPTION: 3>, "'NoneType' object has no attribute 'filter'")