In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import rtdl_num_embeddings
from rtdl_num_embeddings import compute_bins
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset

from sklearn.model_selection import train_test_split

from sklearn.metrics import r2_score
import pandas as pd
import math
import numpy as np
import delu
from tqdm import tqdm
import polars as pl
from collections import OrderedDict
import sys

from tanm_reference import Model, make_parameter_groups


from torch import Tensor
from typing import List, Callable, Union, Any, TypeVar, Tuple

import joblib

import gc

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
feature_train_list = [f"feature_{idx:02d}" for idx in range(79)] 
target_col = "responder_6"
feature_train = feature_train_list \
                + [f"responder_{idx}_lag_1" for idx in range(9)] 

start_dt = 800
end_dt = 1577

feature_cat = ["feature_09", "feature_10", "feature_11"]
feature_cont = [item for item in feature_train if item not in feature_cat]
std_feature = [i for i in feature_train_list if i not in feature_cat] + [f"responder_{idx}_lag_1" for idx in range(9)]

# batch_size = 2048
batch_size = 8192
num_epochs = 4

data_stats = joblib.load("/kaggle/input/jane-street-data-preprocessing/data_stats.pkl")
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

In [None]:
train_original = pl.scan_parquet("/kaggle/input/jane-street-data-preprocessing/training.parquet")
valid_original = pl.scan_parquet("/kaggle/input/jane-street-data-preprocessing/validation.parquet")
all_original = pl.concat([train_original, valid_original])

# def get_category_mapping(df, column):
#     unique_values = df.select([column]).unique().collect().to_series()
#     return {cat: idx for idx, cat in enumerate(unique_values)}
# category_mappings = {col: get_category_mapping(all_original, col) for col in feature_cat + ['symbol_id']}

category_mappings = {'feature_09': {2: 0, 4: 1, 9: 2, 11: 3, 12: 4, 14: 5, 15: 6, 25: 7, 26: 8, 30: 9, 34: 10, 42: 11, 44: 12, 46: 13, 49: 14, 50: 15, 57: 16, 64: 17, 68: 18, 70: 19, 81: 20, 82: 21},
 'feature_10': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 10: 7, 12: 8},
 'feature_11': {9: 0, 11: 1, 13: 2, 16: 3, 24: 4, 25: 5, 34: 6, 40: 7, 48: 8, 50: 9, 59: 10, 62: 11, 63: 12, 66: 13,
  76: 14, 150: 15, 158: 16, 159: 17, 171: 18, 195: 19, 214: 20, 230: 21, 261: 22, 297: 23, 336: 24, 376: 25, 388: 26, 410: 27, 522: 28, 534: 29, 539: 30},
 'symbol_id': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19,
  20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38},
 'time_id' : {i : i for i in range(968)}}


def encode_column(df, column, mapping):
    def encode_category(category):
        return mapping.get(category, -1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=pl.Int16).alias(column)
    )

for col in feature_cat + ['symbol_id', 'time_id']:
    train_original = encode_column(train_original, col, category_mappings[col])
    valid_original = encode_column(valid_original, col, category_mappings[col])