In [1]:
import numpy as np
import polars as pl
from pathlib import Path
import gc
import os
from typing import List, Union, Dict, Any

from prj.data.data_loader import DataLoader as BaseDataLoader
import torch
from torch.utils.data import DataLoader
from prj.model.torch.datasets.base import JaneStreetBaseDataset
from prj.model.torch.losses import WeightedMSELoss
from prj.model.torch.models.mlp import Mlp
from prj.model.torch.wrappers.base import JaneStreetModelWrapper
from prj.model.torch.utils import train
from prj.config import EXP_DIR

2024-12-22 23:16:40.250526: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-22 23:16:40.250558: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-22 23:16:40.251722: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-22 23:16:40.258024: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

In [3]:
from prj.config import DATA_DIR
from prj.data.data_loader import DataConfig

data_args = {'zero_fill': True, 'ffill': True, 'include_intrastock_norm': True}
config = DataConfig(**data_args)
loader = BaseDataLoader(data_dir=DATA_DIR, config=config)

In [4]:
from prj.data.data_loader import PARTITIONS_DATE_INFO
start_dt, end_dt = PARTITIONS_DATE_INFO[5]['min_date'], PARTITIONS_DATE_INFO[8]['max_date']
val_ratio = 0.2
es_ratio = 0.1
early_stopping = False

train_ds, val_ds = loader.load_train_and_val(start_dt, end_dt, val_ratio=val_ratio)
 
es_ds = None
if early_stopping and es_ratio > 0:
    train_dates = train_ds.select('date_id').unique().collect().to_series().sort()
    split_point = int(len(train_dates) * (1 - es_ratio))
    split_date = train_dates[split_point]
    es_ds = train_ds.filter(pl.col('date_id').ge(split_date))
    train_ds = train_ds.filter(pl.col('date_id').lt(split_date))

n_rows_train = train_ds.select(pl.len()).collect().item()
n_dates_train = train_ds.select('date_id').unique().collect().count().item()
n_rows_es = es_ds.select(pl.len()).collect().item() if early_stopping else 0
n_dates_es = es_ds.select('date_id').unique().collect().count().item() if early_stopping else 0
n_rows_val = val_ds.select(pl.len()).collect().item()
n_dates_val = val_ds.select('date_id').unique().collect().count().item()
print(f'N rows train: {n_rows_train}, ES: {n_rows_es}, VAL: {n_rows_val}')
print(f'N dates train: {n_dates_train}, ES: {n_dates_es}, VAL: {n_dates_val}')

N rows train: 19101544, ES: 0, VAL: 4926152
N dates train: 544, ES: 0, VAL: 136


In [5]:
train_ds = JaneStreetBaseDataset(train_ds, features=loader.features)
if early_stopping:
    es_ds = JaneStreetBaseDataset(es_ds, features=loader.features)
    
num_workers = 3 
train_dataloader = DataLoader(train_ds, batch_size=8192, shuffle=True, num_workers=num_workers)
es_dataloader = DataLoader(es_ds, batch_size=8192, shuffle=False, num_workers=num_workers) if early_stopping else None

In [6]:

model = Mlp(len(loader.features), hidden_dims=[512, 512, 256], dropout_rate=0.3, final_mult=5.0, use_tanh=True)

optimizer = 'Adam'
optimizer_cfg = dict(lr=5e-4) #, weight_decay=5e-4)
scheduler = None #'MultiStepLR'
scheduler_cfg = dict(milestones=[9], gamma=0.1)

model = JaneStreetModelWrapper(
    model=model,
    losses=[WeightedMSELoss()],
    loss_weights=[1],
    l1_lambda=0,
    l2_lambda=0,
    optimizer=optimizer,
    optimizer_cfg=optimizer_cfg,
    scheduler=scheduler,
    scheduler_cfg=scheduler_cfg,
)

# model = JaneStreetModelWrapper.load_from_checkpoint(
#     EXP_DIR / 'model' / 'baseline_all.ckpt', 
#     model=model,
#     losses=[WeightedMSELoss()],
#     loss_weights=[1],
#     l1_lambda=0,
#     l2_lambda=0
# )


In [7]:
# for backend in ['inductor', 'eager', 'aot_eager', 'nvfuser', 'ts_nvfuser', 'aot_cudagraphs', 'ipex', 'ofi']:
#     try:
#         model = torch.compile(model, backend=backend)
#         print(f"Backend {backend} is working.")
#     except Exception as e:
#         print(f"Backend {backend} failed with error: {e}")


In [7]:
from prj.config import EXP_DIR

ckpt_config = {'dirpath': str(EXP_DIR / 'model'), 'filename': 'mlp',
               'save_on_train_epoch_end': True, 'verbose': True}

early_stopping_cfg = {'monitor': 'val_wr2', 'min_delta': 0.00, 'patience': 5, 'verbose': True, 'mode': 'max'}
swa_cfg = {'swa_lrs': 0.05, 'swa_epoch_start': 5}

os.makedirs(ckpt_config['dirpath'], exist_ok=True)

compile_cfg = {}

model = train(model, train_dataloader, es_dataloader, max_epochs=10, precision='32-true', use_early_stopping=early_stopping, early_stopping_cfg=early_stopping_cfg,
              use_model_ckpt=True, gradient_clip_val=20, model_ckpt_cfg=ckpt_config, log_every_n_steps=100, accelerator='gpu',
              use_swa=False, swa_cfg=swa_cfg, accumulate_grad_batches=4, compile=False, compile_kwargs=compile_cfg)

Seed set to 42
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/lorecampa/projects/jane_street_forecasting/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/lorecampa/projects/jane_street_forecasting/.venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/lorecampa/projects/jane_street_forecasting/experiments/model exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type            | Params | Mode 
------------------------------------------------------------
0  | model          | Mlp             | 465 K  | train
1  | m

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.

The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.

See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.

or by setting POLARS_ALLOW_FORKING_THREAD=1.

  self.pid = os.fork()


NameError: name 'exit' is not defined

In [None]:
del train_ds, es_ds, train_dataloader, es_dataloader
gc.collect()

In [None]:
from tqdm import tqdm

# val_ds = loader.load_with_partition(start_part_id=9, end_part_id=9)
val_ds = JaneStreetBaseDataset(val_ds, features=loader.features, device=device)
val_dataloader = DataLoader(val_ds, batch_size=8192, shuffle=False, num_workers=0)

model = model.to(device)
model.eval()

y_hat = []
y_true = []
w = []

with torch.no_grad():
    for batch in tqdm(val_dataloader, desc="Predicting", unit="batch"):
        inputs, labels, weights = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        outputs = model(inputs)
        y_hat.append(outputs.cpu())
        y_true.append(labels.cpu())
        w.append(weights.cpu())
        
y_val_hat = torch.cat(y_hat, dim=0).cpu().numpy().flatten()
y_val_true = torch.cat(y_true, dim=0).cpu().numpy().flatten()
w_val = torch.cat(w, dim=0).cpu().numpy().flatten()
y_val_hat.shape, y_val_true.shape, w_val.shape

In [None]:
from prj.metrics import weighted_mae, weighted_mse, weighted_r2, weighted_rmse

def metrics(y_true, y_pred, weights=None):
    return {
        'r2_w': weighted_r2(y_true, y_pred, weights=weights),
        'mae_w': weighted_mae(y_true, y_pred, weights=weights),
        'mse_w': weighted_mse(y_true, y_pred, weights=weights),
        'rmse_w': weighted_rmse(y_true, y_pred, weights=weights),
    }

metrics(y_val_true, y_val_hat, weights=w_val)