In [1]:
import os
from typing import List, Optional, Tuple, Dict, Any

import pandas as pd
import polars as pl
import numpy as np
import time
from pathlib import Path
import random

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GCNConv

import kaggle_evaluation.jane_street_inference_server

In [2]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [3]:
class MultiStockGraphDataset(Dataset):
    
    def __init__(self, dataset: pl.LazyFrame, adjacency_matrices: np.ndarray, stock_ids: list):
        self.dataset = dataset
        self.adjacency_matrices = adjacency_matrices
        self.stock_ids = stock_ids
        self.num_stocks = len(self.stock_ids)
        self.dataset_len = self.dataset.select(['date_id', 'time_id']).unique().shape[0]
        self._load()
    
    def _load(self):
        all_combinations = (
            self.dataset.select(['date_id', 'time_id'])
            .unique()
            .join(pl.DataFrame({'symbol_id': self.stock_ids}, 
                               schema={'symbol_id': pl.Int8}), how="cross")
        )
        feature_cols = [f'feature_{i:02d}' for i in range(79)]
        self.batch = (
            all_combinations
            .join(self.dataset.with_columns(pl.lit(1).alias('mask')), 
                  on=['date_id', 'time_id', 'symbol_id'], how="left")
            .fill_null(0)  # fill all columns with 0 for missing stocks (including the mask)
            .sort(['date_id', 'time_id', 'symbol_id'])
        )
        # num_stocks rows for each date and time
        self.X = self.batch.select(feature_cols).to_numpy().astype(np.float32)
        self.y = self.batch.select(['responder_6']).to_numpy().flatten().astype(np.float32)
        self.s = self.batch.select(['symbol_id']).to_numpy().flatten().astype(np.int32)
        self.date_ids = self.batch.select(['date_id']).to_numpy().flatten()
        self.masks = self.batch.select(['mask']).to_numpy().flatten() == 0
        self.weights = self.batch.select(['weight']).to_numpy().flatten().astype(np.float32)
    
    def __len__(self):
        return self.dataset_len
    
    def __getitem__(self, idx):
        start_row = idx * self.num_stocks
        features = self.X[start_row:start_row+self.num_stocks, :]
        targets = self.y[start_row:start_row+self.num_stocks]
        masks = self.masks[start_row:start_row+self.num_stocks]
        weights = self.weights[start_row:start_row+self.num_stocks]
        symbols = self.s[start_row:start_row+self.num_stocks]

        date_id = self.date_ids[start_row]
        adj_matrix = self.adjacency_matrices[date_id]
        
        return (
            torch.tensor(features), 
            torch.tensor(targets), 
            torch.tensor(masks), 
            torch.tensor(weights), 
            torch.tensor(symbols),
            torch.tensor(adj_matrix, dtype=torch.int)
        )

In [4]:
class TransposeLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input: Tensor) -> Tensor:
        return input.transpose(1, 2)

In [5]:
class WeightedMSELoss(nn.Module):
    def __init__(self):
        super(WeightedMSELoss, self).__init__()
    
    def forward(self, predictions: Tensor, targets: Tensor, weights: Tensor) -> Tensor:
        squared_diff = (predictions - targets) ** 2
        weighted_squared_diff = weights * squared_diff
        return weighted_squared_diff.sum() / weights.sum()

In [6]:
class GraphConvEncoderLayer(nn.Module):
    def __init__(self, hidden_dim, dim_feedforward_mult=4, dropout_rate=0.1):
        super(GraphConvEncoderLayer, self).__init__()
        
        self.graph_conv = GCNConv(
            in_channels=hidden_dim, 
            out_channels=hidden_dim
        )

        self.feedforward = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * dim_feedforward_mult),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim * dim_feedforward_mult, hidden_dim)
        )

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index):
        batch_size, num_nodes, num_features = x.size()

        residual = x
        x = x.reshape(batch_size * num_nodes, num_features)
        x = self.graph_conv(x, edge_index)
        x = x.reshape(batch_size, num_nodes, num_features)        
        x = self.dropout1(x) + residual
        x = self.norm1(x)

        residual = x
        x = self.feedforward(x)
        x = self.dropout2(x) + residual
        x = self.norm2(x)

        return x

In [7]:
class GraphConvEncoder(nn.Module):
    def __init__(self, hidden_dim, num_layers, dim_feedforward_mult=4, dropout_rate=0.1):
        super(GraphConvEncoder, self).__init__()
        self.layers = nn.ModuleList([
            GraphConvEncoderLayer(
                hidden_dim=hidden_dim,
                dim_feedforward_mult=dim_feedforward_mult,
                dropout_rate=dropout_rate
            ) for _ in range(num_layers)
        ])

    def forward(self, x, adj):
        batch_size, num_nodes, _ = x.size()

        edge_indices = []
        for batch_idx in range(batch_size):
            adj_matrix = adj[batch_idx]
            src, tgt = torch.nonzero(adj_matrix, as_tuple=True)
            src = src + batch_idx * num_nodes
            tgt = tgt + batch_idx * num_nodes
            edge_indices.append(torch.stack([src, tgt], dim=0))

        edge_index = torch.cat(edge_indices, dim=1).to(x.device)
        
        for layer in self.layers:
            x = layer(x, edge_index)
        return x

In [8]:
class StockGCNModel(nn.Module):
    def __init__(
        self,
        input_features,
        hidden_dim=64,
        output_dim=1,
        num_layers=2,
        num_stocks=39,
        embedding_dim=16,
        use_embeddings=False,
        dropout_rate=0.2,
        dim_feedforward_mult=4,
    ):
        super(StockGCNModel, self).__init__()

        self.use_embeddings = use_embeddings

        self.init_layers = nn.Sequential(
            # TransposeLayer(),
            # nn.BatchNorm1d(input_features),
            # TransposeLayer(),
            nn.Dropout(dropout_rate),
        )
        self.feature_projector = []
        if use_embeddings:
            self.feature_projector.append(nn.Linear(input_features + embedding_dim, hidden_dim))
            self.embedding_layer = nn.Embedding(num_stocks, embedding_dim)
        else:
            self.feature_projector.append(nn.Linear(input_features, hidden_dim))
        self.feature_projector += [
            # TransposeLayer(),
            # nn.BatchNorm1d(hidden_dim),
            # TransposeLayer(),
            nn.Dropout(dropout_rate),
        ]
        self.feature_projector = nn.Sequential(*self.feature_projector)

        self.encoder = GraphConvEncoder(
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dim_feedforward_mult=dim_feedforward_mult,
            dropout_rate=dropout_rate
        )

        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            # TransposeLayer(),
            # nn.BatchNorm1d(hidden_dim),
            # TransposeLayer(),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x, symbols, adj):
        batch_size, num_stocks, num_features = x.size()

        x = self.init_layers(x)
        if self.use_embeddings:
            stock_embeddings = self.embedding_layer(symbols)
            x = torch.cat([x, stock_embeddings], dim=-1)
        x = self.feature_projector(x)
        x = self.encoder(x, adj)

        output = self.predictor(x)
        return 5 * torch.tanh(output)

In [9]:
def compute_correlation_from_pivot(pivot_df):
    correlations = (
        pivot_df
        .drop(['date_id', 'time_id'])
        .corr()
        .fill_nan(0).fill_null(0)
    )
    order = [int(i) for i in correlations.columns]
    new_order = np.argsort(order).tolist()
    columns_order = [str(i) for i in np.array(order)[new_order].tolist()]
    correlations = correlations[columns_order]
    correlations = correlations[new_order, :]
    return np.abs(correlations.to_numpy())

In [10]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

In [11]:
save_path = '/home/lorecampa/projects/jane_street_forecasting/dataset/models/graph_conv/model_3_7.pth'
model = StockGCNModel(
    input_features=79,
    output_dim=1,
    num_layers=1,
    dropout_rate=0.2,
    dim_feedforward_mult=4,
    hidden_dim=64)
model.load_state_dict(torch.load(save_path, weights_only=True, map_location=torch.device(device)))
model = model.to(device)
inference_model = StockGCNModel(
    input_features=79,
    output_dim=1,
    num_layers=1,
    dropout_rate=0.2,
    dim_feedforward_mult=4,
    hidden_dim=64)
inference_model.load_state_dict(torch.load(save_path, weights_only=True, map_location=torch.device(device)))
inference_model = inference_model.to(device)

loss_fn = WeightedMSELoss()
inference_model.eval()

StockGCNModel(
  (init_layers): Sequential(
    (0): Dropout(p=0.2, inplace=False)
  )
  (feature_projector): Sequential(
    (0): Linear(in_features=79, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
  )
  (encoder): GraphConvEncoder(
    (layers): ModuleList(
      (0): GraphConvEncoderLayer(
        (graph_conv): GCNConv(64, 64)
        (feedforward): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): SiLU()
          (2): Dropout(p=0.2, inplace=False)
          (3): Linear(in_features=256, out_features=64, bias=True)
        )
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (predictor): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): SiLU()
    (2): Dropout(p=0.2, inplace=False)


In [12]:
# from prj.config import DATA_DIR

# FINE_TUNING = True
# N_EPOCHS_PER_TRAIN_MAX = 10
# BATCH_SIZE = 2048
# OLD_DATA_FRACTION = 0.1
# FEATURE_COLS = [f'feature_{i:02d}' for i in range(79)]
# GRADIENT_CLIPPING = 10
# EARLY_STOPPING_DAYS = 7
# ES_PATIENCE = 7
# TRAIN_EVERY = 23
# date_idx = -1
# epoch = None
# best_epoch = None
# best_score = None
# train_dataloader, val_dataloader, train_iterator, val_iterator = None, None, None, None
# save_path = './best_model.pth'
# acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
# start_train = False
# is_training_loop = False

# gradient_clipping_decay = 0.5
# gradient_clipping = GRADIENT_CLIPPING * gradient_clipping_decay
# lr_decay = 0.7
# lr = 1e-5
# optimizer = None

# TIME_LIMIT = 30
# MAX_FINE_TUNING_TIME_LIMIT = time.time() + 60 * 60 * 8 # after 8 hours, stop all the online learning

# FEATURES = [f'feature_{i:02d}' for i in range(79)]
# COLUMNS = FEATURES + ['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']
# BASE_PATH = DATA_DIR / 'train.parquet'

In [13]:
from prj.data.data_loader import DataLoader as PrjDataLoader
from prj.data.data_loader import DataConfig as PrjDataConfig

config = PrjDataConfig(**{})
loader = PrjDataLoader(data_dir=DATA_DIR, config=config)
start, end = 1360, 1529
# start, end = 1360, 1370

# start, end = 1190, 1200
test_ds = loader.load(start-1, end).collect()

y_test = test_ds.filter(pl.col('date_id').ge(start))['responder_6'].to_numpy().flatten()
w_test = test_ds.filter(pl.col('date_id').ge(start))['weight'].to_numpy().flatten()

2025-01-10 11:58:26.987771: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-10 11:58:26.987805: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-10 11:58:26.989056: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 11:58:26.995704: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [15]:
def standardize(df: pl.LazyFrame | pl.DataFrame, data_stats_dict: dict, features: list[str], eps=1e-9) -> pl.LazyFrame:
    cat_features = ['feature_09', 'feature_10', 'feature_11']
    features = [f for f in features if f not in cat_features]
        
    eps = 1e-8
    return df.with_columns(
        [(pl.col(col).sub(data_stats_dict[f'{col}_mean'])).truediv(data_stats_dict[f'{col}_std']).add(eps) for col in features]
    ).with_columns(
        pl.col(f).truediv(data_stats_dict[f'{f}_max']) for f in cat_features
    )
    
data_stats_dict = {}

In [16]:
class GraphConvTrainerConfig:
    LAST_TRAIN_DATE = 1698
    INITIAL_ES_DAYS = 30
    CORRELATION_THR = 0.1
    WINDOW_LEN = 2 if not os.getenv('KAGGLE_IS_COMPETITION_RERUN') else 7
    
    FINE_TUNING = True
    N_EPOCHS_PER_TRAIN_MAX = 30
    BATCH_SIZE = 2048
    OLD_DATA_FRACTION = 0.1
    FEATURE_COLS = [f'feature_{i:02d}' for i in range(79)]
    GRADIENT_CLIPPING = 10
    EARLY_STOPPING_DAYS = 7
    ES_PATIENCE = 3
    TRAIN_EVERY = 23
    TIME_LIMIT = 30
    MAX_FINE_TUNING_TIME_LIMIT = time.time() + 60 * 60 * 8 # after 8 hours, stop all the online learning

    FEATURES = [f'feature_{i:02d}' for i in range(79)]
    COLUMNS = FEATURES + ['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']


class GraphConvTrainer:
    
    def __init__(self):
    
        self.config = GraphConvTrainerConfig()
        self.current_day_data : pl.DataFrame | None = None
        
        
        config = PrjDataConfig()
        loader = PrjDataLoader(data_dir=DATA_DIR, config=config)
        self.old_dataset = loader.load(self.config.LAST_TRAIN_DATE - 60, self.config.LAST_TRAIN_DATE)\
            .sort(['date_id', 'time_id', 'symbol_id']) \
            .fill_nan(None) \
            .fill_null(strategy='zero') \
            .select(self.config.COLUMNS) \
            .pipe(
                standardize,
                data_stats_dict=data_stats_dict,
                features=self.config.FEATURES,
            )
            
        self.new_dataset = self.old_dataset.filter(pl.col('date_id') > self.config.LAST_TRAIN_DATE - self.config.INITIAL_ES_DAYS).collect()
        if OLD_DATA_FRACTION > 0:
            self.old_dataset = self.old_dataset.filter(pl.col('date_id') <= self.config.LAST_TRAIN_DATE - self.config.INITIAL_ES_DAYS).collect()
        else:
            self.old_dataset = None
            
        
        past_responders_pivot: pl.DataFrame | None = None
        current_date_id = -1
        current_stock_ids = list(range(39))
        num_dates = 0

        adjacency_matrices = np.load('/kaggle/input/jane-street-2024-graph-computation/adjacency_matrices.npy')
        current_corr_matrix = np.load('/kaggle/input/jane-street-2024-graph-computation/correlations.npy')[-1, :, :]
        
        self.date_idx = -1
        self.epoch = None
        self.best_epoch = None
        self.best_score = None
        self.train_dataloader, self.val_dataloader, self.train_iterator, self.val_iterator = None, None, None, None
        self.save_path = '/kaggle/working/best_model.pth'
        self.acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
        self.start_train = False
        self.is_training_loop = False

        gradient_clipping_decay = 0.5
        self.gradient_clipping = self.config.GRADIENT_CLIPPING * gradient_clipping_decay
        self.lr_decay = 0.7
        self.lr = 1e-5
        self.optimizer = None
        
    def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
        
        initial_time = time.time()
        FINE_TUNING = FINE_TUNING & (initial_time < MAX_FINE_TUNING_TIME_LIMIT)
        start_train = start_train if FINE_TUNING else False

        test = test.pipe(
            standardize,
            data_stats_dict=data_stats_dict,
            features=FEATURES,
        )

        if lags is not None:
            print(f"Date id: {test['date_id'].min()}")
            lags_ = lags.select(
                pl.col('date_id').sub(1),
                pl.col(['time_id', 'symbol_id']),
                pl.col('responder_6_lag_1').alias('responder_6'),
            )
            if current_day_data is not None:
                current_day_data = current_day_data.join(lags_, on=['date_id', 'time_id', 'symbol_id'], 
                                                        how='left').fill_null(0)
                current_day_data = current_day_data.select(COLUMNS)
                current_day_data = (
                    current_day_data
                    .drop('date_id')
                    .with_columns(pl.lit(last_train_date + date_idx + 1).cast(pl.Int16).alias('date_id'))
                    .select(COLUMNS)
                )


                new_dataset = new_dataset.vstack(current_day_data)
                last_adj = current_corr_matrix.copy()
                
                last_adj[np.arange(len(current_stock_ids)), np.arange(len(current_stock_ids))] = 0
                last_adj = (last_adj > CORRELATION_THR).astype(np.int32)[np.newaxis, :, :]
                adjacency_matrices = np.concatenate([adjacency_matrices, last_adj], axis=0)
                
            current_day_data = test

            all_combinations = (
                lags_.select(['date_id', 'time_id'])
                .unique()
                .join(pl.DataFrame({'symbol_id': current_stock_ids}, 
                                schema={'symbol_id': pl.Int8}), how="cross")
            )
            
            pivot_lags = (
                all_combinations
                .join(lags_, on=['date_id', 'time_id', 'symbol_id'], how="left")
                .fill_null(0)
                .sort(['date_id', 'time_id', 'symbol_id'])
                .pivot(index=['date_id', 'time_id'], values='responder_6', on='symbol_id')
                .fill_null(0)
            )
            
            past_responders_pivot = (
                pl.concat([past_responders_pivot, pivot_lags], how='diagonal')
                .filter(pl.col('date_id') >= current_date_id - WINDOW_LEN - 1)
            ) if past_responders_pivot is not None else pivot_lags
            
            if num_dates >= WINDOW_LEN:
                current_corr_matrix = compute_correlation_from_pivot(past_responders_pivot)

            if FINE_TUNING and not start_train:
                start_train = (date_idx + 1) % TRAIN_EVERY == 0
                if start_train:
                    print('Starting new fine tuning')
                    model.eval()
                    max_date = new_dataset.select(pl.col('date_id').max()).item()
                    new_validation_dataset = new_dataset.filter(pl.col('date_id') > max_date - EARLY_STOPPING_DAYS)
                    new_training_dataset = new_dataset.filter(pl.col('date_id') <= max_date - EARLY_STOPPING_DAYS)
                    train_days = new_training_dataset['date_id'].unique().sort().to_list()
                    val_days = new_validation_dataset['date_id'].unique().sort().to_list()
                    print(f'Training days: {train_days}')
                    print(f'Validation days: {val_days}')
                    
                    if OLD_DATA_FRACTION > 0:
                        old_data_len = OLD_DATA_FRACTION * new_training_dataset.shape[0] / (1 - OLD_DATA_FRACTION)
                        time_factions = min(1, old_data_len / old_dataset.shape[0])
                        old_date_times = old_dataset.select(['date_id', 'time_id']).unique().sample(fraction=time_factions)
                                            
                        old_training_dataset = old_dataset.join(old_date_times, on=['date_id', 'time_id'], how='inner')
                        
                        print(f'Old training days: {old_training_dataset["date_id"].unique().to_list()}')
                        
                        train_dataloader = MultiStockGraphDataset(pl.concat([old_training_dataset, new_training_dataset]), adjacency_matrices.copy(), current_stock_ids)
                        val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                    else:
                        train_dataloader = MultiStockGraphDataset(new_training_dataset, adjacency_matrices.copy(), current_stock_ids)
                        val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                    
                    
                    
                    
                    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
                    train_dataloader = DataLoader(train_dataloader, shuffle=True, batch_size=BATCH_SIZE, num_workers=0)
                    val_dataloader = DataLoader(val_dataloader, shuffle=False, batch_size=2048, num_workers=0)
                    val_iterator = iter(val_dataloader)
                    acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                    is_training_loop = False
                    epoch = -1
                    best_epoch = -1
                    best_score = -1e10

                    if OLD_DATA_FRACTION > 0:
                        max_new_date_id = new_training_dataset['date_id'].max()
                        old_dataset = old_dataset.vstack(new_training_dataset).filter(
                            pl.col('date_id').is_between(max_new_date_id - 30, max_new_date_id)
                        )
                        
                    new_dataset = new_validation_dataset
                    
            date_idx += 1
        else:
            current_day_data = current_day_data.vstack(test)
            
        if FINE_TUNING:
            while start_train and time.time() - initial_time < TIME_LIMIT:
                if is_training_loop:
                    try:
                        batch = next(train_iterator)
                    except StopIteration:
                        model.eval()
                        val_iterator = iter(val_dataloader)
                        acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                        is_training_loop = False
                        continue
            
                    x, targets, m, w, s, A = batch
                    optimizer.zero_grad()
                    y_out = model.forward(x.to(device), s.to(device), A.to(device)).squeeze()
                    loss = loss_fn(y_out, targets.to(device), w.to(device))
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                    optimizer.step()
                    
                else:
                    try:
                        batch = next(val_iterator)
                    except StopIteration:
                        score = 1 - acc_metrics['ss_res'] / acc_metrics['ss_tot']
                        print(f'Epoch {epoch} completed with score {score}')
                        epoch += 1
                        if score > best_score:
                            torch.save(model.state_dict(), save_path)
                            inference_model.load_state_dict(torch.load(save_path, weights_only=True))
                            inference_model.to(device)
                            inference_model.eval()
                            best_epoch = epoch
                            best_score = score
                        if epoch - best_epoch >= ES_PATIENCE or epoch == N_EPOCHS_PER_TRAIN_MAX:
                            print(f'Stopping after {epoch} epochs')
                            print(f'Completed Fine Tuning at time {test.select(pl.col("time_id").first()).item()}')
                            model.load_state_dict(torch.load(save_path, weights_only=True))
                            model.to(device)
                            model.eval()
                            start_train = False
                            lr *= lr_decay
                            gradient_clipping *= gradient_clipping_decay
                            break
                        model.train()
                        train_iterator = iter(train_dataloader)
                        is_training_loop = True
                        continue

                    x, targets, m, w, s, A = batch
                    with torch.no_grad():
                        y_out = model(x.to(device), s.to(device), A.to(device)).squeeze()
                    w = w.to(device)
                    targets = targets.to(device)
                    acc_metrics['ss_res'] += (w * (y_out - targets) ** 2).sum().cpu()
                    acc_metrics['ss_tot'] += (w * (targets ** 2)).sum().cpu()

        if test.select(pl.col('is_scored').cast(pl.Int8).first()).item() > 0:
            stock_ids = test.select('symbol_id').to_numpy().flatten().tolist()
            missing_ids = [x for x in current_stock_ids if x not in stock_ids]
            new_ids = [x for x in stock_ids if x not in current_stock_ids]
            predict_df = test.select([f'feature_{i:02d}' for i in range(79)] + ['symbol_id']).fill_null(0).fill_nan(0)
            predict_df = predict_df.with_columns(pl.lit(1).cast(pl.Int8).alias('valid'))
            if len(missing_ids) > 0:
                if test.select(pl.col('time_id').first()).item() == 0:
                    print(f'Missing ids: {missing_ids}')
                add_df = []
                for missing_id in missing_ids:
                    record = {f'feature_{i:02d}': 0.0 for i in range(79)}
                    record['symbol_id'] = missing_id
                    record['valid'] = 0
                    add_df.append(record)
                predict_df = pl.concat([predict_df, pl.DataFrame(add_df, schema=predict_df.schema)], how='diagonal')
            predict_df = predict_df.sort('symbol_id')
            current_stock_ids += new_ids
            if len(new_ids) > 0:
                print(f'New ids: {new_ids}')
                past_responders_pivot = past_responders_pivot.with_columns(
                    pl.lit(0.0).cast(pl.Float32).alias(str(i)) for i in new_ids
                )
                for id_ in new_ids:
                    index = adj.shape[0] if id_ >= current_corr_matrix.shape[0] else id_
                    current_corr_matrix = np.insert(current_corr_matrix, index, 0, axis=0)
                    current_corr_matrix = np.insert(current_corr_matrix, index, 0, axis=1)
                adj_matrices = []
                for i in range(adjacency_matrices.shape[0]):
                    adj = adjacency_matrices[i, :, :]
                    for id_ in new_ids:
                        index = adj.shape[0] if id_ >= adj.shape[0] else id_
                        adj = np.insert(adj, index, 0, axis=0)
                        adj = np.insert(adj, index, 0, axis=1)
                    adj_matrices.append(adj)
                adjacency_matrices = np.stack(adj_matrices)
            
            x = torch.tensor(predict_df.drop(['symbol_id', 'valid']).to_numpy(), dtype=torch.float32).unsqueeze(0).to(device)
            symbols = predict_df.select(['symbol_id']).fill_null(0).fill_nan(0).to_numpy().flatten()
            symbols = torch.tensor(symbols, dtype=torch.int).unsqueeze(0).to(device)
            adj_matrix = current_corr_matrix.copy()
            adj_matrix[missing_ids, :] = 0
            adj_matrix[:, missing_ids] = 0
            adj_matrix[np.arange(len(current_stock_ids)), np.arange(len(current_stock_ids))] = 0
            adj_matrix = (adj_matrix > CORRELATION_THR).astype(np.int32)
            adj_matrix = torch.tensor(adj_matrix, dtype=torch.int, device=device).unsqueeze(0)
            with torch.no_grad():
                preds = inference_model(x, symbols, adj_matrix).cpu().numpy().flatten()
            predict_df = predict_df.with_columns(pl.Series(preds).alias('responder_6'))
            predictions = test.join(predict_df, on='symbol_id', how='left').select(['row_id', 'responder_6'])
        else:
            predictions = test.select('row_id', pl.Series(np.zeros(test.shape[0])).alias('responder_6'))
        predictions = predictions.with_columns(pl.col('responder_6').cast(pl.Float32))

        if isinstance(predictions, pl.DataFrame):
            assert predictions.columns == ['row_id', 'responder_6']
        elif isinstance(predictions, pd.DataFrame):
            assert (predictions.columns == ['row_id', 'responder_6']).all()
        else:
            raise TypeError('The predict function must return a DataFrame')
        assert len(predictions) == len(test)

        return predictions
    
    

In [14]:
current_day_data : pl.DataFrame | None = None

old_dataset = loader.load(start-60, start-1)\
    .fill_nan(None).fill_null(strategy='zero')\
    .sort(['date_id', 'time_id', 'symbol_id']) \
    .select(COLUMNS).collect()

last_train_date = start-1
new_dataset = old_dataset.filter(pl.col('date_id') > start-30)
if OLD_DATA_FRACTION > 0:
    old_dataset = old_dataset.filter(pl.col('date_id') <= start-30)
else:
    old_dataset = None

In [15]:
CORRELATION_THR = 0.1
WINDOW_LEN = 2 if not os.getenv('KAGGLE_IS_COMPETITION_RERUN') else 7
past_responders_pivot: pl.DataFrame | None = None
current_date_id = -1
current_stock_ids = list(range(39))
num_dates = 0

adjacency_matrices = np.load('/home/lorecampa/projects/jane_street_forecasting/dataset/sources/graph_conv_torch/adjacency_matrices.npy')[:last_train_date+1, :, :]
current_corr_matrix = np.load('/home/lorecampa/projects/jane_street_forecasting/dataset/sources/graph_conv_torch/correlations.npy')[last_train_date, :, :]

In [16]:
def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
    global BATCH_SIZE, GRADIENT_CLIPPING, N_EPOCHS_PER_TRAIN_MAX, TRAIN_EVERY
    global EARLY_STOPPING_DAYS, ES_PATIENCE, OLD_DATA_FRACTION
    global MAX_FINE_TUNING_TIME_LIMIT, TIME_LIMIT, FINE_TUNING
    global CORRELATION_THR, WINDOW_LEN
    global date_idx, new_dataset, old_dataset, current_day_data, last_train_date
    global train_dataloader, val_dataloader, train_iterator, val_iterator, adjacency_matrices
    global acc_metrics, save_path, start_train, is_training_loop, epoch, best_epoch, best_score
    global current_stock_ids, current_date_id, current_corr_matrix, num_dates, past_responders_pivot
    global gradient_clipping_decay, gradient_clipping, lr, lr_decay, optimizer

    initial_time = time.time()
    FINE_TUNING = FINE_TUNING & (initial_time < MAX_FINE_TUNING_TIME_LIMIT)
    start_train = start_train if FINE_TUNING else False

    if lags is not None:
        # print(f"Date id: {test['date_id'].min()}")
        # new date_id
        lags_ = lags.select(
            pl.col('date_id').sub(1),
            pl.col(['time_id', 'symbol_id']),
            pl.col('responder_6_lag_1').alias('responder_6'),
        )
        if current_day_data is not None:
            # print(current_day_data, new_dataset['date_id'].unique().to_list(), old_dataset['date_id'].unique().to_list())
            current_day_data = current_day_data.join(lags_, on=['date_id', 'time_id', 'symbol_id'], 
                                                     how='left').fill_null(0)
            current_day_data = current_day_data.select(COLUMNS)
            # replacing date id to ensure that adjacency_matrices array is consistent
            current_day_data = (
                current_day_data
                .drop('date_id')
                .with_columns(pl.lit(last_train_date + date_idx + 1).cast(pl.Int16).alias('date_id'))
                .select(COLUMNS)
            )

            new_dataset = new_dataset.vstack(current_day_data)
            last_adj = current_corr_matrix.copy()
            
            last_adj[np.arange(len(current_stock_ids)), np.arange(len(current_stock_ids))] = 0
            last_adj = (last_adj > CORRELATION_THR).astype(np.int32)[np.newaxis, :, :]
            adjacency_matrices = np.concatenate([adjacency_matrices, last_adj], axis=0)
            
        current_day_data = test

        all_combinations = (
            lags_.select(['date_id', 'time_id'])
            .unique()
            .join(pl.DataFrame({'symbol_id': current_stock_ids}, 
                               schema={'symbol_id': pl.Int8}), how="cross")
        )
        
        pivot_lags = (
            all_combinations
            .join(lags_, on=['date_id', 'time_id', 'symbol_id'], how="left")
            .fill_null(0)
            .sort(['date_id', 'time_id', 'symbol_id'])
            .pivot(index=['date_id', 'time_id'], values='responder_6', on='symbol_id')
            .fill_null(0)
        )
        
        past_responders_pivot = (
            pl.concat([past_responders_pivot, pivot_lags], how='diagonal')
            .filter(pl.col('date_id') >= current_date_id - WINDOW_LEN - 1)
        ) if past_responders_pivot is not None else pivot_lags
        
        if num_dates >= WINDOW_LEN:
            current_corr_matrix = compute_correlation_from_pivot(past_responders_pivot)

        if FINE_TUNING and not start_train:
            start_train = (date_idx+1) % TRAIN_EVERY == 0
            if start_train:
                print('Starting new fine tuning')
                model.eval()
                max_date = new_dataset.select(pl.col('date_id').max()).item()
                new_validation_dataset = new_dataset.filter(pl.col('date_id') > max_date - EARLY_STOPPING_DAYS)
                new_training_dataset = new_dataset.filter(pl.col('date_id') <= max_date - EARLY_STOPPING_DAYS)
                train_days = new_training_dataset['date_id'].unique().sort().to_list()
                val_days = new_validation_dataset['date_id'].unique().sort().to_list()
                print(f'Training days: {train_days}')
                print(f'Validation days: {val_days}')
                
                if OLD_DATA_FRACTION > 0:
                    old_data_len = OLD_DATA_FRACTION * new_training_dataset.shape[0] / (1 - OLD_DATA_FRACTION)
                    time_factions = min(1, old_data_len / old_dataset.shape[0])
                    old_date_times = old_dataset.select(['date_id', 'time_id']).unique().sample(fraction=time_factions)
                                        
                    old_training_dataset = old_dataset.join(old_date_times, on=['date_id', 'time_id'], how='inner')
                    
                    print(f'Old training days: {old_training_dataset["date_id"].unique().to_list()}')
                    
                    train_dataloader = MultiStockGraphDataset(pl.concat([old_training_dataset, new_training_dataset]), adjacency_matrices.copy(), current_stock_ids)
                    val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                else:
                    train_dataloader = MultiStockGraphDataset(new_training_dataset, adjacency_matrices.copy(), current_stock_ids)
                    val_dataloader = MultiStockGraphDataset(new_validation_dataset, adjacency_matrices.copy(), current_stock_ids)
                
                
                
                
                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
                train_dataloader = DataLoader(train_dataloader, shuffle=True, batch_size=BATCH_SIZE, num_workers=0)
                val_dataloader = DataLoader(val_dataloader, shuffle=False, batch_size=2048, num_workers=0)
                val_iterator = iter(val_dataloader)
                acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                is_training_loop = False
                epoch = -1
                best_epoch = -1
                best_score = -1e10

                if OLD_DATA_FRACTION > 0:
                    max_new_date_id = new_training_dataset['date_id'].max()
                    old_dataset = old_dataset.vstack(new_training_dataset).filter(
                        pl.col('date_id').is_between(max_new_date_id - 30, max_new_date_id)
                    )
                    
                new_dataset = new_validation_dataset
                
        date_idx += 1
    else:
        current_day_data = current_day_data.vstack(test)
        
    if FINE_TUNING:
        while start_train and time.time() - initial_time < TIME_LIMIT:
            if is_training_loop:
                try:
                    batch = next(train_iterator)
                except StopIteration:
                    model.eval()
                    val_iterator = iter(val_dataloader)
                    acc_metrics = dict(ss_res=0.0, ss_tot=0.0)
                    is_training_loop = False
                    continue
        
                x, targets, m, w, s, A = batch
                optimizer.zero_grad()
                y_out = model.forward(x.to(device), s.to(device), A.to(device)).squeeze()
                loss = loss_fn(y_out, targets.to(device), w.to(device))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                optimizer.step()
                
            else:
                try:
                    batch = next(val_iterator)
                except StopIteration:
                    score = 1 - acc_metrics['ss_res'] / acc_metrics['ss_tot']
                    print(f'Epoch {epoch} completed with score {score}')
                    epoch += 1
                    if score > best_score:
                        torch.save(model.state_dict(), save_path)
                        inference_model.load_state_dict(torch.load(save_path, weights_only=True))
                        inference_model.to(device)
                        inference_model.eval()
                        best_epoch = epoch
                        best_score = score
                    if epoch - best_epoch >= ES_PATIENCE or epoch == N_EPOCHS_PER_TRAIN_MAX:
                        print(f'Stopping after {epoch} epochs')
                        print(f'Completed Fine Tuning at time {test.select(pl.col("time_id").first()).item()}')
                        model.load_state_dict(torch.load(save_path, weights_only=True))
                        model.to(device)
                        model.eval()
                        start_train = False
                        lr *= lr_decay
                        gradient_clipping *= gradient_clipping_decay
                        break
                    model.train()
                    train_iterator = iter(train_dataloader)
                    is_training_loop = True
                    continue

                x, targets, m, w, s, A = batch
                with torch.no_grad():
                    y_out = model(x.to(device), s.to(device), A.to(device)).squeeze()
                w = w.to(device)
                targets = targets.to(device)
                acc_metrics['ss_res'] += (w * (y_out - targets) ** 2).sum().cpu()
                acc_metrics['ss_tot'] += (w * (targets ** 2)).sum().cpu()

    if test.select(pl.col('is_scored').cast(pl.Int8).first()).item() > 0:
        test_ = test.fill_nan(None).fill_null(strategy='zero')
        predict_df = (
            test_.select(['date_id', 'time_id'])
                .unique()
                .join(pl.DataFrame({'symbol_id': list(range(39))}, 
                                schema={'symbol_id': pl.Int8}), how="cross")
                .join(test_.with_columns(pl.lit(1).alias('mask')), 
                    on=['date_id', 'time_id', 'symbol_id'], how="left")
                .fill_null(0)  # fill all columns with 0 for missing stocks (including the mask)
                .sort(['date_id', 'time_id', 'symbol_id'])
        )
        valid_data = predict_df.select(['mask']).to_numpy().flatten() == 1
        x = torch.tensor(predict_df.select([f'feature_{i:02d}' for i in range(79)]).to_numpy().reshape(-1, 39, 79), dtype=torch.float32).to(device)
        s = torch.tensor(predict_df.select(['symbol_id']).to_numpy().flatten().reshape(-1, 39).astype(np.int32)).to(device)
        # adj = adjacency_matrices[predict_df.select(pl.col('date_id').first()).item()][np.newaxis, :, :]
        # adj = torch.tensor(adj, dtype=torch.int, device=device).repeat(x.shape[0], 1, 1)
        
        adj_matrix = current_corr_matrix.copy()
        adj_matrix[np.arange(39), np.arange(39)] = 0
        adj_matrix = (adj_matrix > CORRELATION_THR).astype(np.int32)
        adj_matrix = torch.tensor(adj_matrix, dtype=torch.int, device=device).unsqueeze(0).repeat(x.shape[0], 1, 1)
        with torch.no_grad():
            preds = inference_model(x, s, adj_matrix).cpu().numpy().flatten()[valid_data]
        predict_df = predict_df.filter(pl.col('mask') == 1).with_columns(pl.Series(preds).alias('responder_6'))
        
        predictions = test.join(predict_df, on=['time_id', 'symbol_id'], how='left').select(['row_id', 'responder_6'])
    else:
        predictions = test.select('row_id', pl.Series(np.zeros(test.shape[0])).alias('responder_6'))
    predictions = predictions.with_columns(pl.col('responder_6').cast(pl.Float32))

    if isinstance(predictions, pl.DataFrame):
        assert predictions.columns == ['row_id', 'responder_6']
    elif isinstance(predictions, pd.DataFrame):
        assert (predictions.columns == ['row_id', 'responder_6']).all()
    else:
        raise TypeError('The predict function must return a DataFrame')
    
    assert len(predictions) == len(test)

    return predictions

In [17]:
from nbs.tabm.predict_tabm import predict_tabm
from prj.utils import online_iterator, online_iterator_daily
from sklearn.metrics import r2_score

y_hat_iterator = []
for i, (test, lags) in enumerate(online_iterator_daily(test_ds, show_progress=True)):
    # print(len(test))
    res = predict(test, lags)
    y_hat_iterator.append(res['responder_6'].to_numpy())
    
y_hat_iterator = np.concatenate(y_hat_iterator) if len(y_hat_iterator) > 0 else None


score = r2_score(y_test, y_hat_iterator, sample_weight=w_test)

score

100%|██████████| 170/170 [00:49<00:00,  3.46it/s]


0.016260981559753418

In [18]:
# 0.013081669807434082


In [19]:
# 0.011371493339538574