In [1]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import copy
import re
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
import math
import os
import json
import random

import polars as pl
import numpy as np
from numba import njit, prange
from scipy.stats import spearmanr, rankdata, norm

from sklearn.model_selection import TimeSeriesSplit, KFold
from sklearn.base import clone

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# from torch.profiler import profile, ProfilerActivity, record_function

from CONFIG import CONFIG
from PREPROCESSOR_V2 import PREPROCESSOR
from FEATURE_ENGINEERING_V2 import FEATURE_ENGINEERING
from DATASET import SequentialDataset
from SEQUENTIAL_NN_MODEL import CNNTransformerModel, GRUModel, LSTMModel, PureTransformerModel
from CROSS_SECTIONAL_NN_MODEL import DeepMLPModel, LinearModel, ResidualMLPModel
from ENSEMBLE_NN_vembedding import ENSEMBLE_NN
from NN_V2 import NN
from LOSS import CombinedICIRLoss

import time


def timer(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} took {end - start:.4f} seconds")
        return result

    return wrapper


In [2]:
torch.manual_seed(CONFIG.RANDOM_STATE)
random.seed(CONFIG.RANDOM_STATE)
np.random.seed(CONFIG.RANDOM_STATE)

In [3]:
folder = "nn_models/v7_embedding"
os.makedirs(folder, exist_ok=True)
with open(f"{folder}/features.json", "w") as f:
    json.dump(CONFIG.IMPT_COL, f)

with open(f"{folder}/features.json", "r") as f:
    json.load(f)

In [4]:
def gaussian_rank_transform(arr: np.array):
    # n_samples, n_targets = arr.shape
    transformed_targets = np.full_like(arr, np.nan)

    for i, row in enumerate(arr):
        # Find valid (non-NaN) assets for this timestep
        valid_mask = ~np.isnan(row)
        valid_arr = row[valid_mask]
        ranks = rankdata(valid_arr, method="average")
        percentile_ranks = (ranks - 0.5) / (len(ranks))
        percentile_ranks = np.clip(percentile_ranks, 1e-8, 1 - 1e-8)
        gaussian_values = norm.ppf(percentile_ranks)
        transformed_targets[i, valid_mask] = gaussian_values
    return transformed_targets

def parse_target_specs_from_csv(csv_path: str) -> List[Tuple]:
    """
    Parse target specifications from your CSV file.
    
    Returns:
        List of tuples: [(lag, pair_string), ...]
    """
    df = pl.read_csv(csv_path)
    target_specs = []
    
    for row in df.rows():
        lag = row[1]
        pair_str = row[2]
        target_specs.append((lag, pair_str))
    
    return target_specs

In [5]:
# # --- Prepare DataLoader ---
# # Create the dataset

# train_x = pl.scan_csv(CONFIG.TRAIN_X_PATH).filter(pl.col("date_id") <= CONFIG.MAX_TRAIN_DATE)
# train_y = pl.scan_csv(CONFIG.TRAIN_Y_PATH).filter(pl.col("date_id") <= CONFIG.MAX_TRAIN_DATE).fill_null(0).collect()

# train_x = PREPROCESSOR(df=train_x)
# train_x.clean()
# train_x = train_x.transform().lazy()

# train_x = FEATURE_ENGINEERING(df=train_x)
# train_x = train_x.create_all_features().collect().pivot(index=CONFIG.DATE_COL, on=["type", "instr"])
# train_x = train_x.rename({col: re.sub(r'[{",}]', "", col).replace(" ", "_").replace(",", "_") for col in train_x.columns})
# train_x = train_x.select(set(CONFIG.IMPT_COLS + [CONFIG.DATE_COL]))

In [6]:
def rank_correlation_sharpe(targets, predictions) -> float:
    """
    Calculates the rank correlation between predictions and target values,
    and returns its Sharpe ratio (mean / standard deviation).

    :param merged_df: DataFrame containing prediction columns (starting with 'prediction_')
                    and target columns (starting with 'target_')
    :return: Sharpe ratio of the rank correlation
    :raises ZeroDivisionError: If the standard deviation is zero
    """
    correlations = []

    for i, (pred_row, target_row) in enumerate(zip(predictions, targets)):
        # Find valid (non-NaN) assets for this timestep
        valid_mask = ~np.isnan(target_row)
        valid_pred = pred_row[valid_mask]
        valid_target = target_row[valid_mask]

        if np.std(pred_row) == 0 or np.std(target_row) == 0:
            raise ZeroDivisionError("Zero standard deviation in a row.")

        rho = np.corrcoef(rankdata(valid_pred, method="average"), rankdata(valid_target, method="average"))[0, 1]
        correlations.append(rho)

    daily_rank_corrs = np.array(correlations)
    std_dev = daily_rank_corrs.std(ddof=0)
    if std_dev == 0:
        raise ZeroDivisionError("Denominator is zero, unable to compute Sharpe ratio.")

    sharpe_ratio = daily_rank_corrs.mean() / std_dev
    return float(sharpe_ratio)

In [7]:
# --- Prepare DataLoader ---
# Create the dataset

train_x = pl.scan_csv(CONFIG.TRAIN_X_PATH).drop(CONFIG.DROP_COL)
train_x = PREPROCESSOR(df=train_x)
train_x = train_x.clean()

features = FEATURE_ENGINEERING(df=train_x)
train_x: pl.DataFrame = features.create_market_features()

train_y = pl.scan_csv(CONFIG.TRAIN_Y_PATH)

curr_y = (
    train_y.with_columns([pl.col(CONFIG.LAGS[f"lag{i}"]).exclude(CONFIG.DATE_COL).shift(i + 1) for i in range(1, 5)])
    .with_columns(pl.all().exclude(CONFIG.DATE_COL).shift())
    .filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series())))
    .collect()
    .fill_null(0)
    .lazy()
)

y_feat = FEATURE_ENGINEERING(df=curr_y)
ys = y_feat.create_Y_market_features()

train_x = train_x.join(ys, on=CONFIG.DATE_COL).fill_nan(0)


train_y = train_y.filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series()))).collect()
train_x = (
    train_x.with_columns([pl.when(pl.col(col).is_infinite()).then(0.0).otherwise(pl.col(col)).alias(col) for col in train_x.columns])
    .with_columns(pl.all().shrink_dtype())
    .filter(pl.col(CONFIG.DATE_COL).is_in(train_y.select(CONFIG.DATE_COL).to_series()))
    .with_columns(pl.col(CONFIG.DATE_COL).cast(pl.Int64))
    .select([CONFIG.DATE_COL] + CONFIG.IMPT_COL)
)

retrain_x = train_x.with_columns(pl.all().exclude(CONFIG.DATE_COL).shift(5))
retrain_y = train_y.filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series()))).with_columns(
    pl.all().exclude(CONFIG.DATE_COL).shift(5)
)

train_y_arr = train_y.drop(CONFIG.DATE_COL).to_numpy()

train_y = pl.DataFrame(gaussian_rank_transform(train_y_arr), schema=train_y.drop(CONFIG.DATE_COL).columns).insert_column(
    0, train_y.select(CONFIG.DATE_COL).to_series()
)


# pl.DataFrame(
#     (train_y_arr - np.nanmean(train_y_arr, axis=1).reshape(train_y_arr.shape[0], -1))
#     / np.nanstd(train_y_arr, axis=1).reshape(train_y_arr.shape[0], -1),
#     schema=train_y.drop(CONFIG.DATE_COL).columns,
# ).insert_column(0, train_y.select(CONFIG.DATE_COL).to_series())

retrain_y_arr = retrain_y.drop(CONFIG.DATE_COL).to_numpy()
retrain_y = pl.DataFrame(gaussian_rank_transform(retrain_y_arr), schema=train_y.drop(CONFIG.DATE_COL).columns).insert_column(
    0, train_y.select(CONFIG.DATE_COL).to_series()
)


# pl.DataFrame(
#     (retrain_y_arr - np.nanmean(retrain_y_arr, axis=1).reshape(retrain_y_arr.shape[0], -1))
#     / np.nanstd(retrain_y_arr, axis=1).reshape(retrain_y_arr.shape[0], -1),
#     schema=train_y.drop(CONFIG.DATE_COL).columns,
# ).insert_column(0, train_y.select(CONFIG.DATE_COL).to_series())

target_specs = parse_target_specs_from_csv(CONFIG.TARGET_PAIRS_PATH)


create_market_features took 9.8450 seconds


In [19]:
NN_model = NN(
    model=ENSEMBLE_NN(input_dim=len(train_x.columns) - 1, hidden_dim=32, output_dim=CONFIG.NUM_TARGET_COLUMNS, target_specs=target_specs, RNN="GRU"),
    seq_len=CONFIG.SEQ_LEN,
    batch_size=CONFIG.BATCH_SIZE,
    lr=0.0006,
    lr_refit=0.0001,
    epochs=200,
    early_stopping_patience=10,
    spearman_weight=0.4,
    listnet_weight=0.2,
    kendall_weight=0.15,
    pairwise_weight=0.1,
    topk_weight=0.1,
    mse_weight=0.1,  # Small MSE for stability
    listnet_temp=1.0,
    kendall_temp=0.1,
)

In [None]:
if CONFIG.RUN_CV:
    dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) <= CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )
    real_dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) > CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )

    cv = TimeSeriesSplit(n_splits=CONFIG.N_FOLDS)
    cv_split = cv.split(dates_unique)

    scores = []
    models = []
    for fold, (train_idx, valid_idx) in enumerate(cv_split):
        if fold <= 3:
            continue
        if CONFIG.VERBOSE:
            print("-" * 20 + f"Fold {fold}" + "-" * 20)
            print(f"Train dates from {dates_unique[train_idx].min()} to {dates_unique[train_idx].max()}")
            print(f"Valid dates from {dates_unique[valid_idx].min()} to {dates_unique[valid_idx].max()}")

        dates_train = dates_unique[train_idx]
        dates_valid = dates_unique[valid_idx]

        df_train = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))
        true_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))

        valid_period = range(min(dates_valid) - CONFIG.SEQ_LEN + 1, max(dates_valid) + 1)
        df_valid = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
        df_valid_current_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

        df_valid_retrain = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
        df_valid_current_y_retrain = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

        model_fold = copy.deepcopy(NN_model)

        model_fold.fit(
            train_set=(df_train, true_y),
            val_set=(df_valid, df_valid_current_y),
            retrain_set=(df_valid_retrain, df_valid_current_y_retrain),
            verbose=CONFIG.VERBOSE,
        )

        models.append(model_fold)

        torch.save(
            model_fold.model.state_dict(),
            f"{folder}/ensemble_{fold}.pth",
        )

        preds = []
        cnt_dates = 0
        model_save = copy.deepcopy(model_fold)

        model_fold.model.load_state_dict(
            torch.load(
                f"{folder}/ensemble_{fold}.pth",
                map_location=torch.device("cuda"),
            )
        )

        for date_id in tqdm(dates_valid):
            period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)

            df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            valid_lags = df_valid_date.to_numpy()[-CONFIG.SEQ_LEN :].astype(np.float64)

            if model_fold.refit and (cnt_dates > 0):
                upd_period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)
                df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(upd_period)).drop(CONFIG.DATE_COL).to_numpy()
                df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL) == date_id).drop(CONFIG.DATE_COL).to_numpy()

                if len(df_upd) > 0:
                    model_fold.update(df_upd, df_upd_current_y)

            preds_i = model_fold.predict(valid_lags)

            preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

            cnt_dates += 1

        preds = np.array(preds)

        score = rank_correlation_sharpe(
            df_valid_current_y.drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
            preds,
        )
        scores.append(score)

        print(f"LAST VALIDIDATION Sharpe: {score:.5f}")

        model_real = copy.deepcopy(model_fold)
        preds = []
        cnt_dates = 0
        for date_id in tqdm(real_dates_unique):
            period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)

            df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            valid_lags = df_valid_date.to_numpy()[-CONFIG.SEQ_LEN :].astype(np.float64)

            if model_fold.refit and (cnt_dates > 0):
                upd_period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)
                df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(upd_period)).drop(CONFIG.DATE_COL).to_numpy()
                df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL) == date_id).drop(CONFIG.DATE_COL).to_numpy()

                if len(df_upd) > 0:
                    model_real.update(df_upd, df_upd_current_y)

            preds_i = model_real.predict(valid_lags)

            preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

            cnt_dates += 1

        preds = np.array(preds)

        score = rank_correlation_sharpe(
            train_y.filter(pl.col(CONFIG.DATE_COL).is_in(real_dates_unique)).drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
            preds,
        )
        scores.append(score)
        print(f"REAL Sharpe: {score:.5f}")
else:
    dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) <= CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )
    real_dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) > CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )

    if CONFIG.VERBOSE:
        print(f"Train dates from {dates_unique.min()} to {dates_unique.max()}")
        print(f"Valid dates from {real_dates_unique.min()} to {real_dates_unique.max()}")

    dates_train = dates_unique
    dates_valid = real_dates_unique

    df_train = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))
    true_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))

    valid_period = range(min(dates_valid) - CONFIG.SEQ_LEN + 1, max(dates_valid) + 1)
    df_valid = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
    df_valid_current_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

    df_valid_retrain = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
    df_valid_current_y_retrain = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

    model_fold = copy.deepcopy(NN_model)

    model_fold.fit(
        train_set=(df_train, true_y),
        val_set=(df_valid, df_valid_current_y),
        retrain_set=(df_valid_retrain, df_valid_current_y_retrain),
        verbose=CONFIG.VERBOSE,
    )

    torch.save(
        model_fold.model.state_dict(),
        f"{folder}/ensemble_full.pth",
    )

    preds = []
    cnt_dates = 0
    model_save = copy.deepcopy(model_fold)

    model_fold.model.load_state_dict(
        torch.load(
            f"{folder}/ensemble_full.pth",
            map_location=torch.device("cuda"),
        )
    )

    model_real = copy.deepcopy(model_save)
    preds = []
    cnt_dates = 0
    for date_id in tqdm(real_dates_unique):
        period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)

        df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
        valid_lags = df_valid_date.to_numpy()[-CONFIG.SEQ_LEN :].astype(np.float64)

        if model_fold.refit and (cnt_dates > 0):
            upd_period = range(date_id - CONFIG.SEQ_LEN + 1, date_id + 1)
            df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(upd_period)).drop(CONFIG.DATE_COL).to_numpy()
            df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL) == date_id).drop(CONFIG.DATE_COL).to_numpy()

            if len(df_upd) > 0:
                model_real.update(df_upd, df_upd_current_y)

        preds_i = model_real.predict(valid_lags)

        preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

        cnt_dates += 1

    preds = np.array(preds)

    score = rank_correlation_sharpe(
        train_y.filter(pl.col(CONFIG.DATE_COL).is_in(real_dates_unique)).drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
        preds,
    )
    print(f"REAL Sharpe: {score:.5f}")

--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826


  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.25282


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.40014


In [None]:
real = train_y.filter(pl.col(CONFIG.DATE_COL) >= 1827).drop(CONFIG.DATE_COL).to_numpy()
rank_real = np.full_like(real, np.nan)
rank_preds = np.full_like(real, np.nan)
for i in range(real.shape[0]):
    mask = ~np.isnan(real[i])
    rank_real[i][mask] = rankdata(real[i][mask])
    rank_preds[i][mask] = rankdata(preds[i][mask])

In [None]:
(pl.DataFrame(rank_real) - pl.DataFrame(rank_preds)).with_columns(pl.all().abs()).fill_nan(0).mean().transpose(include_header=True).sort(
    by="column_0"
).tail(10)

In [8]:
import optuna
from optuna.samplers import TPESampler
import optuna.study.study

In [9]:
def objective(trial: optuna.trial.Trial) -> float:
    batchsize = trial.suggest_categorical(name="batch_size", choices=[32, 64, 128, 256])
    seq_len = trial.suggest_categorical(name="seq_len", choices=[16, 32, 64, 128])
    hidden_dim = trial.suggest_categorical(name="hidden_dim", choices=[16, 32, 64])
    lr = trial.suggest_float(name="lr", low=0.005, high=0.01, step=0.001)
    refit_lr = trial.suggest_float(name="refit_lr", low=0.005, high=0.01, step=0.001)
    spearman_weight = trial.suggest_float("spearman_weight", 0.05, 0.2, step=0.05)
    listnet_weight = trial.suggest_float("listnet_weight", 0.2, 0.4, step=0.05)
    kendall_weight = trial.suggest_float("kendall_weight", 0.05, 0.25, step=0.05)
    pairwise_weight = trial.suggest_float("pairwise_weight", 0.2, 0.6, step=0.05)
    topk_weight = trial.suggest_float("topk_weight", 0.2, 0.3, step=0.05)
    mse_weight = trial.suggest_float("mse_weight", 0.05, 0.2, step=0.05)

    NN_model = NN(
        model=ENSEMBLE_NN(
            input_dim=len(train_x.columns) - 1, hidden_dim=hidden_dim, target_specs=target_specs, output_dim=CONFIG.NUM_TARGET_COLUMNS, RNN="GRU"
        ),
        batch_size=batchsize,
        lr=lr,
        seq_len=seq_len,
        lr_refit=refit_lr,
        epochs=200,
        early_stopping_patience=10,
        spearman_weight=spearman_weight,
        listnet_weight=listnet_weight,
        kendall_weight=kendall_weight,
        pairwise_weight=pairwise_weight,
        topk_weight=topk_weight,
        mse_weight=mse_weight,  # Small MSE for stability
        listnet_temp=1.0,
        kendall_temp=0.1,
    )

    dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) <= CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )
    real_dates_unique = (
        train_x.filter(pl.col(CONFIG.DATE_COL) > CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
    )

    cv = TimeSeriesSplit(n_splits=CONFIG.N_FOLDS)
    cv_split = cv.split(dates_unique)

    for fold, (train_idx, valid_idx) in enumerate(cv_split):
        if fold <= 3:
            continue
        if CONFIG.VERBOSE:
            print("-" * 20 + f"Fold {fold}" + "-" * 20)
            print(f"Train dates from {dates_unique[train_idx].min()} to {dates_unique[train_idx].max()}")
            print(f"Valid dates from {dates_unique[valid_idx].min()} to {dates_unique[valid_idx].max()}")

        dates_train = dates_unique[train_idx]
        dates_valid = dates_unique[valid_idx]

        df_train = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))
        true_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))

        valid_period = range(min(dates_valid) - seq_len + 1, max(dates_valid) + 1)
        df_valid = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
        df_valid_current_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

        df_valid_retrain = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
        df_valid_current_y_retrain = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

        model_fold = copy.deepcopy(NN_model)

        model_fold.fit(
            train_set=(df_train, true_y),
            val_set=(df_valid, df_valid_current_y),
            retrain_set=(df_valid_retrain, df_valid_current_y_retrain),
            verbose=CONFIG.VERBOSE,
        )

        preds = []
        cnt_dates = 0

        for date_id in tqdm(dates_valid):
            period = range(date_id - seq_len + 1, date_id + 1)

            df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            valid_lags = df_valid_date.to_numpy()[-seq_len:].astype(np.float64)

            if model_fold.refit and (cnt_dates > 0):
                upd_period = range(date_id - seq_len + 1, date_id + 1)
                df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(upd_period)).drop(CONFIG.DATE_COL).to_numpy()
                df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL) == date_id).drop(CONFIG.DATE_COL).to_numpy()

                if len(df_upd) > 0:
                    model_fold.update(df_upd, df_upd_current_y)

            preds_i = model_fold.predict(valid_lags)

            preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

            cnt_dates += 1

        preds = np.array(preds)

        val_score = rank_correlation_sharpe(
            df_valid_current_y.drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
            preds,
        )
        print(f"LAST VALIDIDATION Sharpe: {val_score:.5f}")

        model_real = copy.deepcopy(model_fold)
        preds = []
        cnt_dates = 0
        for date_id in tqdm(real_dates_unique):
            period = range(date_id - seq_len + 1, date_id + 1)

            df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            valid_lags = df_valid_date.to_numpy()[-seq_len:].astype(np.float64)

            if model_real.refit and (cnt_dates > 0):
                upd_period = range(date_id - seq_len + 1, date_id + 1)
                df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(upd_period)).drop(CONFIG.DATE_COL).to_numpy()
                df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL) == date_id).drop(CONFIG.DATE_COL).to_numpy()

                if len(df_upd) > 0:
                    model_real.update(df_upd, df_upd_current_y)

            preds_i = model_real.predict(valid_lags)

            preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

            cnt_dates += 1

        preds = np.array(preds)

        final_score = rank_correlation_sharpe(
            train_y.filter(pl.col(CONFIG.DATE_COL).is_in(real_dates_unique)).drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
            preds,
        )
        print(f"REAL Sharpe: {final_score:.5f}")

    return val_score * 0.2 + final_score * 0.8


In [10]:
def optimize(study_name, n_trials: int = 200) -> optuna.study.study:
    """
    Runs Optuna optimization over the defined search space.

    Parameters
    ----------
    n_trials : int, optional
        Number of trials to run, by default 200.

    Returns
    -------
    optuna.study.study
        study object
    """
    sampler = TPESampler(seed=CONFIG.RANDOM_STATE)
    study = optuna.create_study(
        study_name=study_name,
        storage=f"sqlite:///{folder}/{study_name}.db",
        direction="maximize",
        sampler=sampler,
        load_if_exists=True,
    )
    study.optimize(
        objective,
        n_trials=20,
        show_progress_bar=True,
    )

    return study

In [None]:
optimize("trials_vary_loss_v1")

[I 2025-09-17 23:33:44,654] A new study created in RDB with name: trials_vary_loss_v1


  0%|          | 0/20 [00:00<?, ?it/s]

--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.8961 seconds
   1    |   0.9500   |   0.9128   |  -0.0192   |  -0.1722   |    0.7312     |    0.6111     |    -0.0012    |    -0.0065    |     1.1865     |     1.1688     |   0.9984    |   0.9913    |   1.9923   |   2.0106   |    0.0343    |   0.1763   | 0.01000
validate_one_epoch took 8.6215 seconds
   2    |   0.8841   |   0.8779   |  -0.1110   |  -0.2582   |    0.5682     |    0.5544     |    -0.0043    |    -0.0070    |     1.0983     |     1.0974     |   0.9959    |   0.9967    |   1.9782   |   2.0191   | 

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.18297


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.11532
[I 2025-09-17 23:43:33,474] Trial 0 finished with value: 0.12884970286983816 and parameters: {'batch_size': 64, 'seq_len': 128, 'hidden_dim': 32, 'lr': 0.01, 'refit_lr': 0.009000000000000001, 'spearman_weight': 0.05, 'listnet_weight': 0.2, 'kendall_weight': 0.05, 'pairwise_weight': 0.30000000000000004, 'topk_weight': 0.25, 'mse_weight': 0.1}. Best is trial 0 with value: 0.12884970286983816.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.8858 seconds
   1    |   0.7857   |   0.7239   |  -0.0147   |  -0.0661   |    0.7145     |    0.5503 

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.23349


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: -0.05295
[I 2025-09-17 23:48:43,670] Trial 1 finished with value: 0.004334539541246504 and parameters: {'batch_size': 64, 'seq_len': 64, 'hidden_dim': 32, 'lr': 0.008, 'refit_lr': 0.006, 'spearman_weight': 0.05, 'listnet_weight': 0.4, 'kendall_weight': 0.25, 'pairwise_weight': 0.55, 'topk_weight': 0.2, 'mse_weight': 0.05}. Best is trial 0 with value: 0.12884970286983816.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.7056 seconds
   1    |   0.6940   |   0.6485   |   0.0032   |   0.0102   |    0.6547     |    0.5448     |    0.0001     |    -0.

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.43721


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.11767
[I 2025-09-18 00:00:47,291] Trial 2 finished with value: 0.18157498908186576 and parameters: {'batch_size': 32, 'seq_len': 32, 'hidden_dim': 64, 'lr': 0.006, 'refit_lr': 0.01, 'spearman_weight': 0.2, 'listnet_weight': 0.4, 'kendall_weight': 0.25, 'pairwise_weight': 0.45, 'topk_weight': 0.3, 'mse_weight': 0.05}. Best is trial 2 with value: 0.18157498908186576.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.7627 seconds
   1    |   0.9480   |   0.9095   |  -0.0658   |  -0.1387   |    0.8766     |    0.7259     |    -0.0020    |    -0.0078

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.42904


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.08177
[I 2025-09-18 00:06:19,590] Trial 3 finished with value: 0.151227852503924 and parameters: {'batch_size': 256, 'seq_len': 32, 'hidden_dim': 64, 'lr': 0.005, 'refit_lr': 0.01, 'spearman_weight': 0.2, 'listnet_weight': 0.2, 'kendall_weight': 0.05, 'pairwise_weight': 0.55, 'topk_weight': 0.3, 'mse_weight': 0.15000000000000002}. Best is trial 2 with value: 0.18157498908186576.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.4376 seconds
   1    |   0.8998   |   0.8490   |   0.0095   |  -0.0853   |    0.6691     |    0.5571     |    -0.0001  

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.39275


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.13158
[I 2025-09-18 00:16:04,824] Trial 4 finished with value: 0.18381383478293775 and parameters: {'batch_size': 32, 'seq_len': 16, 'hidden_dim': 64, 'lr': 0.008, 'refit_lr': 0.01, 'spearman_weight': 0.1, 'listnet_weight': 0.2, 'kendall_weight': 0.2, 'pairwise_weight': 0.5, 'topk_weight': 0.25, 'mse_weight': 0.2}. Best is trial 4 with value: 0.18381383478293775.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.3748 seconds
   1    |   1.0030   |   0.9577   |  -0.0217   |  -0.2561   |    0.8201     |    0.7168     |    -0.0016    |    -0.0114  

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.04137


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.05948
[I 2025-09-18 00:24:54,072] Trial 5 finished with value: 0.05585782347779414 and parameters: {'batch_size': 64, 'seq_len': 64, 'hidden_dim': 32, 'lr': 0.007, 'refit_lr': 0.009000000000000001, 'spearman_weight': 0.05, 'listnet_weight': 0.2, 'kendall_weight': 0.1, 'pairwise_weight': 0.25, 'topk_weight': 0.3, 'mse_weight': 0.2}. Best is trial 4 with value: 0.18381383478293775.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.5231 seconds
   1    |   0.7107   |   0.6876   |  -0.0310   |  -0.1317   |    0.7256     |    0.6612     |    -0.0004 

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.12803


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.20445
[I 2025-09-18 00:44:24,675] Trial 6 finished with value: 0.189165232076017 and parameters: {'batch_size': 64, 'seq_len': 128, 'hidden_dim': 16, 'lr': 0.007, 'refit_lr': 0.009000000000000001, 'spearman_weight': 0.2, 'listnet_weight': 0.2, 'kendall_weight': 0.15000000000000002, 'pairwise_weight': 0.35000000000000003, 'topk_weight': 0.2, 'mse_weight': 0.05}. Best is trial 6 with value: 0.189165232076017.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.8118 seconds
   1    |   0.8137   |   0.7672   |  -0.0289   |  -0.1786   |    0.7592     |

  0%|          | 0/304 [00:00<?, ?it/s]

LAST VALIDIDATION Sharpe: 0.07433


  0%|          | 0/134 [00:00<?, ?it/s]

REAL Sharpe: 0.18777
[I 2025-09-18 00:54:26,537] Trial 7 finished with value: 0.16508096221807173 and parameters: {'batch_size': 64, 'seq_len': 64, 'hidden_dim': 32, 'lr': 0.006, 'refit_lr': 0.005, 'spearman_weight': 0.15000000000000002, 'listnet_weight': 0.30000000000000004, 'kendall_weight': 0.05, 'pairwise_weight': 0.30000000000000004, 'topk_weight': 0.3, 'mse_weight': 0.05}. Best is trial 6 with value: 0.189165232076017.
--------------------Fold 4--------------------
Train dates from 3 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
 Epoch  | TrainLoss  |  ValLoss   | TrainSharpe  | ValSharpe  | TrainICIR  |  ValICIR   | TrainListNet  |  ValListNet   | TrainKendall  |  ValKendall   | TrainPairwise  |  ValPairwise   |  TrainTopK  |   ValTopK   |  TrainMSE  |   ValMSE   | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------
validate_one_epoch took 8.6771 seconds
   1    |   0.8570   |   0.8208   |  -0.0094   |  -0.0798   |    0.7884   