In [1]:
import numpy as np
import polars as pl
from pathlib import Path
import gc
import os
from typing import List, Union, Dict, Any

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchmetrics.functional as tmf
import lightning as L
import lightning.pytorch.callbacks as C

In [2]:
from prj.config import DATA_DIR


BASE_PATH = DATA_DIR

train_ds = pl.concat([
    pl.scan_parquet(BASE_PATH / f'partition_id={i}' / 'part-0.parquet')
    for i in range(8, 9)
])
val_ds = pl.scan_parquet(BASE_PATH / 'partition_id=9' / 'part-0.parquet')

In [3]:
class JaneStreetBaseDataset(Dataset):
    
    def __init__(self, dataset: pl.LazyFrame, num_days_batch: int = 10, 
                 num_stocks: int = 39, num_timesteps: int = 50):
        super(JaneStreetBaseDataset, self).__init__()   
        self.dataset = dataset
        self.num_days_batch = num_days_batch
        self.num_stocks = num_stocks
        self.num_timesteps = num_timesteps
        self.dataset_len = self.dataset.select(['date_id', 'time_id', 'symbol_id']).unique().collect().shape[0]
        self._load()
    
    def _shuffle_batches(self):
        np.random.shuffle(self.dates)
    
    def _load(self):
        feature_cols = [f'feature_{i:02d}' for i in range(79)]
        preprocessed_dataset = (
            self.dataset
            .sort(['date_id', 'time_id'])
            .with_columns(pl.col(feature_cols).fill_null(strategy='forward', limit=10).over('symbol_id').fill_null(strategy='zero'))
        )
        
        self.X = preprocessed_dataset.select(feature_cols).collect().to_numpy().astype(np.float32)
        self.y = preprocessed_dataset.select(['responder_6']).collect().to_numpy().flatten().astype(np.float32)
        self.weights = preprocessed_dataset.select(['weight']).collect().to_numpy().flatten().astype(np.float32)
    
    def __len__(self):
        return self.dataset_len
    
    def __getitem__(self, idx):       
        return (
            torch.tensor(self.X[idx, :], dtype=torch.float32), 
            torch.tensor(self.y[idx], dtype=torch.float32), 
            torch.tensor(self.weights[idx], dtype=torch.float32)
        )

In [None]:
train_dataset = JaneStreetBaseDataset(train_ds, num_days_batch=200)
val_dataset = JaneStreetBaseDataset(val_ds, num_days_batch=200)

batch_size = 1024
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def weighted_r2_score(preds, targets, weights):
    ss_res = (weights * (targets - preds) ** 2).sum()
    ss_tot = (weights * (targets ** 2)).sum()
    return 1 - ss_res / ss_tot if ss_tot > 0 else 0.

In [None]:
class JaneStreetBaseModel(L.LightningModule):

    def __init__(self, 
                 model: nn.Module,
                 losses: List[nn.Module] | nn.Module, 
                 loss_weights: List[float], 
                 l1_lambda: float = 1e-4,
                 l2_lambda: float = 1e-4,
                 optimizer: str = 'Adam',
                 optimizer_cfg: Dict[str, Any] = dict(),
                 scheduler: str = None,
                 scheduler_cfg: Dict[str, Any] = dict()):
        super(JaneStreetBaseModel, self).__init__()   
        assert isinstance(losses, nn.Module) or len(losses) == len(loss_weights), 'Each loss must have a weight'
        assert len(loss_weights) == 0 or min(loss_weights) > 0, 'Losses must have positive weights'
        self.model = model
        losses = [losses] if isinstance(losses, nn.Module) else losses
        self.losses = nn.ModuleList(losses) 
        self.loss_weights = [1.0] if isinstance(losses, nn.Module) else loss_weights
        self.l1_lambda = l1_lambda
        self.l2_lambda = l2_lambda
        self.optimizer_name = optimizer
        self.optimizer_cfg = optimizer_cfg
        self.scheduler_name = scheduler
        self.scheduler_cfg = scheduler_cfg

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, weights = batch
        y_hat = self.forward(x).squeeze()
        loss = self._compute_loss(y_hat, y, weights)
        reg_loss, l1_loss, l2_loss = self._regularization_loss()
        with torch.no_grad():
            metrics = self._compute_metrics(y_hat, y, weights, prefix='train')
        metrics['train_loss'] = loss
        metrics['train_l1_reg'] = l1_loss
        metrics['train_l2_reg'] = l2_loss
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss + reg_loss

    def validation_step(self, batch, batch_idx):
        x, y, weights = batch
        y_hat = self.forward(x).squeeze()
        loss = self._compute_loss(y_hat, y, weights)
        metrics = self._compute_metrics(y_hat, y, weights, prefix='val')
        metrics['val_loss'] = loss
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def _compute_metrics(self, preds, targets, weights, prefix='val'):
        metrics = dict()
        metrics[f'{prefix}_wmse'] = (weights * (preds - targets) ** 2).sum() / weights.sum()
        metrics[f'{prefix}_wmae'] = (weights * (preds - targets).abs()).sum() / weights.sum()
        metrics[f'{prefix}_wr2'] = weighted_r2_score(preds, targets, weights)
        return metrics

    def _compute_loss(self, preds, targets, weights):
        loss = 0
        for i in range(len(self.losses)):
            loss += self.losses[i](preds, targets, weights=weights) * self.loss_weights[i]
        return loss
    
    def _regularization_loss(self):
        reg_loss = 0
        if self.l1_lambda > 0:
            l1_loss = sum(p.abs().sum() for p in self.parameters())
            reg_loss += l1_loss * self.l1_lambda
            
        if self.l2_lambda > 0:
            l2_loss = sum(p.pow(2).sum() for p in self.parameters())
            reg_loss += l2_loss * self.l2_lambda
            
        return reg_loss, l1_loss, l2_loss

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.optimizer_name)(self.parameters(), **self.optimizer_cfg)
        if self.scheduler_name is None:
            return optimizer
        scheduler = getattr(torch.optim.lr_scheduler, self.scheduler_name)(optimizer, **self.scheduler_cfg)
        return [optimizer], [{'scheduler': scheduler, 'monitor': 'val_wr2'}]

In [None]:
class WeightedMSELoss(nn.Module):
    def __init__(self):
        super(WeightedMSELoss, self).__init__()
    
    def forward(self, predictions: Tensor, targets: Tensor, weights: Tensor) -> Tensor:
        squared_diff = (predictions - targets) ** 2
        weighted_squared_diff = weights * squared_diff
        return weighted_squared_diff.sum() / weights.sum()

In [None]:
class SimpleNNModel(nn.Module):
    def __init__(self, input_features, hidden_dims=[], use_dropout=True, 
                 dropout_rate=0.1, use_bn=True, output_dim=1, use_tanh=False, final_mult=1.0):
        super(SimpleNNModel, self).__init__()
        self.final_mult = final_mult
        self.use_tanh = use_tanh
        
        layers = []
        in_features = input_features

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_features, hidden_dim))
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.LeakyReLU())
            if use_dropout:
                layers.append(nn.Dropout(dropout_rate))
            in_features = hidden_dim
            
        layers.append(nn.Linear(in_features, output_dim))
        if self.use_tanh:
            layers.append(nn.Tanh())
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.final_mult * self.model(x)

In [None]:
class EpochStatsCallback(C.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics  # Get all logged metrics
        train_metrics = {k: v.item() for k, v in metrics.items() if k.startswith("train_")}
        epoch = trainer.current_epoch
        print(f"\n[Epoch {epoch} - Training]")
        for key, value in train_metrics.items():
            print(f"{key}: {value:.4f}")

    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics  # Get all logged metrics
        val_metrics = {k: v.item() for k, v in metrics.items() if k.startswith("val_")}
        epoch = trainer.current_epoch
        print(f"\n[Epoch {epoch} - Validation]")
        for key, value in val_metrics.items():
            print(f"{key}: {value:.4f}")

In [None]:
def train(model: L.LightningModule,
          train_dataloader: DataLoader,
          val_dataloader: DataLoader,
          checkpoint_path: Union[str, Path] = None,
          max_epochs: int = 1000,
          eval_frequency: int = 1,
          log_every_n_steps: int = 10,
          precision: str = "16-mixed",
          accumulate_grad_batches: int = 1,
          gradient_clip_val: int = 10,
          gradient_clip_algorithm: str = 'norm',
          use_swa : bool = False,
          swa_cfg: Dict[str, Any] = None,
          use_early_stopping: bool = False,
          early_stopping_cfg: Dict[str, Any] = None,
          use_model_ckpt: bool = True,
          model_ckpt_cfg: Dict[str, Any] = None,
          seed: int = 42,
          compile: bool = False):
    if compile:
        model = torch.compile(model, fullgraph=False, dynamic=False)

    L.seed_everything(seed, workers=True)

    callbacks = [C.ModelSummary(max_depth=3), 
                 C.LearningRateMonitor(logging_interval='epoch'),
                 EpochStatsCallback()]
    if use_swa:
        callbacks.append(C.StochasticWeightAveraging(**swa_cfg))
    if use_early_stopping:
        callbacks.append(C.EarlyStopping(**early_stopping_cfg))
    if use_model_ckpt:
        callbacks.append(C.ModelCheckpoint(**model_ckpt_cfg))

    trainer = L.Trainer(
        max_epochs=max_epochs,
        check_val_every_n_epoch=eval_frequency,
        log_every_n_steps=log_every_n_steps,
        precision=precision,
        accumulate_grad_batches=accumulate_grad_batches,
        gradient_clip_val=gradient_clip_val,
        gradient_clip_algorithm=gradient_clip_algorithm,
        callbacks=callbacks
    )
    trainer.fit(model=model, train_dataloaders=train_dataloader,
                val_dataloaders=val_dataloader, ckpt_path=checkpoint_path)
    return model

In [None]:
from prj.config import EXP_DIR


scheduler = 'ReduceLROnPlateau'
scheduler_cfg = dict(mode='max', factor=0.1, patience=3, verbose=True, min_lr=1e-8)


base_model = SimpleNNModel(79, hidden_dims=[128, 128], dropout_rate=0.1, final_mult=5.0, use_tanh=True)
model = JaneStreetBaseModel(base_model, [WeightedMSELoss()], [1], l1_lambda=1e-6, 
                            l2_lambda=1e-4, scheduler=scheduler, scheduler_cfg=scheduler_cfg)


data_dir = str(EXP_DIR / 'tmp' / 'model1')
os.makedirs(data_dir, exist_ok=True)
ckpt_config = {'dirpath': data_dir, 'filename': 'mlp', 'save_top_k': 1,
               'monitor': 'val_wr2', 'verbose': True, 'mode': 'max'}
early_stopping = {'monitor': 'val_wr2', 'min_delta': 0.00, 'patience': 5, 'verbose': True, 'mode': 'max'}
swa_config = {'swa_lrs': 0.05, 'swa_epoch_start': 4}
model = train(model, train_dataloader, val_dataloader, max_epochs=10, precision='32-true', 
              use_model_ckpt=True, gradient_clip_val=10, use_early_stopping=True, 
              early_stopping_cfg=early_stopping, model_ckpt_cfg=ckpt_config, 
              use_swa=False, swa_cfg=swa_config)

In [None]:
base_model = SimpleNNModel(79, hidden_dims=[128, 128], dropout_rate=0.1, final_mult=5.0, use_tanh=True)
model = JaneStreetBaseModel.load_from_checkpoint(
    f"{data_dir}/mlp.ckpt", 
    model=base_model,
    losses=[WeightedMSELoss()],
    loss_weights=[1]
)

In [None]:
import torch
from sklearn.metrics import r2_score
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

val_dataset = JaneStreetBaseDataset(val_ds, num_days_batch=50)
val_dataloader = DataLoader(val_dataset, batch_size=2048, shuffle=False)

y_hat = []
y = []
weights = []

model.eval()
for x, targets, w in tqdm(iter(val_dataloader)):
    x = x.to(device)
    targets = targets.to(device)
    w = w.to(device)

    with torch.no_grad():
        preds_all = model(x)

    y_hat.append(preds_all.cpu().numpy().flatten())
    y.append(targets.cpu().numpy().flatten())
    weights.append(w.cpu().numpy().flatten())

y = np.concatenate(y)
y_hat = np.concatenate(y_hat)
weights = np.concatenate(weights)

(
    weighted_r2_score(y_hat, y, weights)
)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/