In [19]:
import os
from typing import List, Optional, Tuple, Dict, Any

import pandas as pd
import polars as pl
import numpy as np
import time
from pathlib import Path
import random

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv

import kaggle_evaluation.jane_street_inference_server

In [20]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [3]:
class MultiStockGraphDataset(Dataset):
    
    def __init__(self, dataset: pl.LazyFrame, adjacency_matrices: np.ndarray, stock_ids: list):
        self.dataset = dataset
        self.adjacency_matrices = adjacency_matrices
        self.stock_ids = stock_ids
        self.num_stocks = len(self.stock_ids)
        self.dataset_len = self.dataset.select(['date_id', 'time_id']).unique().shape[0]
        self._load()
    
    def _load(self):
        all_combinations = (
            self.dataset.select(['date_id', 'time_id'])
            .unique()
            .join(pl.DataFrame({'symbol_id': self.stock_ids}, 
                               schema={'symbol_id': pl.Int8}), how="cross")
        )
        feature_cols = [f'feature_{i:02d}' for i in range(79)]
        self.batch = (
            all_combinations
            .join(self.dataset.with_columns(pl.lit(1).alias('mask')), 
                  on=['date_id', 'time_id', 'symbol_id'], how="left")
            .fill_null(0)  # fill all columns with 0 for missing stocks (including the mask)
            .sort(['date_id', 'time_id', 'symbol_id'])
        )
        # num_stocks rows for each date and time
        self.X = self.batch.select(feature_cols).to_numpy().astype(np.float32)
        self.y = self.batch.select(['responder_6']).to_numpy().flatten().astype(np.float32)
        self.s = self.batch.select(['symbol_id']).to_numpy().flatten().astype(np.int32)
        self.date_ids = self.batch.select(['date_id']).to_numpy().flatten()
        self.masks = self.batch.select(['mask']).to_numpy().flatten() == 0
        self.weights = self.batch.select(['weight']).to_numpy().flatten().astype(np.float32)
    
    def __len__(self):
        return self.dataset_len
    
    def __getitem__(self, idx):
        start_row = idx * self.num_stocks
        features = self.X[start_row:start_row+self.num_stocks, :]
        targets = self.y[start_row:start_row+self.num_stocks]
        masks = self.masks[start_row:start_row+self.num_stocks]
        weights = self.weights[start_row:start_row+self.num_stocks]
        symbols = self.s[start_row:start_row+self.num_stocks]

        date_id = self.date_ids[start_row]
        adj_matrix = self.adjacency_matrices[date_id]
        
        return (
            torch.tensor(features), 
            torch.tensor(targets), 
            torch.tensor(masks), 
            torch.tensor(weights), 
            torch.tensor(symbols),
            torch.tensor(adj_matrix, dtype=torch.int)
        )

In [4]:
class TransposeLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input: Tensor) -> Tensor:
        return input.transpose(1, 2)

In [5]:
class WeightedMSELoss(nn.Module):
    def __init__(self):
        super(WeightedMSELoss, self).__init__()
    
    def forward(self, predictions: Tensor, targets: Tensor, weights: Tensor) -> Tensor:
        squared_diff = (predictions - targets) ** 2
        weighted_squared_diff = weights * squared_diff
        return weighted_squared_diff.sum() / weights.sum()

In [6]:
class GraphConvEncoderLayer(nn.Module):
    def __init__(self, hidden_dim, dim_feedforward_mult=4, dropout_rate=0.1):
        super(GraphConvEncoderLayer, self).__init__()
        
        self.graph_conv = GCNConv(
            in_channels=hidden_dim, 
            out_channels=hidden_dim
        )

        self.feedforward = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * dim_feedforward_mult),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim * dim_feedforward_mult, hidden_dim)
        )

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index):
        batch_size, num_nodes, num_features = x.size()

        residual = x
        x = x.reshape(batch_size * num_nodes, num_features)
        x = self.graph_conv(x, edge_index)
        x = x.reshape(batch_size, num_nodes, num_features)        
        x = self.dropout1(x) + residual
        # x = self.norm1(x)

        residual = x
        x = self.feedforward(x)
        x = self.dropout2(x) + residual
        # x = self.norm2(x)

        return x

In [7]:
class GraphConvEncoder(nn.Module):
    def __init__(self, hidden_dim, num_layers, dim_feedforward_mult=4, dropout_rate=0.1):
        super(GraphConvEncoder, self).__init__()
        self.layers = nn.ModuleList([
            GraphConvEncoderLayer(
                hidden_dim=hidden_dim,
                dim_feedforward_mult=dim_feedforward_mult,
                dropout_rate=dropout_rate
            ) for _ in range(num_layers)
        ])

    def forward(self, x, adj):
        batch_size, num_nodes, _ = x.size()

        edge_indices = []
        for batch_idx in range(batch_size):
            adj_matrix = adj[batch_idx]
            src, tgt = torch.nonzero(adj_matrix, as_tuple=True)
            src = src + batch_idx * num_nodes
            tgt = tgt + batch_idx * num_nodes
            edge_indices.append(torch.stack([src, tgt], dim=0))

        edge_index = torch.cat(edge_indices, dim=1).to(x.device)
        
        for layer in self.layers:
            x = layer(x, edge_index)
        return x

In [8]:
class StockGCNModel(nn.Module):
    def __init__(
        self,
        input_features,
        hidden_dim=64,
        output_dim=1,
        num_layers=2,
        num_stocks=39,
        embedding_dim=16,
        use_embeddings=False,
        dropout_rate=0.2,
        dim_feedforward_mult=4,
    ):
        super(StockGCNModel, self).__init__()

        self.use_embeddings = use_embeddings

        self.init_layers = nn.Sequential(
            # TransposeLayer(),
            # nn.BatchNorm1d(input_features),
            # TransposeLayer(),
            nn.Dropout(dropout_rate),
        )
        self.feature_projector = []
        if use_embeddings:
            self.feature_projector.append(nn.Linear(input_features + embedding_dim, hidden_dim))
            self.embedding_layer = nn.Embedding(num_stocks, embedding_dim)
        else:
            self.feature_projector.append(nn.Linear(input_features, hidden_dim))
        self.feature_projector += [
            # TransposeLayer(),
            # nn.BatchNorm1d(hidden_dim),
            # TransposeLayer(),
            nn.Dropout(dropout_rate),
        ]
        self.feature_projector = nn.Sequential(*self.feature_projector)

        self.encoder = GraphConvEncoder(
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dim_feedforward_mult=dim_feedforward_mult,
            dropout_rate=dropout_rate
        )

        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            # TransposeLayer(),
            # nn.BatchNorm1d(hidden_dim),
            # TransposeLayer(),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x, symbols, adj):
        batch_size, num_stocks, num_features = x.size()

        x = self.init_layers(x)
        if self.use_embeddings:
            stock_embeddings = self.embedding_layer(symbols)
            x = torch.cat([x, stock_embeddings], dim=-1)
        x = self.feature_projector(x)
        x = self.encoder(x, adj)

        output = self.predictor(x)
        return 5 * torch.tanh(output)

In [9]:
def compute_correlation_from_pivot(pivot_df):
    correlations = (
        pivot_df
        .drop(['date_id', 'time_id'])
        .corr()
        .fill_nan(0).fill_null(0)
    )
    order = [int(i) for i in correlations.columns]
    new_order = np.argsort(order).tolist()
    columns_order = [str(i) for i in np.array(order)[new_order].tolist()]
    correlations = correlations[columns_order]
    correlations = correlations[new_order, :]
    return np.abs(correlations.to_numpy())

In [10]:

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
    
device = 'cpu'

In [None]:
save_path = '/home/lorecampa/projects/jane_street_forecasting/dataset/models/graph_conv/model_4_7_norm_nolayer.pth'
model = StockGCNModel(
    input_features=79,
    output_dim=1,
    num_layers=1,
    dropout_rate=0.2,
    dim_feedforward_mult=4,
    hidden_dim=64)
model.load_state_dict(torch.load(save_path, weights_only=True, map_location=torch.device(device)))
model = model.to(device)
inference_model = StockGCNModel(
    input_features=79,
    output_dim=1,
    num_layers=1,
    dropout_rate=0.2,
    dim_feedforward_mult=4,
    hidden_dim=64)
inference_model.load_state_dict(torch.load(save_path, weights_only=True, map_location=torch.device(device)))
inference_model = inference_model.to(device)

loss_fn = WeightedMSELoss()
inference_model.eval()

StockGCNModel(
  (init_layers): Sequential(
    (0): Dropout(p=0.2, inplace=False)
  )
  (feature_projector): Sequential(
    (0): Linear(in_features=79, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
  )
  (encoder): GraphConvEncoder(
    (layers): ModuleList(
      (0): GraphConvEncoderLayer(
        (graph_conv): GCNConv(64, 64)
        (feedforward): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): SiLU()
          (2): Dropout(p=0.2, inplace=False)
          (3): Linear(in_features=256, out_features=64, bias=True)
        )
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (predictor): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): SiLU()
    (2): Dropout(p=0.2, inplace=False)


In [12]:
from prj.config import DATA_DIR

FINE_TUNING = True
N_EPOCHS_PER_TRAIN_MAX = 10
BATCH_SIZE = 2048
OLD_DATA_FRACTION = 0.1
FEATURE_COLS = [f'feature_{i:02d}' for i in range(79)]
GRADIENT_CLIPPING = 10
EARLY_STOPPING_DAYS = 7
ES_PATIENCE = 7
TRAIN_EVERY = 23
date_idx = -1
epoch = None
best_epoch = None
best_score = None
train_dataloader, val_dataloader, train_iterator, val_iterator = None, None, None, None
save_path = './best_model.pth'
acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
start_train = False
is_training_loop = False

gradient_clipping_decay = 0.5
gradient_clipping = GRADIENT_CLIPPING * gradient_clipping_decay
lr_decay = 0.7
lr = 1e-5
optimizer = None

TIME_LIMIT = 30
MAX_FINE_TUNING_TIME_LIMIT = time.time() + 60 * 60 * 8 # after 8 hours, stop all the online learning

FEATURES = [f'feature_{i:02d}' for i in range(79)]
COLUMNS = FEATURES + ['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']
BASE_PATH = DATA_DIR / 'train.parquet'

In [13]:
from prj.data.data_loader import DataLoader as PrjDataLoader
from prj.data.data_loader import DataConfig as PrjDataConfig

config = PrjDataConfig(**{})
loader = PrjDataLoader(data_dir=DATA_DIR, config=config)
start, end = 1360, 1529
start, end = 1360, 1370

# start, end = 1190, 1200
test_ds = loader.load(start-1, end).collect()

to_remove_symbols = np.random.choice(test_ds['symbol_id'].unique().to_numpy(), size=5, replace=False)
test_ds = test_ds.filter(~pl.col('symbol_id').is_in(to_remove_symbols))


y_test = test_ds.filter(pl.col('date_id').ge(start))['responder_6'].to_numpy().flatten()
w_test = test_ds.filter(pl.col('date_id').ge(start))['weight'].to_numpy().flatten()

to_remove_symbols

2025-01-12 16:59:49.450472: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-12 16:59:49.450507: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-12 16:59:49.451849: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-12 16:59:49.458772: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


array([33, 36,  4, 13, 30], dtype=int8)

In [14]:
def standardize(df: pl.LazyFrame, data_stats_dict: dict, features: list[str], eps=1e-9) -> pl.DataFrame:
    cat_features = ['feature_09', 'feature_10', 'feature_11']
    features = [f for f in features if f not in cat_features]
        
    eps = 1e-8
    return df.with_columns(
        [(pl.col(col).sub(data_stats_dict[f'{col}_mean'])).truediv(data_stats_dict[f'{col}_std']).add(eps) for col in features]
    ).with_columns(
        pl.col(f).truediv(data_stats_dict[f'{f}_max']) for f in cat_features
    )
    
FEATURES = [f'feature_{i:02d}' for i in range(79)]

cat_features = ['feature_09', 'feature_10', 'feature_11']

data_stats = test_ds.filter(pl.col('date_id').ge(start)).select(
    pl.col(FEATURES).mean().name.suffix('_mean'),
    pl.col(FEATURES).std().name.suffix('_std'),
    pl.col(cat_features).max().name.suffix('_max')
)

data_stats_dict = data_stats.to_dicts()[0]

test_ds = standardize(test_ds, data_stats_dict, FEATURES)

test_ds = test_ds.fill_nan(None).fill_null(strategy='zero')

test_ds

date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,…,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8,partition_id
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64
1359,0,0,2.136611,2.422234,-0.516497,2.192171,2.539526,2.473287,-0.481699,0.879098,0.548299,-0.212614,0.134146,0.583333,0.141002,-1.081436,2.52483,-0.062748,0.0,-0.373516,0.0,-1.620773,-1.231624,0.725762,-0.432059,0.711123,0.526897,-0.203114,-0.199322,1.256969,0.943244,0.296373,0.293305,-0.150386,-0.865181,0.0,…,0.0,0.0,-0.533592,0.0,-1.739991,2.93767,0.0,1.811655,1.190002,3.851669,0.088155,0.524211,-1.261496,-1.410816,-1.21496,-0.943442,1.986899,-0.019753,-1.217643,2.425279,-0.080253,0.0,0.0,-0.347528,-0.317906,-0.45527,-0.479321,-0.194527,-0.199862,-0.05801,-0.391492,0.473886,-0.220418,-0.337185,0.654231,-0.316881,7
1359,0,1,1.413127,2.944739,-0.449057,0.310592,1.332943,2.228261,-0.308488,0.851587,0.574776,-0.531986,0.134146,0.583333,0.141002,-0.77292,4.313714,0.317823,0.0,0.207595,0.0,-1.617025,-2.119984,-0.025225,0.074288,0.293362,0.180437,-0.229156,-0.318933,-1.677353,-0.136319,0.463709,-0.238748,-1.099758,0.058065,0.0,…,0.0,0.0,-0.623788,0.0,-1.050706,2.110111,0.0,0.924037,0.524254,3.851669,-0.388955,0.63682,-0.88224,-1.330283,-1.439945,-1.012512,4.215051,0.343801,-1.119219,2.736124,0.043281,0.0,0.0,0.207296,0.19957,-0.278074,-0.260069,-0.599824,-0.319232,-0.802312,-0.3617,0.502022,-0.173755,-0.201655,0.720056,-0.09412,7
1359,0,2,1.518645,2.242425,-0.404063,0.465731,1.737105,2.552789,-0.255517,1.035678,0.58726,-0.124289,0.987805,0.166667,0.109462,-1.615565,0.445302,-0.472148,0.0,1.143563,0.0,-2.157037,-0.79123,-0.275964,-1.115692,-0.344736,-0.750789,-1.104792,-1.28049,0.85386,0.512554,-0.16115,-0.635355,-0.640427,-0.962834,0.0,…,0.0,0.0,-0.922758,0.0,-0.712226,2.420122,0.0,4.606506,1.734409,3.851669,3.454577,0.689913,1.884014,-1.826468,-1.584165,-1.112662,0.124197,-0.610694,-1.114943,1.06576,-0.701126,0.0,0.0,0.841072,1.343543,-0.085355,-0.065144,-0.302588,-0.485465,-0.370119,0.722382,0.210889,0.397347,1.606368,0.806006,1.240832,7
1359,0,3,1.342317,1.995452,0.197798,1.008745,1.815249,2.355587,-0.270152,0.503513,0.478007,-0.266015,0.04878,0.25,0.020408,-1.091128,1.52646,-0.205106,0.0,-0.303416,0.0,-1.541598,-2.453977,1.462868,-0.13117,-0.486692,-0.577188,1.071896,0.931131,-0.362029,-0.696776,-0.427228,-0.639676,-0.538232,-0.184962,0.0,…,0.0,0.0,-2.398488,0.0,-1.530771,1.76038,0.0,-1.498801,-1.352574,3.851669,1.643704,-0.139378,-0.024075,-1.445658,-2.021217,-1.145029,4.12744,0.186234,-0.697002,0.990352,-0.725536,0.0,0.0,0.47503,0.359566,-0.376677,-0.389206,-0.352015,-0.301331,-0.844495,-0.188697,0.510724,-1.134449,-0.159172,0.968843,-1.236539,7
1359,0,5,1.469287,1.636272,-0.242845,0.897076,1.409509,2.331006,-0.437219,1.261788,0.440067,0.116636,0.02439,0.833333,0.317254,-1.112629,1.214978,-0.491099,0.0,-0.28322,0.0,-1.099059,-1.299383,-0.552897,-0.493447,0.060737,0.488685,-1.262748,-1.905434,-1.080464,0.217649,0.38157,-0.587423,0.065154,-0.338904,0.0,…,0.0,0.0,-2.845521,0.0,-1.888879,1.667516,0.0,-1.724386,-0.822925,3.851669,-0.497117,-0.124717,-0.479826,-1.594099,-1.410835,-0.977505,0.711486,-0.3419,-1.320605,1.337668,-0.279872,0.0,0.0,-0.173756,-0.119563,-0.442301,-0.354211,-0.138384,-0.043467,0.178531,0.067827,1.043191,-0.044119,0.325286,0.970462,-0.442468,7
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1370,967,32,2.392336,-0.970513,-0.689064,-0.80464,0.188512,0.977977,0.32006,-0.866912,-0.308539,-0.301335,0.317073,1.0,0.293135,-0.457903,1.422804,0.191951,0.970012,0.6943,0.789821,-1.27194,0.420144,-2.020144,3.422769,-0.135455,1.447991,2.255607,1.267027,2.376685,-1.019575,-2.066507,1.982367,0.982517,2.810244,-0.562868,…,-1.102517,-0.174381,-0.910987,-1.011996,1.56789,-0.154542,0.295458,-1.430611,-0.631919,1.227449,-0.626681,-1.369254,0.151332,-1.235409,0.412719,-0.496586,2.533943,0.339227,-0.35601,-0.219466,-0.056391,-0.544337,-0.396218,-0.332147,-0.177569,-0.544805,-0.349197,-1.771054,-0.525312,-1.946798,0.241072,0.181044,0.065421,0.375176,0.137428,0.428109,8
1370,967,34,1.911111,-0.73649,-0.752915,-0.448448,-2.091238,1.174098,0.575665,-1.369187,-0.839989,-0.412729,0.512195,0.416667,0.278293,0.75222,0.589476,0.550262,-0.65345,-0.386201,0.178793,0.189882,0.992539,1.233011,-0.451411,-0.0366,-0.15942,-0.965048,-0.848866,-0.814554,0.773128,1.620654,-0.620883,-0.46228,-0.668776,-2.073051,…,0.260011,-1.327027,1.263522,-0.473861,-0.038519,-0.612673,-1.371018,-0.01813,-0.632441,1.227449,-0.92545,-1.231382,-1.325025,-0.253767,0.609254,0.911673,0.715444,0.714629,0.888505,0.336581,0.374347,-0.201689,-0.212223,-0.386479,-0.398518,-0.311951,-0.344617,0.781782,0.389013,0.911636,0.773548,0.619074,0.137061,0.169996,0.167079,0.349549,8
1370,967,35,2.430476,-0.390004,-0.27989,-1.174875,-1.348503,0.674799,0.373151,-0.440463,-0.207015,-0.349554,0.134146,0.583333,0.141002,-1.007735,-0.063938,-0.620457,0.788178,1.076969,1.082795,2.103139,-1.174479,-0.967522,2.896023,0.810477,1.915284,2.106562,2.531167,0.280818,-0.315183,-1.234272,2.732495,1.034769,3.119164,-1.140089,…,-0.90763,1.262501,-0.909521,0.763333,-0.589021,-0.821753,0.211839,-0.031535,0.212399,1.227449,0.197552,1.550038,0.735273,1.904985,-2.053403,-0.880794,-0.376071,-0.642854,-0.819059,0.326489,-0.462855,-0.437922,-0.757116,-0.484134,-0.847847,-0.550796,-0.756279,-0.96301,-0.068327,-0.703112,-0.115929,-0.067838,0.090418,0.113785,0.080935,0.216028,8
1370,967,37,0.76466,0.229862,-0.130822,-0.366252,-1.395082,0.911715,0.4517,-1.112004,-0.46962,-0.145647,0.414634,0.333333,0.397032,2.261843,3.228687,1.606249,-0.026959,3.196151,1.266449,0.302986,2.11229,-2.730362,-0.563215,-1.667471,-1.151794,0.291971,0.468838,0.786002,-0.820648,-1.309597,-0.584387,0.410452,-0.596625,-2.088345,…,1.596131,-0.346931,-2.968961,-1.331302,1.357634,-0.586922,-0.455652,4.387135,0.839644,1.227449,-0.107197,1.069073,-0.047615,-0.130344,1.360324,1.553387,0.90202,0.810563,2.457252,7.613956,1.643733,0.941442,0.807299,3.846327,5.864861,1.441216,1.966457,-0.06673,-1.519372,-2.203012,0.209954,0.102338,-1.303606,0.428269,0.187324,0.527199,8


In [15]:
current_day_data : pl.DataFrame | None = None

old_dataset = loader.load(start-60, start-1)\
    .select(COLUMNS)\
    .collect() \
    .filter(~pl.col('symbol_id').is_in(to_remove_symbols))

old_dataset_stocks = old_dataset['symbol_id'].unique().sort().to_list()

old_dataset = standardize(old_dataset, data_stats_dict, FEATURES).fill_nan(None).fill_null(strategy='zero')


last_train_date = start-1
new_dataset = old_dataset.filter(pl.col('date_id') > start-30)
if OLD_DATA_FRACTION > 0:
    old_dataset = old_dataset.filter(pl.col('date_id') <= start-30)
else:
    old_dataset = None

In [16]:
CORRELATION_THR = 0.1
WINDOW_LEN = 2 if not os.getenv('KAGGLE_IS_COMPETITION_RERUN') else 7
past_responders_pivot: pl.DataFrame | None = None
current_date_id = -1
current_stock_ids = old_dataset_stocks
num_dates = 0

adjacency_matrices = np.load('/home/lorecampa/projects/jane_street_forecasting/dataset/sources/graph_conv_torch/adjacency_matrices.npy')[:last_train_date+1, :, :]
current_corr_matrix = np.load('/home/lorecampa/projects/jane_street_forecasting/dataset/sources/graph_conv_torch/correlations.npy')[last_train_date, :, :]

In [17]:
def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
    global BATCH_SIZE, GRADIENT_CLIPPING, N_EPOCHS_PER_TRAIN_MAX, TRAIN_EVERY
    global EARLY_STOPPING_DAYS, ES_PATIENCE, OLD_DATA_FRACTION
    global MAX_FINE_TUNING_TIME_LIMIT, TIME_LIMIT, FINE_TUNING
    global CORRELATION_THR, WINDOW_LEN
    global date_idx, new_dataset, old_dataset, current_day_data, last_train_date
    global train_dataloader, val_dataloader, train_iterator, val_iterator, adjacency_matrices
    global acc_metrics, save_path, start_train, is_training_loop, epoch, best_epoch, best_score
    global current_stock_ids, current_date_id, current_corr_matrix, num_dates, past_responders_pivot
    global gradient_clipping_decay, gradient_clipping, lr, lr_decay, optimizer

    initial_time = time.time()
    FINE_TUNING = FINE_TUNING & (initial_time < MAX_FINE_TUNING_TIME_LIMIT)
    start_train = start_train if FINE_TUNING else False

    if lags is not None:
        # print(f"Date id: {test['date_id'].min()}")
        # new date_id
        lags_ = lags.select(
            pl.col('date_id').sub(1),
            pl.col(['time_id', 'symbol_id']),
            pl.col('responder_6_lag_1').alias('responder_6'),
        )
        if current_day_data is not None:
            # print(current_day_data, new_dataset['date_id'].unique().to_list(), old_dataset['date_id'].unique().to_list())
            current_day_data = current_day_data.join(lags_, on=['date_id', 'time_id', 'symbol_id'], 
                                                     how='left').fill_null(0)
            current_day_data = current_day_data.select(COLUMNS)
            # replacing date id to ensure that adjacency_matrices array is consistent
            current_day_data = (
                current_day_data
                .drop('date_id')
                .with_columns(pl.lit(last_train_date + date_idx + 1).cast(pl.Int16).alias('date_id'))
                .select(COLUMNS)
            )

            new_dataset = new_dataset.vstack(current_day_data)
            last_adj = current_corr_matrix.copy()
            
            last_adj[np.arange(len(current_stock_ids)), np.arange(len(current_stock_ids))] = 0
            last_adj = (last_adj > CORRELATION_THR).astype(np.int32)[np.newaxis, :, :]
            adjacency_matrices = np.concatenate([adjacency_matrices, last_adj], axis=0)
            
        current_day_data = test

        all_combinations = (
            lags_.select(['date_id', 'time_id'])
            .unique()
            .join(pl.DataFrame({'symbol_id': current_stock_ids}, 
                               schema={'symbol_id': pl.Int8}), how="cross")
        )
        
        pivot_lags = (
            all_combinations
            .join(lags_, on=['date_id', 'time_id', 'symbol_id'], how="left")
            .fill_null(0)
            .sort(['date_id', 'time_id', 'symbol_id'])
            .pivot(index=['date_id', 'time_id'], values='responder_6', on='symbol_id')
            .fill_null(0)
        )
        
        past_responders_pivot = (
            pl.concat([past_responders_pivot, pivot_lags], how='diagonal')
            .filter(pl.col('date_id') >= current_date_id - WINDOW_LEN - 1)
        ) if past_responders_pivot is not None else pivot_lags
        
        if num_dates >= WINDOW_LEN:
            current_corr_matrix = compute_correlation_from_pivot(past_responders_pivot)

        if FINE_TUNING and not start_train:
            start_train = (date_idx+1) % TRAIN_EVERY == 0
            if start_train:
                print('Starting new fine tuning')
                model.eval()
                max_date = new_dataset.select(pl.col('date_id').max()).item()
                new_validation_dataset = new_dataset.filter(pl.col('date_id') > max_date - EARLY_STOPPING_DAYS)
                new_training_dataset = new_dataset.filter(pl.col('date_id') <= max_date - EARLY_STOPPING_DAYS)
                train_days = new_training_dataset['date_id'].unique().sort().to_list()
                val_days = new_validation_dataset['date_id'].unique().sort().to_list()
                print(f'Training days: {train_days}')
                print(f'Validation days: {val_days}')
                
                if OLD_DATA_FRACTION > 0:
                    old_data_len = OLD_DATA_FRACTION * new_training_dataset.shape[0] / (1 - OLD_DATA_FRACTION)
                    time_factions = min(1, old_data_len / old_dataset.shape[0])
                    old_date_times = old_dataset.select(['date_id', 'time_id']).unique().sample(fraction=time_factions)
                                        
                    old_training_dataset = old_dataset.join(old_date_times, on=['date_id', 'time_id'], how='inner')
                    
                    print(f'Old training days: {old_training_dataset["date_id"].unique().to_list()}')
                    
                    train_dataloader = MultiStockGraphDataset(pl.concat([old_training_dataset, new_training_dataset]), adjacency_matrices.copy(), current_stock_ids)
                    val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                else:
                    train_dataloader = MultiStockGraphDataset(new_training_dataset, adjacency_matrices.copy(), current_stock_ids)
                    val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                
                
                
                
                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
                train_dataloader = DataLoader(train_dataloader, shuffle=True, batch_size=BATCH_SIZE, num_workers=0)
                val_dataloader = DataLoader(val_dataloader, shuffle=False, batch_size=2048, num_workers=0)
                val_iterator = iter(val_dataloader)
                acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                is_training_loop = False
                epoch = -1
                best_epoch = -1
                best_score = -1e10

                if OLD_DATA_FRACTION > 0:
                    max_new_date_id = new_training_dataset['date_id'].max()
                    old_dataset = old_dataset.vstack(new_training_dataset).filter(
                        pl.col('date_id').is_between(max_new_date_id - 30, max_new_date_id)
                    )
                    
                new_dataset = new_validation_dataset
                
        date_idx += 1
    else:
        current_day_data = current_day_data.vstack(test)
        
    if FINE_TUNING:
        while start_train and time.time() - initial_time < TIME_LIMIT:
            if is_training_loop:
                try:
                    batch = next(train_iterator)
                except StopIteration:
                    model.eval()
                    val_iterator = iter(val_dataloader)
                    acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                    is_training_loop = False
                    continue
        
                x, targets, m, w, s, A = batch
                optimizer.zero_grad()
                y_out = model.forward(x.to(device), s.to(device), A.to(device)).squeeze()
                loss = loss_fn(y_out, targets.to(device), w.to(device))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                optimizer.step()
                
            else:
                try:
                    batch = next(val_iterator)
                except StopIteration:
                    score = 1 - acc_metrics['ss_res'] / acc_metrics['ss_tot']
                    print(f'Epoch {epoch} completed with score {score}')
                    epoch += 1
                    if score > best_score:
                        torch.save(model.state_dict(), save_path)
                        inference_model.load_state_dict(torch.load(save_path, weights_only=True))
                        inference_model.to(device)
                        inference_model.eval()
                        best_epoch = epoch
                        best_score = score
                    if epoch - best_epoch >= ES_PATIENCE or epoch == N_EPOCHS_PER_TRAIN_MAX:
                        print(f'Stopping after {epoch} epochs')
                        print(f'Completed Fine Tuning at time {test.select(pl.col("time_id").first()).item()}')
                        model.load_state_dict(torch.load(save_path, weights_only=True))
                        model.to(device)
                        model.eval()
                        start_train = False
                        lr *= lr_decay
                        gradient_clipping *= gradient_clipping_decay
                        break
                    model.train()
                    train_iterator = iter(train_dataloader)
                    is_training_loop = True
                    continue

                x, targets, m, w, s, A = batch
                with torch.no_grad():
                    y_out = model(x.to(device), s.to(device), A.to(device)).squeeze()
                w = w.to(device)
                targets = targets.to(device)
                acc_metrics['ss_res'] += (w * (y_out - targets) ** 2).sum().cpu()
                acc_metrics['ss_tot'] += (w * (targets ** 2)).sum().cpu()

    if test.select(pl.col('is_scored').cast(pl.Int8).first()).item() > 0:
        test_ = test.fill_nan(None).fill_null(strategy='zero')
        predict_df = (
            test_.select(['date_id', 'time_id'])
                .unique()
                .join(pl.DataFrame({'symbol_id': list(range(39))}, 
                                schema={'symbol_id': pl.Int8}), how="cross")
                .join(test_.with_columns(pl.lit(1).alias('mask')), 
                    on=['date_id', 'time_id', 'symbol_id'], how="left")
                .fill_null(0)  # fill all columns with 0 for missing stocks (including the mask)
                .sort(['date_id', 'time_id', 'symbol_id'])
        )
        valid_data = predict_df.select(['mask']).to_numpy().flatten() == 1
        x = torch.tensor(predict_df.select([f'feature_{i:02d}' for i in range(79)]).to_numpy().reshape(-1, 39, 79), dtype=torch.float32).to(device)
        s = torch.tensor(predict_df.select(['symbol_id']).to_numpy().flatten().reshape(-1, 39).astype(np.int32)).to(device)
        # adj = adjacency_matrices[predict_df.select(pl.col('date_id').first()).item()][np.newaxis, :, :]
        # adj = torch.tensor(adj, dtype=torch.int, device=device).repeat(x.shape[0], 1, 1)
        
        adj_matrix = current_corr_matrix.copy()
        adj_matrix[np.arange(39), np.arange(39)] = 0
        adj_matrix = (adj_matrix > CORRELATION_THR).astype(np.int32)
        adj_matrix = torch.tensor(adj_matrix, dtype=torch.int, device=device).unsqueeze(0).repeat(x.shape[0], 1, 1)
        with torch.no_grad():
            preds = inference_model(x, s, adj_matrix).cpu().numpy().flatten()[valid_data]
        predict_df = predict_df.filter(pl.col('mask') == 1).with_columns(pl.Series(preds).alias('responder_6'))
        
        predictions = test.join(predict_df, on=['time_id', 'symbol_id'], how='left').select(['row_id', 'responder_6'])
    else:
        predictions = test.select('row_id', pl.Series(np.zeros(test.shape[0])).alias('responder_6'))
    predictions = predictions.with_columns(pl.col('responder_6').cast(pl.Float32))

    if isinstance(predictions, pl.DataFrame):
        assert predictions.columns == ['row_id', 'responder_6']
    elif isinstance(predictions, pd.DataFrame):
        assert (predictions.columns == ['row_id', 'responder_6']).all()
    else:
        raise TypeError('The predict function must return a DataFrame')
    
    assert len(predictions) == len(test)

    return predictions

In [None]:
from nbs.tabm.predict_tabm import predict_tabm
from prj.utils import online_iterator, online_iterator_daily
from sklearn.metrics import r2_score

y_hat_iterator = []
for i, (test, lags) in enumerate(online_iterator_daily(test_ds, show_progress=True)):
    # print(len(test))
    res = predict(test, lags)
    y_hat_iterator.append(res['responder_6'].to_numpy())
    
y_hat_iterator = np.concatenate(y_hat_iterator) if len(y_hat_iterator) > 0 else None


score = r2_score(y_test, y_hat_iterator, sample_weight=w_test)

score

  0%|          | 0/11 [00:00<?, ?it/s]

Starting new fine tuning
Training days: [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352]
Validation days: [1353, 1354, 1355, 1356, 1357, 1358, 1359]
Old training days: [1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330]


  0%|          | 0/11 [00:01<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [41]:
# 0.013081669807434082


In [42]:
# 0.011371493339538574