In [1]:
import numpy as np
import polars as pl
from pathlib import Path
import gc
import os
from typing import List, Union, Dict, Any

from prj.data.data_loader import DataLoader as BaseDataLoader
import torch
from torch.utils.data import DataLoader
from prj.model.torch.datasets.base import JaneStreetBaseDataset
from prj.model.torch.losses import WeightedMSELoss
from prj.model.torch.models.mlp import Mlp
from prj.model.torch.wrappers.base import JaneStreetModelWrapper
from prj.model.torch.utils import train

2024-12-18 16:26:49.853756: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-18 16:26:49.853787: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-18 16:26:49.855029: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-18 16:26:49.861402: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from prj.config import DATA_DIR
from prj.data.data_loader import DataConfig

data_args = {'zero_fill': True}
config = DataConfig(**data_args)
loader = BaseDataLoader(data_dir=DATA_DIR, config=config)

In [3]:
start_dt, end_dt = 1020, 1529
# start_dt, end_dt = 1020, 1100
val_ratio = 0.2
es_ratio = 0.1
early_stopping = True

train_ds, val_ds = loader.load_train_and_val(start_dt=start_dt, end_dt=end_dt, val_ratio=val_ratio)        
es_ds = None
if early_stopping:
    train_dates = train_ds.select('date_id').unique().collect().to_series().sort()
    split_point = int(len(train_dates) * (1 - es_ratio))
    split_date = train_dates[split_point]
    es_ds = train_ds.filter(pl.col('date_id').ge(split_date))
    train_ds = train_ds.filter(pl.col('date_id').lt(split_date))

n_rows_train = train_ds.select(pl.len()).collect().item()
n_dates_train = train_ds.select('date_id').unique().collect().count().item()
n_rows_es = es_ds.select(pl.len()).collect().item() if early_stopping else 0
n_dates_es = es_ds.select('date_id').unique().collect().count().item() if early_stopping else 0
n_rows_val = val_ds.select(pl.len()).collect().item()
n_dates_val = val_ds.select('date_id').unique().collect().count().item()
print(f'N rows train: {n_rows_train}, ES: {n_rows_es}, VAL: {n_rows_val}')
print(f'N dates train: {n_dates_train}, ES: {n_dates_es}, VAL: {n_dates_val}')

N rows train: 13495856, ES: 1457808, VAL: 3692920
N dates train: 366, ES: 41, VAL: 102


In [4]:
train_ds = JaneStreetBaseDataset(train_ds, features=loader.features)
if early_stopping:
    es_ds = JaneStreetBaseDataset(es_ds, features=loader.features)


In [5]:
scheduler = 'ReduceLROnPlateau'
scheduler_cfg = dict(mode='min', factor=0.1, patience=3, verbose=True, min_lr=1e-8)
model = Mlp(input_dim=(len(loader.features),), hidden_dims=[256, 128, 64], use_dropout=True, use_bn=False)
model = JaneStreetModelWrapper(
    model, 
    [WeightedMSELoss()], 
    [1], 
    scheduler=scheduler, 
    scheduler_cfg=scheduler_cfg
)

In [6]:
early_stopping = {'monitor': 'val_wr2', 'min_delta': 0.00, 'patience': 5, 'verbose': True, 'mode': 'max'}
batch_size = 1024
num_workers = 3
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
es_dataloader = DataLoader(es_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) if early_stopping else None


model: JaneStreetModelWrapper = train(model, train_dataloader, es_dataloader, accelerator='auto',
                                max_epochs=20, precision='32-true', use_model_ckpt=False, 
                                gradient_clip_val=10, use_early_stopping=True, 
                                early_stopping_cfg=early_stopping, compile=False)

Seed set to 42
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type            | Params | Mode 
-----------------------------------------------------------
0  | model         | Mlp             | 61.7 K | train
1  | model.model   | Sequential      | 61.7 K | train
2  | model.model.0 | Linear          | 20.5 K | train
3  | model.model.1 | LeakyReLU       | 0      | train
4  | model.model.2 | Dropout         | 0      | train
5  | model.model.3 | Linear          | 32.9 K | train
6  | model.model.4 | LeakyReLU       | 0      | train
7  | model.model.5 | Dropout         | 0      | train
8  | model.model.6 | Linear          | 8.3 K  | train
9  | model.model.7 | LeakyReLU       | 0      | t

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.

  self.pid = os.fork()



[Epoch 0 - Validation]
val_wmse: 3.1414
val_wmse_epoch: 3.1414
val_wmae: 1.2893
val_wmae_epoch: 1.2893
val_wr2: -0.1392
val_wr2_epoch: -0.1392
val_loss: 3.1414
val_loss_epoch: 3.1414


Training: |          | 0/? [00:00<?, ?it/s]

In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_wr2 improved. New best score: -0.000



[Epoch 0 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5515
val_wmae_epoch: 0.5515
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 0 - Training]
train_wmse: 0.7201
train_wmse_step: 0.9726
train_wmae: 0.5460
train_wmae_step: 0.5787
train_wr2: -0.0007
train_wr2_step: -0.0003
train_loss: 0.7201
train_loss_step: 0.9726
train_wmse_epoch: 0.7201
train_wmae_epoch: 0.5460
train_wr2_epoch: -0.0007
train_loss_epoch: 0.7201


Validation: |          | 0/? [00:00<?, ?it/s]


[Epoch 1 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5514
val_wmae_epoch: 0.5514
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 1 - Training]
train_wmse: 0.7196
train_wmse_step: 0.6711
train_wmae: 0.5458
train_wmae_step: 0.5558
train_wr2: -0.0000
train_wr2_step: 0.0006
train_loss: 0.7196
train_loss_step: 0.6711
train_wmse_epoch: 0.7196
train_wmae_epoch: 0.5458
train_wr2_epoch: -0.0000
train_loss_epoch: 0.7196


Validation: |          | 0/? [00:00<?, ?it/s]


[Epoch 2 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5517
val_wmae_epoch: 0.5517
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 2 - Training]
train_wmse: 0.7196
train_wmse_step: 0.5750
train_wmae: 0.5458
train_wmae_step: 0.5162
train_wr2: -0.0000
train_wr2_step: -0.0000
train_loss: 0.7196
train_loss_step: 0.5750
train_wmse_epoch: 0.7196
train_wmae_epoch: 0.5458
train_wr2_epoch: -0.0000
train_loss_epoch: 0.7196


Validation: |          | 0/? [00:00<?, ?it/s]


[Epoch 3 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5518
val_wmae_epoch: 0.5518
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 3 - Training]
train_wmse: 0.7196
train_wmse_step: 0.5677
train_wmae: 0.5458
train_wmae_step: 0.4896
train_wr2: -0.0000
train_wr2_step: 0.0002
train_loss: 0.7196
train_loss_step: 0.5677
train_wmse_epoch: 0.7196
train_wmae_epoch: 0.5458
train_wr2_epoch: -0.0000
train_loss_epoch: 0.7196


Validation: |          | 0/? [00:00<?, ?it/s]


[Epoch 4 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5518
val_wmae_epoch: 0.5518
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 4 - Training]
train_wmse: 0.7196
train_wmse_step: 0.6982
train_wmae: 0.5458
train_wmae_step: 0.5369
train_wr2: -0.0000
train_wr2_step: -0.0006
train_loss: 0.7196
train_loss_step: 0.6982
train_wmse_epoch: 0.7196
train_wmae_epoch: 0.5458
train_wr2_epoch: -0.0000
train_loss_epoch: 0.7196


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_wr2 did not improve in the last 5 records. Best score: -0.000. Signaling Trainer to stop.



[Epoch 5 - Validation]
val_wmse: 0.7383
val_wmse_epoch: 0.7383
val_wmae: 0.5514
val_wmae_epoch: 0.5514
val_wr2: -0.0000
val_wr2_epoch: -0.0000
val_loss: 0.7383
val_loss_epoch: 0.7383

[Epoch 5 - Training]
train_wmse: 0.7196
train_wmse_step: 0.6510
train_wmae: 0.5458
train_wmae_step: 0.5306
train_wr2: 0.0000
train_wr2_step: 0.0001
train_loss: 0.7196
train_loss_step: 0.6510
train_wmse_epoch: 0.7196
train_wmae_epoch: 0.5458
train_wr2_epoch: 0.0000
train_loss_epoch: 0.7196


In [7]:
X_val, y_val, w_val, _ = loader._build_splits(val_ds)
X_val.shape, y_val.shape, w_val.shape

((3692920, 79), (3692920,), (3692920,))

In [8]:
y_hat = model.predict(X_val)
y_hat.shape

(3692920,)

In [9]:
from prj.metrics import weighted_mae, weighted_mse, weighted_r2, weighted_rmse

{
    'r2_w': weighted_r2(y_val, y_hat, weights=w_val),
    'mae_w': weighted_mae(y_val, y_hat, weights=w_val),
    'mse_w': weighted_mse(y_val, y_hat, weights=w_val),
    'rmse_w': weighted_rmse(y_val, y_hat, weights=w_val),
}

{'r2_w': -3.075599670410156e-05,
 'mae_w': 0.50378007,
 'mse_w': 0.6111065,
 'rmse_w': 0.78173304}