In [None]:
import numpy as np
import polars as pl
from pathlib import Path
import gc
import os
from typing import List, Union, Dict, Any

from prj.data.data_loader import DataLoader as BaseDataLoader
import torch
from torch.utils.data import DataLoader
from prj.model.torch.datasets.base import JaneStreetBaseDataset
from prj.model.torch.losses import WeightedMSELoss
from prj.model.torch.models.mlp import Mlp
from prj.model.torch.wrappers.base import JaneStreetModelWrapper
from prj.model.torch.utils import train

In [2]:
from prj.config import DATA_DIR
from prj.data.data_loader import DataConfig

data_args = {'zero_fill': True}
config = DataConfig(**data_args)
loader = BaseDataLoader(data_dir=DATA_DIR, config=config)

In [None]:
from prj.data.data_loader import PARTITIONS_DATE_INFO
# start_dt, end_dt = 1020, 1300
start_dt, end_dt = PARTITIONS_DATE_INFO[8]['min_date'], PARTITIONS_DATE_INFO[8]['max_date']
val_ratio = 0.2
es_ratio = 0.1
early_stopping = True

train_ds, val_ds = loader.load_train_and_val(start_dt=start_dt, end_dt=end_dt, val_ratio=val_ratio)       

train_ds = loader.load(start_dt, end_dt)
val_ds = loader.load_with_partition(start_part_id=9, end_part_id=9)
 
es_ds = None
if early_stopping:
    train_dates = train_ds.select('date_id').unique().collect().to_series().sort()
    split_point = int(len(train_dates) * (1 - es_ratio))
    split_date = train_dates[split_point]
    es_ds = train_ds.filter(pl.col('date_id').ge(split_date))
    train_ds = train_ds.filter(pl.col('date_id').lt(split_date))

n_rows_train = train_ds.select(pl.len()).collect().item()
n_dates_train = train_ds.select('date_id').unique().collect().count().item()
n_rows_es = es_ds.select(pl.len()).collect().item() if early_stopping else 0
n_dates_es = es_ds.select('date_id').unique().collect().count().item() if early_stopping else 0
n_rows_val = val_ds.select(pl.len()).collect().item()
n_dates_val = val_ds.select('date_id').unique().collect().count().item()
print(f'N rows train: {n_rows_train}, ES: {n_rows_es}, VAL: {n_rows_val}')
print(f'N dates train: {n_dates_train}, ES: {n_dates_es}, VAL: {n_dates_val}')

In [4]:
train_ds = JaneStreetBaseDataset(train_ds, features=loader.features)
if early_stopping:
    es_ds = JaneStreetBaseDataset(es_ds, features=loader.features)

In [6]:
scheduler = 'ReduceLROnPlateau'
scheduler_cfg = dict(mode='min', factor=0.1, patience=3, verbose=True, min_lr=1e-8)
model = Mlp(input_dim=(len(loader.features),), hidden_dims=[128, 64, 32], use_dropout=True, use_bn=True, dropout_rate=0.1)

optimizer_cfg = {'lr': 1e-4}
model = JaneStreetModelWrapper(
    model, 
    [WeightedMSELoss()], 
    [1], 
    scheduler=None, 
    optimizer_cfg=optimizer_cfg
)

In [None]:
early_stopping = {'monitor': 'val_wr2', 'min_delta': 0.00, 'patience': 5, 'verbose': True, 'mode': 'max'}
batch_size = 1024
num_workers = 0
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
es_dataloader = DataLoader(es_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) if early_stopping else None


model: JaneStreetModelWrapper = train(model, train_dataloader, es_dataloader, accelerator='auto',
                                max_epochs=20, precision='32-true', use_model_ckpt=False, 
                                gradient_clip_val=10, use_early_stopping=True, 
                                early_stopping_cfg=early_stopping, compile=False)

In [None]:
X_val, y_val, w_val, _ = loader._build_splits(val_ds)
X_val.shape, y_val.shape, w_val.shape

In [None]:
y_hat = model.predict(X_val)
y_hat.shape

In [None]:
from prj.metrics import weighted_mae, weighted_mse, weighted_r2, weighted_rmse

{
    'r2_w': weighted_r2(y_val, y_hat, weights=w_val),
    'mae_w': weighted_mae(y_val, y_hat, weights=w_val),
    'mse_w': weighted_mse(y_val, y_hat, weights=w_val),
    'rmse_w': weighted_rmse(y_val, y_hat, weights=w_val),
}