In [1]:
from dataclasses import dataclass
from typing import Dict, List, Optional
import copy
import re
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
import math

import polars as pl
import numpy as np
from numba import njit, prange
from scipy.stats import spearmanr, rankdata

from sklearn.model_selection import TimeSeriesSplit, KFold
from sklearn.base import clone

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# from torch.profiler import profile, ProfilerActivity, record_function

from CONFIG import CONFIG
from PREPROCESSOR_V2 import PREPROCESSOR
from FEATURE_ENGINEERING_V2 import FEATURE_ENGINEERING
from SEQUENTIAL_NN_MODEL import CNNTransformerModel, GRUModel, LSTMModel, PureTransformerModel
from CROSS_SECTIONAL_NN_MODEL import DeepMLPModel, LinearModel, ResidualMLPModel
from LOSS import CombinedICIRLoss

import time


def timer(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} took {end - start:.4f} seconds")
        return result

    return wrapper

In [2]:
# # --- Prepare DataLoader ---
# # Create the dataset

# train_x = pl.scan_csv(CONFIG.TRAIN_X_PATH).filter(pl.col("date_id") <= CONFIG.MAX_TRAIN_DATE)
# train_y = pl.scan_csv(CONFIG.TRAIN_Y_PATH).filter(pl.col("date_id") <= CONFIG.MAX_TRAIN_DATE).fill_null(0).collect()

# train_x = PREPROCESSOR(df=train_x)
# train_x.clean()
# train_x = train_x.transform().lazy()

# train_x = FEATURE_ENGINEERING(df=train_x)
# train_x = train_x.create_all_features().collect().pivot(index=CONFIG.DATE_COL, on=["type", "instr"])
# train_x = train_x.rename({col: re.sub(r'[{",}]', "", col).replace(" ", "_").replace(",", "_") for col in train_x.columns})
# train_x = train_x.select(set(CONFIG.IMPT_COLS + [CONFIG.DATE_COL]))

In [3]:
def rank_correlation_sharpe(targets, predictions) -> float:
    """
    Calculates the rank correlation between predictions and target values,
    and returns its Sharpe ratio (mean / standard deviation).

    :param merged_df: DataFrame containing prediction columns (starting with 'prediction_')
                      and target columns (starting with 'target_')
    :return: Sharpe ratio of the rank correlation
    :raises ZeroDivisionError: If the standard deviation is zero
    """
    correlations = []

    for i, (pred_row, target_row) in enumerate(zip(predictions, targets)):
        # Find valid (non-NaN) assets for this timestep
        valid_mask = ~np.isnan(target_row)
        valid_pred = pred_row[valid_mask]
        valid_target = target_row[valid_mask]

        if np.std(pred_row) == 0 or np.std(target_row) == 0:
            raise ZeroDivisionError("Zero standard deviation in a row.")

        rho = np.corrcoef(rankdata(valid_pred, method="average"), rankdata(valid_target, method="average"))[0, 1]
        correlations.append(rho)

    daily_rank_corrs = np.array(correlations)
    std_dev = daily_rank_corrs.std(ddof=0)
    if std_dev == 0:
        raise ZeroDivisionError("Denominator is zero, unable to compute Sharpe ratio.")

    sharpe_ratio = daily_rank_corrs.mean() / std_dev
    return float(sharpe_ratio)

In [4]:
class BaseFinancialDataset(Dataset, ABC):
    """Base class for financial datasets"""

    def __init__(self, X: pl.DataFrame, y: pl.DataFrame, date_column: str = CONFIG.DATE_COL):
        """
        Base initialization

        Args:
            data: Preprocessed DataFrame (scaling already done)
            target_columns: List of target column names (424 targets)
            feature_columns: List of feature column names
            date_column: Name of date identifier column
        """
        self.X = X.clone()
        self.y = y.clone()
        self.date_column = date_column

        # Sort by date
        self.X = self.X.sort(by=CONFIG.DATE_COL)
        self.y = self.y.sort(by=CONFIG.DATE_COL)
        self.unique_dates = sorted(self.X[self.date_column].unique())
        self.device = torch.device("cpu")

        self.lag_featues = CONFIG.LAG_FEATURES

        self._prepare_samples()

    def _prepare_samples(self):
        self.samples = {}
        for lag, features in self.lag_featues.items():
            X = self.X.select(features).to_numpy()

            # self.num_features = self.X.shape[-1]
            # Split continuous and categorical features
            continuous_data = torch.tensor(X, dtype=torch.float32)
            continuous_data = torch.nan_to_num(continuous_data, 0)
            self.samples[lag] = continuous_data

        self.dates = torch.tensor(self.unique_dates, dtype=torch.int16)
        self.y = torch.tensor(self.y.drop(CONFIG.DATE_COL).to_numpy(), dtype=torch.float32)

        self.unique_date, self.inverse_indices, self.counts = torch.unique(self.dates, return_inverse=True, return_counts=True)

        self.n_unique_dates = len(self.unique_date)

    @abstractmethod
    def __getitem__(self, idx):
        """Get item - implemented by subclasses"""
        pass

In [5]:
class SequentialDataset(BaseFinancialDataset):
    """Dataset for sequential models (LSTM, Transformers, CNN)"""

    def __init__(self, X: pl.DataFrame, Y: pl.DataFrame, date_column: str = CONFIG.DATE_COL, prediction_horizon: int = 1):
        """
        Sequential dataset for temporal models

        Args:
            data: Preprocessed DataFrame
            target_columns: Target column names
            feature_columns: Feature column names
            date_column: Date identifier column
            sequence_length: Number of time steps in sequence
            prediction_horizon: Steps ahead to predict (usually 1)
        """
        self.prediction_horizon = prediction_horizon

        super().__init__(X, Y)

        self._generate_sequence()

    def _generate_sequence(self):
        self.sequence_x = {1: [], 2: [], 3: [], 4: []}
        self.sequence_y = []

        for date in range(max(CONFIG.LAG_SEQ_LEN.values()), self.n_unique_dates):
            for lag, seq in CONFIG.LAG_SEQ_LEN.items():
                sample = self.samples[lag]
                self.sequence_x[lag].append(sample[date - seq : date])

            self.sequence_y.append(self.y[date - 1])
        for lag, seq in self.sequence_x.items():
            self.sequence_x[lag] = torch.stack(seq)
        self.sequence_y = torch.stack(self.sequence_y)

    def __len__(self):
        return self.n_unique_dates - max(CONFIG.LAG_SEQ_LEN.values())

    def __getitem__(self, idx):
        """Get sequence, target, and date_id"""
        continuous_seq = {lag: self.sequence_x[lag][idx] for lag in self.sequence_x.keys()}  # (seq_len, N_FEATURES)
        target = self.sequence_y[idx]  # (424,)

        return continuous_seq, target

In [6]:
def flatten_collate_fn(batch: list) -> dict[str, torch.Tensor]:
    """
    Collate function for DataLoader to flatten the batch.

    Args:
        batch (list): List of tuples containing tensors.

        tuple[torch.Tensor]: Flattened tensors (type, instr, X, y).
    """
    X, curr_y = zip(*batch)
    lag_keys = X[0].keys()

    # Stack each lag sequence across the batch
    continuous_batch = {
        lag: torch.stack([x_dict[lag] for x_dict in X])  # shape: [B, seq_len, num_features]
        for lag in lag_keys
    }

    curr_y = torch.stack(curr_y)

    return {"continuous": continuous_batch, "current": curr_y}


In [7]:
# seq_val_dataset = SequentialDataset(df_valid, df_valid_current_y)
# seq_val_dataloader = DataLoader(
#     seq_val_dataset,
#     batch_size=6,
#     shuffle=False,
#     collate_fn=flatten_collate_fn,
#     pin_memory=True,
#     # num_workers=6,
#     # persistent_workers=True,
#     # prefetch_factor=2,
#     drop_last=True,
# )

In [8]:
class LagSpecificEnsemble(nn.Module):
    """Ensemble of multiple architectures for each lag"""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()

        # seq models
        self.cnn_transformer = CNNTransformerModel(input_dim, hidden_dim, output_dim, CONFIG.SEQ_LEN)
        self.gru_model = GRUModel(input_dim, hidden_dim, output_dim)
        self.lstm_model = LSTMModel(input_dim, hidden_dim, output_dim, num_layers=2)
        self.pure_transformer = PureTransformerModel(input_dim, hidden_dim, output_dim)

        # cross sectional models
        self.mlp = DeepMLPModel(input_dim, [64, 32], output_dim)
        self.linear = LinearModel(input_dim, output_dim)
        self.residual = ResidualMLPModel(input_dim, hidden_dim, output_dim)

        # Ensemble weights (learnable)
        self.ensemble_weights = nn.Parameter(torch.ones(3) / 3)

        self.ensemble_dropout = nn.Dropout(0.3)
        self.prediction_dropout = nn.Dropout(0.3)
        self.input_dropout = nn.Dropout(0.3)

    def forward(
        self,
        x_seq,
    ):
        x_seq = self.input_dropout(x_seq)
        # Get predictions from all models
        x_cs = x_seq[:, -1, :]
        # out1 = self.cnn_transformer(x_seq)
        out2 = self.gru_model(x_seq)
        # out3 = self.lstm_model(x_seq)
        # out4 = self.pure_transformer(x_seq)
        out5 = self.mlp(x_cs)
        # out6 = self.linear(x_cs)
        # out7 = self.residual(x_cs)

        individual_outputs = [out2, out5]  # out1, out2, out3, out4, out5, out6, out7
        individual_outputs = [self.prediction_dropout(out) for out in individual_outputs]

        dropped_weights = self.ensemble_dropout(self.ensemble_weights)
        weights = F.softmax(dropped_weights, dim=0)

        ensemble_output = torch.zeros_like(individual_outputs[0])  # (batch_size, 424)
        for w, out in zip(weights, individual_outputs):
            ensemble_output += w * out

        entropy_reg = self.entropy_regularization(weights)
        return ensemble_output + 0.01 * entropy_reg

    def entropy_regularization(self, weights):
        """Penalize overly confident ensemble weights"""
        return -torch.sum(weights * torch.log(weights + 1e-8))  # Entropy term


class HierarchicalModel(nn.Module):
    """Complete hierarchical model with variable targets per lag"""

    def __init__(self, input_dim: Dict[int, int], lag_target_sizes: Dict[int, int] = CONFIG.LAGS_TARGET):
        super().__init__()

        self.lag_target_sizes = lag_target_sizes
        self.num_lags = len(lag_target_sizes)

        # Lag-specific encoders with different configurations (ensemble models for each lag)
        self.lag_encoders = nn.ModuleList(
            [
                LagSpecificEnsemble(
                    input_dim=input_dim[lag],
                    hidden_dim=64,  # More capacity for longer lags
                    output_dim=lag_target_sizes[lag],
                )
                for lag in lag_target_sizes.keys()
            ]
        )

    def forward(self, inputs: Dict[str, torch.Tensor]):
        """
        inputs: Dict with keys 'lag_1', 'lag_2', etc.
        Each tensor shape: (batch, seq_len, input_dim)
        """
        lag_features = []

        # Process each lag using the ensemble models
        for i, (lag, encoder) in enumerate(zip(self.lag_target_sizes.keys(), self.lag_encoders)):
            lag_key = lag
            if lag_key in inputs:
                features = encoder(inputs[lag_key])
                lag_features.append(features)
            else:
                print("missing")
                # Handle missing lags (e.g., during inference)
                batch_size = list(inputs.values())[0].size(0)
                dummy_features = torch.zeros(batch_size, 64, device=next(encoder.parameters()).device)
                lag_features.append(dummy_features)

        # Return all individual lag predictions directly (no meta combiner)
        return {
            "final_prediction": torch.stack(lag_features).reshape(-1, CONFIG.NUM_TARGET_COLUMNS),  # List of tensors with different sizes for each lag
        }


In [9]:
class NN:
    def __init__(
        self,
        model: HierarchicalModel,
        lr: float = 0.001,
        batch_size: int = 1,
        epochs: int = 100,
        early_stopping_patience: int = 10,
        early_stopping: bool = True,
        lr_patience: int = 2,
        lr_factor: float = 0.5,
        lr_refit: float = 0.001,
        random_seed: int = CONFIG.RANDOM_STATE,
        refit: bool = True,
        **kwargs,
    ) -> None:
        self.lr = lr
        self.batch_size = batch_size
        self.epochs = epochs
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping = early_stopping
        self.lr_patience = lr_patience
        self.lr_factor = lr_factor
        self.lr_refit = lr_refit
        self.random_seed = random_seed
        self.refit = refit

        self.criterion = CombinedICIRLoss()

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        # Different optimizers for different components
        # self.lag_params = list(self.model.lag_encoders.parameters()) + list(self.model.prediction_heads.parameters())
        # self.meta_params = list(self.model.meta_combiner.parameters()) + list(self.model.cross_attention.parameters())

        # self.refit_lag_optimizer = torch.optim.AdamW(self.lag_params, lr=self.lr_refit["lag_optimizer"], weight_decay=0.01)
        # self.refit_meta_optimizer = torch.optim.AdamW(self.meta_params, lr=self.lr_refit["meta_optimizer"], weight_decay=0.01)
        # self.refit_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr_refit["lag_optimizer"], weight_decay=0.01)

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.01)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=5, T_mult=2, eta_min=1e-6)

        self.refit_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr_refit, weight_decay=0.01)

        self.best_epoch = None
        self.features = None
        self.kwargs = kwargs

    def fit(self, train_set: tuple, val_set: tuple, retrain_set: tuple, verbose: bool = False) -> None:
        """Fit the model on the training set and validate on the validation set.

        Args:
            train_set (tuple): A tuple containing input data, targets for training.
            val_set (tuple): A tuple containing input data, targets for validation.
            verbose (bool, optional): If True, prints training progress. Defaults to False.
        """
        torch.manual_seed(self.random_seed)

        seq_train_dataset = SequentialDataset(*train_set)
        seq_train_dataloader = DataLoader(
            seq_train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=flatten_collate_fn,
            pin_memory=True,
            # num_workers=2,
            # persistent_workers=True,
            # prefetch_factor=2,
            drop_last=True,
        )

        seq_val_dataset = SequentialDataset(*val_set)
        seq_val_dataloader = DataLoader(
            seq_val_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=flatten_collate_fn,
            pin_memory=True,
            # num_workers=6,
            # persistent_workers=True,
            # prefetch_factor=2,
            drop_last=True,
        )

        retrain_val_dataset = SequentialDataset(*retrain_set)
        retrain_val_dataloader = DataLoader(
            retrain_val_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=flatten_collate_fn,
            pin_memory=True,
            # num_workers=6,
            # persistent_workers=True,
            # prefetch_factor=2,
            drop_last=True,
        )

        train_sharpes, val_sharpes = [], []
        if verbose:
            print(f"Device: {self.device}")
            print(
                f"{'Epoch':^5} | {'Train Loss':^10} | {'Train ICIR Loss':^15} | {'Train MSE Loss':^14} | {'Train Ranking Loss':^17} | {'Val Loss':^8} | {'Val ICIR Loss':^13} | {'Val MSE Loss':^12} | {'Val Ranking Loss':^16} | {'Train sharpe':^9} | {'Val sharpe':^7} | {'LR':^7}"
            )
            print("-" * 60)

        min_val_sharpe = -np.inf
        best_epoch = 0
        no_improvement = 0
        best_model = None
        for epoch in range(self.epochs):
            train_loss, train_sharpe, train_icir_loss, train_mse_loss, train_ranking_loss = self.train_one_epoch(seq_train_dataloader, verbose)
            val_loss, val_sharpe, val_icir_loss, val_mse_loss, val_ranking_loss = self.validate_one_epoch(
                seq_val_dataloader, retrain_val_dataloader, verbose
            )

            self.scheduler.step()
            lr_last = self.optimizer.param_groups[0]["lr"]

            # self.lag_scheduler.step()
            # self.meta_scheduler.step()
            # lag_optimizer_lr_last = self.lag_optimizer.param_groups[0]["lr"]
            # meta_optimizer_lr_last = self.meta_optimizer.param_groups[0]["lr"]

            train_sharpes.append(train_sharpe)
            val_sharpes.append(val_sharpe)

            if verbose:
                print(
                    f"{epoch + 1:^5} | {train_loss:^10.4f} | {train_icir_loss:^15.4f} | {train_mse_loss:^14.4f} | {train_ranking_loss:^17.4f} | {val_loss:^8.4f} | {val_icir_loss:^13.4f} | {val_mse_loss:^12.4f} | {val_ranking_loss:^16.4f} | {train_sharpe:^9.4f} | {val_sharpe:^7.4f} | {lr_last:^7.5f}"
                )

            if val_sharpe > min_val_sharpe:
                min_val_sharpe = val_sharpe
                best_model = copy.deepcopy(self.model.state_dict())
                no_improvement = 0
                best_epoch = epoch
            else:
                no_improvement += 1

            if self.early_stopping:
                if no_improvement >= self.early_stopping_patience + 1:
                    self.best_epoch = best_epoch + 1
                    if verbose:
                        print(f"Early stopping on epoch {best_epoch + 1}. Best score: {min_val_sharpe:.4f}")
                    break

        # Load the best model
        if self.early_stopping:
            self.model.load_state_dict(best_model)

    def train_one_epoch(self, seq_train_dataloader: DataLoader, verbose: bool) -> tuple:
        """Train the model for one epoch.

        Args:
            train_dataloader (DataLoader): DataLoader for the training set.
            verbose (bool): If True, shows progress using tqdm.

        Returns:
            tuple[float, float]: A tuple containing:
                - Train loss (float).
                - Spearman Sharpe for the training set (float).
        """
        self.model.train()
        total_loss = 0.0
        total_icir_loss = 0.0
        total_mse_loss = 0.0
        total_ranking_loss = 0.0

        y_total, preds_total = [], []

        for seq_batch in seq_train_dataloader:
            seq_x_batch = seq_batch["continuous"]
            seq_x_batch = {key: value.to(self.device) for key, value in seq_x_batch.items()}

            true_y = seq_batch["current"].to(self.device)

            self.optimizer.zero_grad()
            # self.lag_optimizer.zero_grad()
            # self.meta_optimizer.zero_grad()
            with torch.autocast(device_type="cuda"):
                pred_y = self.model(seq_x_batch)
                final_preds = pred_y["final_prediction"]
                loss, icir_loss, mse_loss, ranking_loss, _ = self.criterion(final_preds, true_y).values()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

            self.optimizer.step()
            # self.lag_optimizer.step()
            # self.meta_optimizer.step()

            total_loss += loss.item()
            total_icir_loss += icir_loss.item()
            total_mse_loss += mse_loss.item()
            total_ranking_loss += ranking_loss.item()

            y_total.append(true_y)
            preds_total.append((final_preds).detach())

            torch.cuda.empty_cache()

        y_total = torch.cat(y_total).cpu().numpy().astype(np.float64)
        preds_total = torch.cat(preds_total).cpu().numpy().astype(np.float64)

        train_sharpe = rank_correlation_sharpe(y_total, preds_total)
        train_loss = total_loss / len(seq_train_dataloader)
        train_icir_loss = total_icir_loss / len(seq_train_dataloader)
        train_mse_loss = total_mse_loss / len(seq_train_dataloader)
        train_ranking_loss = total_ranking_loss / len(seq_train_dataloader)

        return train_loss, train_sharpe, train_icir_loss, train_mse_loss, train_ranking_loss

    @timer
    def validate_one_epoch(self, seq_val_dataloader: DataLoader, retrain_val_dataloader: DataLoader, verbose=False) -> tuple:
        """Validate the model on the validation set.

        Args:
            val_dataloader (DataLoader): DataLoader for the validation set.
            verbose (bool, optional): If True, shows progress using tqdm. Defaults to False.

        Returns:
            tuple[float, float]: A tuple containing:
                - Validation loss (float).
                - Spearman Sharpe for the validation set (float).
        """
        model = copy.deepcopy(self.model).to("cpu")

        losses, icir_losses, mse_losses, ranking_losses, all_y, all_preds = [], [], [], [], [], []

        for seq_batch, retrain_batch in zip(seq_val_dataloader, retrain_val_dataloader):
            # seq_batch = {key: value.to(self.device) for key, value in seq_batch.items()}
            # cs_batch = {key: value.to(self.device) for key, value in cs_batch.items()}
            seq_x_batch = seq_batch["continuous"]
            true_y = seq_batch["current"]

            # Predict
            model.eval()
            with torch.inference_mode():
                pred_y = model(seq_x_batch)
                final_preds = pred_y["final_prediction"]
                final_preds = torch.nan_to_num(final_preds)

                loss, icir_loss, mse_loss, ranking_loss, _ = self.criterion(final_preds, true_y).values()
                losses.append(loss.cpu().numpy())
                icir_losses.append(icir_loss.cpu().numpy())
                mse_losses.append(mse_loss.cpu().numpy())
                ranking_losses.append(ranking_loss.cpu().numpy())

                all_y.append(true_y)
                all_preds.append(final_preds)

            # Update weights
            if self.refit:
                retrain_seq_x_batch = retrain_batch["continuous"]
                retrain_true_y = retrain_batch["current"]

                self.refit_optimizer.zero_grad()
                # self.refit_lag_optimizer.zero_grad()
                # self.refit_meta_optimizer.zero_grad()

                model.train()
                pred_y = model(retrain_seq_x_batch)
                final_preds = pred_y["final_prediction"]

                loss = self.criterion(final_preds, retrain_true_y)["total_loss"]
                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

                self.refit_optimizer.step()
                # self.refit_lag_optimizer.step()
                # self.refit_meta_optimizer.step()

        all_y = torch.cat(all_y).numpy().astype(np.float64)
        all_preds = torch.cat(all_preds).numpy().astype(np.float64)
        loss = np.mean(losses)
        val_icir_loss = np.mean(icir_losses)
        val_mse_loss = np.mean(mse_losses)
        val_ranking_loss = np.mean(ranking_losses)

        sharpe = rank_correlation_sharpe(all_y, all_preds)

        return loss, sharpe, val_icir_loss, val_mse_loss, val_ranking_loss

    def update(self, seq_X: Dict[int, np.array], true_y: np.array):
        """Update the model with new data.

        Args:
            X (np.array): Input data.
            y (np.array): Target variable.
            n_times (int): Number of time steps.
        """
        torch.manual_seed(self.random_seed)
        if self.lr_refit == 0.0:
            return

        seq_continuous_data = {
            lag: torch.tensor(np.nan_to_num(X, nan=0.0), dtype=torch.float32, device=self.device).unsqueeze(0) for lag, X in seq_X.items()
        }

        true_y = torch.tensor(np.nan_to_num(true_y, nan=0.0), dtype=torch.float32, device=self.device)
        self.model.train()

        self.refit_optimizer.zero_grad()
        # self.refit_lag_optimizer.zero_grad()
        # self.refit_meta_optimizer.zero_grad()
        with torch.autocast(device_type="cuda"):
            pred_y = self.model(seq_continuous_data)
            loss = self.criterion(pred_y["final_prediction"], true_y)["total_loss"]

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
        self.refit_optimizer.step()
        # self.refit_lag_optimizer.step()
        # self.refit_meta_optimizer.step()

    def predict(self, seq_X: Dict[int, np.array]) -> tuple[np.array, torch.Tensor | list]:
        """Predict the target variable for the given input data.

        Args:
            X (np.array): Input data.

        Returns:
            tuple[np.array, torch.Tensor or list]: A tuple containing:
                - Predictions (np.array).
                - Hidden state (torch.Tensor or list).
        """
        torch.manual_seed(self.random_seed)
        seq_continuous_data = {
            lag: torch.tensor(np.nan_to_num(X, nan=0.0), dtype=torch.float32, device=self.device).unsqueeze(0) for lag, X in seq_X.items()
        }

        self.model.eval()
        with torch.inference_mode():
            preds = self.model(seq_continuous_data)
            preds = torch.nan_to_num(preds["final_prediction"])

        return preds.cpu().numpy().astype(np.float64)

In [10]:
# --- Prepare DataLoader ---
# Create the dataset

train_x = pl.scan_csv(CONFIG.TRAIN_X_PATH)
train_x = PREPROCESSOR(df=train_x)
train_x = train_x.clean()

features = FEATURE_ENGINEERING(df=train_x)
train_x: pl.DataFrame = features.create_market_features()

train_y = pl.scan_csv(CONFIG.TRAIN_Y_PATH)

curr_y = (
    train_y.with_columns([pl.col(CONFIG.LAGS[f"lag{i}"]).exclude(CONFIG.DATE_COL).shift(i + 1) for i in range(1, 5)])
    .with_columns(pl.all().exclude(CONFIG.DATE_COL).shift())
    .filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series())))
    .collect()
    .fill_null(0)
    .lazy()
)

y_feat = FEATURE_ENGINEERING(df=curr_y)
lags = y_feat._compute_lag_returns(df=curr_y)
market = y_feat._compute_market_stats(df=curr_y)
skew = y_feat._compute_return_skew(df=curr_y)

train_x = (
    train_x.join(curr_y.collect(), on=CONFIG.DATE_COL)
    .join(lags.collect(), on=CONFIG.DATE_COL)
    .join(market.collect(), on=CONFIG.DATE_COL)
    .join(skew.collect(), on=CONFIG.DATE_COL)
)


train_y = train_y.filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series()))).collect()
train_x = (
    train_x.with_columns([pl.when(pl.col(col).is_infinite()).then(0.0).otherwise(pl.col(col)).alias(col) for col in train_x.columns])
    .with_columns(pl.all().shrink_dtype())
    .filter(pl.col(CONFIG.DATE_COL).is_in(train_y.select(CONFIG.DATE_COL).to_series()))
    .with_columns(pl.col(CONFIG.DATE_COL).cast(pl.Int64))
    # .select([CONFIG.DATE_COL] + CONFIG.IMPT_COL)
)

retrain_x = train_x.with_columns(pl.all().exclude(CONFIG.DATE_COL).shift(5))
retrain_y = train_y.filter((pl.col(CONFIG.DATE_COL).is_in(train_x.select(CONFIG.DATE_COL).to_series()))).with_columns(
    pl.all().exclude(CONFIG.DATE_COL).shift(5)
)

train_y_arr = train_y.drop(CONFIG.DATE_COL).to_numpy()

train_y = pl.DataFrame(
    (train_y_arr - np.nanmean(train_y_arr, axis=1).reshape(train_y_arr.shape[0], -1))
    / np.nanstd(train_y_arr, axis=1).reshape(train_y_arr.shape[0], -1),
    schema=train_y.drop(CONFIG.DATE_COL).columns,
).insert_column(0, train_y.select(CONFIG.DATE_COL).to_series())

retrain_y_arr = retrain_y.drop(CONFIG.DATE_COL).to_numpy()
retrain_y = pl.DataFrame(
    (retrain_y_arr - np.nanmean(retrain_y_arr, axis=1).reshape(retrain_y_arr.shape[0], -1))
    / np.nanstd(retrain_y_arr, axis=1).reshape(retrain_y_arr.shape[0], -1),
    schema=train_y.drop(CONFIG.DATE_COL).columns,
).insert_column(0, train_y.select(CONFIG.DATE_COL).to_series())


_compute_lag_returns took 0.0125 seconds
_compute_autocorr_torch took 0.3736 seconds
_compute_obv took 0.0095 seconds
_compute_return_skew took 0.0025 seconds
_compute_volume_z took 0.0416 seconds
_compute_market_stats took 0.1341 seconds
_compute_atr took 0.0075 seconds
_compute_rolling took 0.0133 seconds
create_market_features took 2.1133 seconds
_compute_lag_returns took 0.0084 seconds
_compute_market_stats took 0.0115 seconds
_compute_return_skew took 0.0030 seconds


  (retrain_y_arr - np.nanmean(retrain_y_arr, axis=1).reshape(retrain_y_arr.shape[0], -1))
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [11]:
NN_model = NN(
    model=HierarchicalModel(input_dim={lag: len(features) for lag, features in CONFIG.LAG_FEATURES.items()}, lag_target_sizes=CONFIG.LAGS_TARGET),
    batch_size=CONFIG.BATCH_SIZE,
    lr=0.0005,
    lr_refit=0.0005,
    epochs=200,
    early_stopping_patience=10,
)

In [None]:
# test_size = (
#     TEST_SIZE
#     if len(dates_unique) > TEST_SIZE * (n_splits + 1)
#     else len(dates_unique) // (n_splits + 1)
# )  # For testing purposes on small samples

dates_unique = train_x.filter(pl.col(CONFIG.DATE_COL) <= CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
real_dates_unique = (
    train_x.filter(pl.col(CONFIG.DATE_COL) > CONFIG.MAX_TRAIN_DATE).select(pl.col(CONFIG.DATE_COL).unique().sort()).to_series().to_numpy()
)

cv = TimeSeriesSplit(n_splits=CONFIG.N_FOLDS)
cv_split = cv.split(dates_unique)

scores = []
models = []
for fold, (train_idx, valid_idx) in enumerate(cv_split):
    if fold <= 3:
        continue
    if CONFIG.VERBOSE:
        print("-" * 20 + f"Fold {fold}" + "-" * 20)
        print(f"Train dates from {dates_unique[train_idx].min()} to {dates_unique[train_idx].max()}")
        print(f"Valid dates from {dates_unique[valid_idx].min()} to {dates_unique[valid_idx].max()}")

    dates_train = dates_unique[train_idx]
    dates_valid = dates_unique[valid_idx]

    df_train = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))
    true_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(dates_train))

    valid_period = range(min(dates_valid) - max(CONFIG.LAG_SEQ_LEN.values()) + 1, max(dates_valid) + 1)
    df_valid = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
    df_valid_current_y = train_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

    df_valid_retrain = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))
    df_valid_current_y_retrain = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(valid_period))

    model_fold = copy.deepcopy(NN_model)

    model_fold.fit(
        train_set=(df_train, true_y),
        val_set=(df_valid, df_valid_current_y),
        retrain_set=(df_valid_retrain, df_valid_current_y_retrain),
        verbose=CONFIG.VERBOSE,
    )

    models.append(model_fold)

    torch.save(
        model_fold.model.state_dict(),
        f"C:/Users/Admin/Desktop/Personal-Projects/Kaggle/MITSUI&CO. Commodity Prediction Challenge/ensemble_{fold}.pth",
    )

    preds = []
    cnt_dates = 0
    model_save = copy.deepcopy(model_fold)

    model_fold.model.load_state_dict(
        torch.load(
            f"C:/Users/Admin/Desktop/Personal-Projects/Kaggle/MITSUI&CO. Commodity Prediction Challenge/ensemble_{fold}.pth",
            map_location=torch.device("cuda"),
        )
    )

    for date_id in tqdm(dates_valid):
        period = range(date_id - max(CONFIG.LAG_SEQ_LEN.values()) + 1, date_id + 1)

        df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
        valid_lags = {lag: df_valid_date.select(features).to_numpy().astype(np.float64) for lag, features in CONFIG.LAG_FEATURES.items()}
        valid_lags = {lag: valid_lags[lag][-seq_len:] for lag, seq_len in CONFIG.LAG_SEQ_LEN.items()}

        if model_fold.refit and (cnt_dates > 0):
            df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            df_upd_lags = {lag: df_upd.select(features).to_numpy().astype(np.float64) for lag, features in CONFIG.LAG_FEATURES.items()}
            df_upd_lags = {lag: df_upd_lags[lag][-seq_len:] for lag, seq_len in CONFIG.LAG_SEQ_LEN.items()}

            df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(date_id)).drop(CONFIG.DATE_COL).to_numpy()

            if len(df_upd) > 0:
                model_save.update(df_upd_lags, df_upd_current_y)

        preds_i = model_save.predict(valid_lags)

        preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

        cnt_dates += 1

    preds = np.array(preds)

    score = rank_correlation_sharpe(
        df_valid_current_y.drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
        preds,
    )
    scores.append(score)

    print(f"LAST VALIDIDATION Sharpe: {score:.5f}")

    model_real = copy.deepcopy(model_save)
    preds = []
    cnt_dates = 0
    for date_id in tqdm(real_dates_unique):
        # print(date_id)
        period = range(date_id - max(CONFIG.LAG_SEQ_LEN.values()) + 1, date_id + 1)
        df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)

        valid_lags = {lag: df_valid_date.select(features).to_numpy().astype(np.float64) for lag, features in CONFIG.LAG_FEATURES.items()}
        valid_lags = {lag: valid_lags[lag][-seq_len:] for lag, seq_len in CONFIG.LAG_SEQ_LEN.items()}

        if model_fold.refit and (cnt_dates > 0):
            df_upd = retrain_x.filter(pl.col(CONFIG.DATE_COL).is_in(period)).drop(CONFIG.DATE_COL)
            df_upd_lags = {lag: df_upd.select(features).to_numpy().astype(np.float64) for lag, features in CONFIG.LAG_FEATURES.items()}
            df_upd_lags = {lag: df_upd_lags[lag][-seq_len:] for lag, seq_len in CONFIG.LAG_SEQ_LEN.items()}

            df_upd_current_y = retrain_y.filter(pl.col(CONFIG.DATE_COL).is_in(date_id)).drop(CONFIG.DATE_COL).to_numpy()

            if len(df_upd) > 0:
                model_real.update(df_upd_lags, df_upd_current_y)

        # print(df_upd[:, 0])
        # print(df_upd_current_y[:, -1])
        # print(df_upd_true_delta[:, -1])
        preds_i = model_real.predict(valid_lags)

        preds += list(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))
        # print(preds_i[-1].reshape(-1, CONFIG.NUM_TARGET_COLUMNS))

        cnt_dates += 1
        # if cnt_dates == 2:
        #     break

    preds = np.array(preds)

    score = rank_correlation_sharpe(
        train_y.filter(pl.col(CONFIG.DATE_COL).is_in(real_dates_unique)).drop(CONFIG.DATE_COL).to_numpy().astype(np.float64),
        preds,
    )
    scores.append(score)
    print(f"REAL Sharpe: {score:.5f}")

--------------------Fold 4--------------------
Train dates from 2 to 1522
Valid dates from 1523 to 1826
Device: cuda:0
Epoch | Train Loss | Train ICIR Loss | Train MSE Loss | Train Ranking Loss | Val Loss | Val ICIR Loss | Val MSE Loss | Val Ranking Loss | Train sharpe | Val sharpe |   LR   
------------------------------------------------------------


In [None]:
retrain_ds = SequentialDataset(df_valid_retrain, df_valid_current_y_retrain)
pl.DataFrame(next(iter(retrain_ds))[0][4].numpy())

column_0,column_1,column_2,column_3,column_4,column_5,column_6,column_7,column_8,column_9,column_10,column_11,column_12,column_13,column_14,column_15,column_16,column_17,column_18,column_19,column_20,column_21,column_22,column_23,column_24,column_25,column_26,column_27,column_28,column_29,column_30,column_31,column_32,column_33,column_34,column_35,column_36,…,column_63,column_64,column_65,column_66,column_67,column_68,column_69,column_70,column_71,column_72,column_73,column_74,column_75,column_76,column_77,column_78,column_79,column_80,column_81,column_82,column_83,column_84,column_85,column_86,column_87,column_88,column_89,column_90,column_91,column_92,column_93,column_94,column_95,column_96,column_97,column_98,column_99
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
-0.600796,-0.914647,0.030829,-0.708012,-0.054668,-0.948001,-1.359432,-0.146292,-0.137402,-0.850851,0.078772,-0.123519,0.000983,-0.884438,1.125735,1.143567,1.159554,-0.505677,1.106106,-0.168141,-1.12163,-0.141626,-0.418873,-2.22735,-0.018628,0.876863,-0.263154,-0.659306,-0.949431,-1.160557,-0.547408,-0.144907,-0.32823,0.011029,0.008649,-1.469505,-0.192384,…,0.416911,0.49945,0.415239,0.01134,0.505069,0.011571,0.011571,-0.065728,0.428339,0.427753,-1.43725,-1.056561,-0.017627,0.110134,0.15424,0.503649,0.620564,0.06471,0.399062,0.604878,0.691247,-0.02008,0.693985,0.638001,-1.114075,0.182135,0.662117,0.721626,-0.987616,0.682037,-0.065506,1.024815,-0.410165,1.093586,-0.052384,0.030002,0.118704
-0.331305,-0.799951,0.04221,-0.609153,0.222382,-1.083803,-1.329361,-0.524006,-0.866288,-0.78775,-0.705297,0.011299,-0.016925,-0.681306,1.074381,1.087834,1.111351,-0.04803,1.053658,-0.819396,-0.831125,-0.809522,-0.274026,-2.369183,-0.068554,0.866089,-1.459333,-0.513191,-0.818227,-1.076879,-0.287044,-1.180784,-0.213305,-0.002656,-0.586793,-1.471487,-1.138588,…,0.343683,0.436692,0.33969,-0.005089,0.441262,-0.024865,-0.024865,-0.018148,0.358543,0.355994,-1.439224,-0.91452,0.011029,-0.09682,-1.293827,0.296467,0.540513,-0.136164,-0.716517,0.52334,0.608413,0.168394,0.620309,0.637208,-1.226812,0.063022,0.663466,0.724069,-0.763981,0.68042,-0.068423,0.949595,-1.30037,1.074553,-0.389326,0.056125,-0.120663
-0.672513,-0.636656,0.046717,-0.600888,0.41067,-0.007206,-1.275968,-0.859563,-1.660442,-0.744369,-0.602834,0.000385,0.01921,-0.459605,1.024923,1.03614,1.055555,0.60446,1.007326,-1.579592,-0.377803,-1.605468,-0.090251,2.471409,0.191179,0.707784,-1.358112,-0.466072,-0.894953,-0.814633,-0.734001,-0.546881,-0.124805,0.021448,-0.777764,-1.104306,-0.532989,…,0.26818,0.357569,0.263673,-0.026372,0.358072,-0.02419,-0.02419,-0.026771,0.275264,0.273948,-1.451764,-0.862717,-0.002656,-0.346436,0.475959,0.182605,0.4661,-0.360702,0.214843,0.454266,0.53341,0.190971,0.535629,0.640734,-1.364923,0.327149,0.667453,0.727474,-0.545925,0.684033,-0.044396,0.847165,-1.70006,1.060013,-0.580176,0.024639,-0.258854
-0.669441,-0.958308,0.041761,-0.694853,0.48666,-0.399988,-0.731586,-0.852627,-2.475681,-0.705136,-0.313858,0.07329,-0.023334,-0.469034,0.978591,0.98603,1.011364,0.021094,0.956699,-2.387139,-0.162747,-2.413514,0.981804,2.242544,0.101565,0.63733,-1.347286,-0.552309,-0.894854,-0.839749,-1.0248,-1.514108,-0.186853,0.007691,-0.70138,-0.926617,-1.330135,…,0.214499,0.285683,0.209884,-0.039141,0.285819,0.022193,0.022193,0.014136,0.230739,0.231165,-1.503957,-1.070332,0.021448,-0.551869,0.168374,0.124773,0.392771,-0.570181,-0.43412,0.375534,0.457342,0.102894,0.465667,0.642076,-1.394,0.275905,0.663425,0.725267,-0.496921,0.686843,-0.002049,0.731776,-1.717889,1.030099,-0.786002,-0.005382,-0.331825
0.514794,-1.189119,0.022192,-0.588277,-0.064362,0.077836,-0.398386,-0.778215,-1.876163,-0.639785,0.000009,0.172795,0.027207,-0.403682,0.921649,0.928434,0.955571,-0.079817,0.901019,-1.822546,-0.215668,-1.818873,1.469027,2.26922,0.238014,0.374133,0.139196,-0.685932,-2.453737,-0.124573,-1.139299,0.327368,-0.422237,0.056293,-0.277295,-1.232394,0.412006,…,0.175401,0.234107,0.171287,-0.052108,0.235389,-0.001036,-0.001036,-0.050332,0.180022,0.179639,-1.502217,-1.165954,0.007691,-0.839695,0.691321,-0.00693,0.302254,-0.867949,0.295432,0.281222,0.363977,0.029722,0.379238,0.651947,-1.258663,1.451046,0.673701,0.73948,-0.279643,0.698398,0.066064,0.660564,-1.259132,0.981152,-1.028614,-0.045712,-0.475671
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
-1.96814,-1.115693,0.033142,-1.344751,-0.543341,-0.076323,-0.211837,-0.478293,-0.217636,0.145765,-1.331024,0.119196,0.002533,-0.561447,0.05448,0.00422,0.019327,0.790045,0.015917,-0.228357,-0.329444,-0.169107,-2.771951,-1.062265,-0.246845,-0.21535,-0.480388,-1.460805,1.496517,-0.998931,-0.992053,1.257294,-0.334431,-0.008477,0.021521,-2.449189,1.093991,…,-0.283008,-0.295676,-0.26737,0.043099,-0.28509,-0.004691,-0.004691,-0.00335,-0.242449,-0.2158,-1.54215,-1.05603,-0.024786,0.509546,1.476047,-1.482851,-1.464848,0.488219,1.545948,-1.46654,-1.427171,-0.709411,-1.470181,1.328351,0.459814,1.925516,1.338632,1.381834,-0.543535,1.348985,-0.03714,-1.139946,1.362018,1.432309,0.455582,0.029711,-1.85691
-1.70836,-0.869532,-0.015545,-1.241043,-0.206872,-0.400984,-0.207813,1.166677,-0.335576,0.141556,-1.170128,0.162678,0.023578,-0.152517,0.080209,0.048139,0.057744,2.32708,0.040351,-0.323255,0.204012,-0.284087,-1.718454,0.374419,-0.225106,-0.316583,2.1017,-1.446151,1.390184,-0.346507,0.030289,0.818277,-0.4362,0.02143,0.340718,-2.046823,0.699786,…,-0.170766,-0.166887,-0.154509,0.018265,-0.156789,0.024597,0.024597,0.01073,-0.110283,-0.082167,-1.492885,-0.789841,-0.008477,0.531188,4.144293,-1.123518,-1.453697,0.503802,3.791299,-1.456301,-1.395617,-0.68776,-1.446778,1.359627,0.333251,2.365141,1.369433,1.410887,0.361639,1.378128,-0.006543,-1.049528,2.475384,1.524135,0.308651,0.00552,-1.507814
-1.547518,-0.963822,0.020784,-1.080271,0.105217,-0.067506,-0.218394,1.493208,-1.039451,0.141422,-0.624688,0.13518,0.049707,0.037228,0.100481,0.07362,0.09134,1.947745,0.056703,-0.951019,0.508251,-0.974819,-1.768716,1.074763,-0.027022,-0.351408,1.895022,-1.420101,0.613906,0.266544,-0.332149,0.489066,-0.445044,0.027328,0.35058,-2.403541,0.431322,…,-0.05274,-0.050915,-0.035895,-0.04864,-0.042202,0.0,0.0,0.042943,0.013415,0.039093,-1.458768,-0.892023,0.02143,0.357499,3.792241,-0.7788,-1.451631,0.326401,2.929815,-1.465609,-1.391209,-0.679369,-1.426131,1.390869,0.390423,2.136362,1.398071,1.439649,0.580624,1.409553,-0.046906,-0.96927,2.26241,1.567311,-0.050363,0.027104,-1.175144
-1.447757,-1.476177,-0.049767,-1.001616,0.269588,0.445854,-0.304137,1.586509,-1.46205,0.14727,-0.728888,-0.023395,0.053843,0.317731,0.106147,0.084324,0.102986,1.501789,0.065393,-1.430824,1.146053,-1.434283,-1.65314,0.931555,-0.115491,-0.330579,1.271166,-1.414736,0.108802,0.021706,-0.035344,-0.439081,-0.440282,0.026983,0.48243,-2.462531,-0.378354,…,0.02136,0.041744,0.036118,0.0,0.050104,0.000707,0.000707,0.030855,0.04767,0.075274,-1.416233,-1.141376,0.027328,0.020897,2.598391,-0.548398,-1.481715,-0.015513,1.763104,-1.489823,-1.412691,-0.798053,-1.442874,1.419367,0.590982,0.947132,1.426985,1.468707,0.813362,1.436473,0.008007,-0.89995,0.95295,1.55111,-0.109264,-0.022066,-0.952475


In [None]:
train_x.filter(pl.col(CONFIG.DATE_COL).is_in(range(date_id - max(CONFIG.LAG_SEQ_LEN.values()) + 1, date_id + 1))).select(
    [CONFIG.DATE_COL] + CONFIG.LAG_FEATURES[1]
)

date_id,LME_CA_Close_vol_20,US_Stock_SPTL_adj_close_log_ret_return_lag_3,target_334,target_334_return_lag_0,FX_AUDUSD_log_ret_return_lag_3,US_Stock_VGLT_adj_close_log_ret_return_lag_3,FX_USDCHF_log_ret_return_lag_4,US_Stock_CCJ_adj_high_log_ret,US_Stock_CCJ_adj_high_log_ret_return_lag_0,US_Stock_VGK_adj_high_log_ret_return_lag_4,US_Stock_URA_adj_high_log_ret_return_lag_1,US_Stock_IGSB_adj_high_log_ret_return_lag_4,US_Stock_OKE_adj_close_vol_5,target_353,target_353_return_lag_0,US_Stock_VEA_adj_high_log_ret_return_lag_4,FX_GBPAUD_log_ret_return_lag_2,target_352_return_lag_5,US_Stock_EWY_adj_high_log_ret_return_lag_4,US_Stock_VALE_adj_high_log_ret_return_lag_4,target_342_return_lag_4,US_Stock_EFA_adj_high_log_ret_return_lag_4,US_Stock_GLD_adj_high_log_ret_return_lag_1,target_282_return_lag_4,US_Stock_EWJ_adj_high_log_ret_return_lag_4,US_Stock_RSP_adj_low_log_ret_return_lag_4,US_Stock_BNDX_adj_close_log_ret_return_lag_5,target_391_return_lag_5,US_Stock_EWY_adj_open_log_ret_return_lag_4,US_Stock_VALE_adj_open_log_ret_return_lag_4,US_Stock_BNDX_adj_low_vol_5,FX_USDJPY_log_ret_return_lag_3,US_Stock_IGSB_adj_high_log_ret_return_lag_2,US_Stock_SPTL_adj_low_log_ret_return_lag_3,LME_PB_Close_log_ret_return_lag_3,target_293_return_lag_3,…,US_Stock_YINN_adj_open_log_ret_return_lag_4,US_Stock_VCSH_adj_low_log_ret_return_lag_4,US_Stock_KMI_adj_low_log_ret_return_lag_1,US_Stock_BKR_adj_close_log_ret,US_Stock_BKR_adj_close_log_ret_return_lag_0,US_Stock_BP_adj_high_log_ret_return_lag_4,US_Stock_BKR_adj_close_log_ret_return_lag_4,US_Stock_CVX_adj_low_log_ret_return_lag_4,US_Stock_SPIB_adj_open_log_ret_return_lag_5,US_Stock_SHY_adj_low_log_ret_return_lag_3,US_Stock_YINN_adj_low_log_ret_return_lag_4,US_Stock_BSV_adj_low_log_ret_return_lag_3,US_Stock_CCJ_adj_close_log_ret_return_lag_4,target_355_return_lag_2,US_Stock_FXI_adj_low_log_ret_return_lag_4,US_Stock_BNDX_adj_low_log_ret_return_lag_3,JPX_Gold_Standard_Futures_Low_log_ret,JPX_Gold_Standard_Futures_Low_log_ret_return_lag_0,target_396_return_lag_5,US_Stock_EWJ_adj_low_log_ret_return_lag_4,JPX_Platinum_Standard_Futures_Close_log_ret_return_lag_4,JPX_Gold_Mini_Futures_Low_log_ret,JPX_Gold_Mini_Futures_Low_log_ret_return_lag_0,US_Stock_VGLT_adj_low_log_ret_return_lag_3,US_Stock_VWO_adj_close_log_ret,US_Stock_VWO_adj_close_log_ret_return_lag_0,FX_USDJPY_sma_5_ratio,US_Stock_MPC_adj_close_vol_20,JPX_Platinum_Standard_Futures_Low_log_ret,JPX_Platinum_Standard_Futures_Low_log_ret_return_lag_0,US_Stock_TECK_adj_open_log_ret,US_Stock_TECK_adj_open_log_ret_return_lag_0,US_Stock_LYB_adj_low_log_ret_return_lag_4,FX_AUDCAD_log_ret_return_lag_2,US_Stock_EWZ_adj_low_log_ret,US_Stock_EWZ_adj_low_log_ret_return_lag_0,US_Stock_FXI_adj_open_log_ret_return_lag_5
i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
1492,-0.511044,0.173562,-0.033112,-0.033112,1.533998,0.234326,1.023351,-2.188488,-2.188488,-0.663893,-2.809797,-0.074308,1.073938,-0.022895,-0.022895,-0.519893,-0.154931,-0.023216,-0.240476,-0.428015,0.023171,-0.455075,-2.550884,-0.03502,0.215218,-0.611564,-0.224083,0.027589,-0.173789,-0.15977,-0.008468,-0.157877,0.954141,-1.089345,1.486385,-0.037555,…,0.161582,-1.27943,-2.127097,-0.125517,-0.125517,0.738528,0.674825,0.727557,0.234615,0.581767,-0.088414,-0.000459,-0.356188,0.032037,-0.118028,-1.095543,-2.022053,-2.022053,0.026556,0.024484,0.329002,-2.035179,-2.035179,-1.097548,-1.284433,-1.284433,-0.23474,-0.97524,-2.99021,-2.99021,-1.378905,-1.378905,0.077251,1.309706,-1.638625,-1.638625,-0.276408
1493,-0.312962,0.101284,-0.016007,-0.016007,0.02741,0.097307,-1.096954,-1.957091,-1.957091,0.558081,-1.654699,-0.45129,1.165789,-0.038005,-0.038005,0.381586,0.147775,0.013702,0.254698,0.281916,-0.008606,0.401456,-0.312567,0.010238,-0.371636,0.65835,-0.746884,0.027614,-0.65199,-0.364345,0.081926,0.015751,-1.643458,1.254481,-0.893886,0.039143,…,-0.678046,0.356886,-0.603849,-2.20992,-2.20992,0.192488,0.377309,0.222367,-0.321926,0.819286,-0.350581,1.449974,0.765745,0.063368,-0.373034,1.393187,-1.040454,-1.040454,0.032931,-0.255789,-1.269722,-1.061661,-1.061661,1.242034,-0.444676,-0.444676,-0.329176,-0.468101,-1.172184,-1.172184,-0.616882,-0.616882,0.241764,-0.687277,-0.346604,-0.346604,0.125275
1494,-0.080944,-1.401107,0.009703,0.009703,-1.289793,-1.485501,0.17002,0.772106,0.772106,0.922436,-1.991366,0.954141,-0.212533,-0.039579,-0.039579,0.578321,1.378222,0.057261,0.453239,0.766982,-0.005012,0.561487,-0.524425,-0.015106,-0.730568,0.091347,-0.320084,0.009189,1.373382,1.248753,0.08233,0.421902,-0.549668,-1.361923,-0.915021,0.047038,…,1.210833,0.913825,-0.450888,0.25656,0.25656,-0.375223,-1.910836,-0.696194,-1.614106,-0.785505,0.702767,-1.214816,-1.681274,0.041355,0.708703,-1.167691,0.024832,0.024832,0.034368,-1.076575,1.646769,0.064427,0.064427,-1.371962,0.376309,0.376309,-0.597941,-0.037599,-0.070476,-0.070476,-0.706681,-0.706681,-0.04385,-1.06743,-0.217693,-0.217693,-0.675665
1495,-0.023508,-1.782801,0.014661,0.014661,-1.058782,-1.810959,0.550438,0.951175,0.951175,-1.918621,1.034759,-1.643458,-0.15478,0.006229,0.006229,-1.997166,0.264006,0.046996,-1.335938,-0.496724,0.001405,-1.952236,-0.335142,-0.047527,-1.464685,-1.4316,0.367832,0.004801,-1.476968,-0.339328,-0.725367,-0.601002,-0.075723,-1.991876,-1.027569,0.079335,…,-0.56788,-1.111427,0.032094,0.518414,0.518414,-1.231643,-1.203543,-1.104401,1.425645,-0.469598,-0.243584,-0.84612,-2.57197,0.04096,-0.284574,-1.158817,-0.937016,-0.937016,0.036215,-1.057066,-0.526809,-0.938397,-0.938397,-1.98672,1.300727,1.300727,0.067641,0.207813,-0.981562,-0.981562,1.040375,1.040375,-1.239508,0.998291,-0.73244,-0.73244,1.204968
1496,-0.087871,1.149808,0.03664,0.03664,0.386462,1.17987,0.652076,-0.445462,-0.445462,-1.495579,1.244723,-0.549668,1.883954,0.024519,0.024519,-1.797999,-0.453268,0.01164,-0.660344,-0.620099,-0.016771,-1.679115,0.833929,-0.060649,-1.061548,-1.184204,-1.078687,0.054209,-0.969331,-1.025861,-0.669347,-0.206991,0.767765,0.563411,-0.029653,-0.011485,…,-1.334306,-1.70101,0.590575,1.586124,1.586124,-1.752598,-0.125517,0.091868,-1.071626,0.379826,-1.151713,0.225303,-0.697418,-0.012123,-1.184028,-0.032004,-0.067485,-0.067485,0.097434,-2.147739,-2.552266,-0.067455,-0.067455,0.593217,-0.649562,-0.649562,-0.423256,0.270992,0.005253,0.005253,-0.059475,-0.059475,0.415142,0.624903,1.460085,1.460085,-0.57029
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1519,-0.960542,-0.804647,-0.092706,-0.092706,-0.459747,-0.776192,-1.49708,1.472992,1.472992,1.215112,-0.260884,1.932424,1.12496,-0.04171,-0.04171,1.758383,0.87107,-0.046063,2.226214,0.373327,-0.016667,1.730929,-0.455415,0.032896,2.10888,2.74635,1.170535,0.062301,2.11611,0.43695,-0.091291,0.567891,0.328829,-1.097286,0.347239,-0.002168,…,1.041632,1.615778,-0.450138,0.258001,0.258001,-0.860806,-0.226697,1.13054,2.63622,-0.476386,1.224654,-0.849864,-1.087795,-0.004142,1.255055,-1.095059,-0.434664,-0.434664,0.030135,1.969524,0.00129,-0.417545,-0.417545,-1.124315,-0.994871,-0.994871,0.820444,-0.314741,-1.865203,-1.865203,-0.468029,-0.468029,1.419654,-0.695311,-0.387257,-0.387257,0.790494
1520,-1.009707,1.275387,-0.086301,-0.086301,-1.20894,1.294288,0.116636,-0.003635,-0.003635,-0.293308,0.560836,-1.368154,0.325644,-0.023029,-0.023029,-0.009204,0.591108,0.011156,3.997054,0.021724,0.000381,-0.653924,0.091066,0.043951,-1.373731,-0.306983,1.229914,0.001028,4.169611,0.363435,-0.084726,0.340282,-0.072727,1.213365,0.303273,0.009394,…,0.971451,-1.060379,-0.717461,0.658274,0.658274,0.346922,-0.441287,0.298863,2.205608,0.009822,0.495286,0.144293,-1.397655,-0.024666,0.521689,1.254137,-0.378165,-0.378165,-0.029611,-0.815667,-0.597778,-0.388318,-0.388318,1.198703,0.732307,0.732307,0.576604,-0.243254,-1.013675,-1.013675,-0.052795,-0.052795,-0.617059,-0.370447,0.851918,0.851918,0.949652
1521,-1.011257,1.426783,-0.044674,-0.044674,-0.513088,1.401085,0.173089,1.623886,1.623886,-0.772582,-0.254333,0.328829,-0.814267,0.013741,0.013741,-1.378117,0.199196,0.017206,-1.731424,-0.352703,0.02319,-1.056797,-1.222856,-0.001765,-1.382764,-0.254804,-1.169921,0.000309,-2.278234,-0.138944,-0.806531,0.436513,-0.2722,0.937505,0.071482,0.056605,…,-0.794188,0.042823,0.285481,0.169348,0.169348,-1.453235,-1.799926,-1.542321,-1.098614,0.170228,-0.71455,0.23943,0.299185,0.044893,-0.748403,1.051235,-0.431901,-0.431901,-0.038768,-1.475391,-1.42638,-0.423974,-0.423974,0.943803,0.175352,0.175352,0.37805,-0.228861,-1.26094,-1.26094,0.808815,0.808815,-0.731644,-1.005815,-0.202861,-0.202861,0.991861
1522,-0.971642,-1.937158,-0.01281,-0.01281,-0.907734,-1.965644,-0.135633,-0.325053,-0.325053,0.516323,1.927278,-0.072727,-0.118373,0.032001,0.032001,-0.06372,0.08499,-0.016568,-1.374882,-0.389102,0.025895,0.089888,0.049997,-0.028327,-1.309477,-0.272417,1.553362,-0.036827,-0.863802,-0.601661,0.167284,0.442838,-0.372716,-1.892661,-0.033468,0.065226,…,-0.118559,0.442015,0.651406,0.331279,0.331279,-0.617521,-0.238593,-0.410269,-0.127812,-1.052224,0.029804,-0.987344,-0.917916,0.042114,0.004822,-0.964702,-0.255183,-0.255183,-0.000566,-1.452903,-0.797426,-0.228269,-0.228269,-1.868637,2.466483,2.466483,-0.752396,-0.239898,0.131002,0.131002,3.295977,3.295977,0.00009,-0.162136,2.252719,2.252719,-0.816642


In [None]:
for date_id in tqdm(dates_valid):
    df_valid_date = train_x.filter(pl.col(CONFIG.DATE_COL).is_in(range(date_id - max(CONFIG.LAG_SEQ_LEN.values()) + 1, date_id + 1))).drop(
        CONFIG.DATE_COL
    )
    valid_lags = {lag: df_valid_date.select(features).to_numpy().astype(np.float64) for lag, features in CONFIG.LAG_FEATURES.items()}
    valid_lags = {lag: valid_lags[lag][-seq_len:] for lag, seq_len in CONFIG.LAG_SEQ_LEN.items()}
    break

  0%|          | 0/304 [00:00<?, ?it/s]

In [None]:
# Step 1: Prepare the numpy input
X_np = (
    df_valid.drop(
        CONFIG.DATE_COL,
    )
    .to_numpy()
    .reshape(-1, 8, 173)
)

# Step 2: Create leaf tensor on CUDA with requires_grad=True
X_sample = torch.tensor(X_np, dtype=torch.float32, device="cuda", requires_grad=True)  # ✅ single step

# Step 3: Put model into train mode to support CuDNN RNN backward
model_real.model.train()

# Step 4: Forward and backward pass
output = model_real.model(X_sample)  # (B, ...)
loss = output.mean()
loss.backward()

# Step 5: Retrieve input gradients — now this will work without warnings
grads = X_sample.grad.detach().cpu().numpy()  # shape: (B, T, F)
importance = np.abs(grads).mean(axis=(0, 1))  # shape: (F,)
importance /= importance.sum()

ValueError: cannot reshape array of size 2831456 into shape (8,173)

In [None]:
pl.DataFrame(importance).with_columns(pl.Series(name="feats", values=df_train.drop(CONFIG.DATE_COL).columns)).sort(by="column_0", descending=True)

column_0,feats
f32,str
0.011635,"""US_Stock_GLD_adj_low_log_ret_r…"
0.011528,"""US_Stock_NUGT_adj_close_log_re…"
0.010101,"""US_Stock_IGSB_adj_open_log_ret…"
0.00998,"""US_Stock_RY_adj_high_log_ret"""
0.009936,"""US_Stock_MPC_adj_low"""
…,…
0.003558,"""US_Stock_VCIT_adj_open_log_ret…"
0.003554,"""US_Stock_AMP_adj_low_log_ret"""
0.003419,"""US_Stock_LYB_adj_close_log_ret…"
0.003284,"""US_Stock_VGK_adj_close_log_ret"""
