In [1]:
import copy
import os
import warnings
import logging
from typing import List, Tuple, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.optimize import minimize
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import cohen_kappa_score
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder, QuantileTransformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
import lightgbm as lgb

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

######################################################
# Configuration
######################################################
WARMUP = True
N_SPLITS = 5
RANDOM_STATE = 1335
EPOCHS = 8
LR = [1e-3, 3e-3]
DATA_PATH = "/kaggle/input/child-mind-institute-problematic-internet-use"
DROP_NAN = True
CYCLES = 1
QWK_WEIGHT = 0.25
CE_WEIGHT = 0.75
WARMUP_RATIO = [0.0, 1.0]
BATCH_SIZE = 32
DROPOUT = 0.3

# Columns selected for training as per previous code
COLS = [
    "Basic_Demos-Enroll_Season", "CGAS-Season", "Physical-Season", "Fitness_Endurance-Season",
    "FGC-Season", "BIA-Season", "PAQ_C-Season", "SDS-Season", "PreInt_EduHx-Season",
    "FGC-FGC_PU", "BIA-BIA_SMM", "BIA-BIA_BMR", "BIA-BIA_FFMI", "BIA-BIA_TBW", "Basic_Demos-Sex",
    "BIA-BIA_LDM", "Fitness_Endurance-Time_Mins", "FGC-FGC_GSND", "Basic_Demos-Age", "Physical-HeartRate",
    "FGC-FGC_SRL", "Physical-Waist_Circumference", "Physical-Systolic_BP", "CGAS-CGAS_Score",
    "BIA-BIA_ECW", "PAQ_A-PAQ_A_Total", "FGC-FGC_SRR", "PreInt_EduHx-computerinternet_hoursday",
    "SDS-SDS_Total_Raw", "FGC-FGC_GSD", "PAQ_C-PAQ_C_Total", "BIA-BIA_BMI", "Fitness_Endurance-Time_Sec",
    "Physical-Height", "SDS-SDS_Total_T", "FGC-FGC_CU", "Physical-Weight", "FGC-FGC_TL",
    "Physical-Diastolic_BP", "Physical-BMI", "Fitness_Endurance-Max_Stage", "BIA-BIA_FMI", "BIA-BIA_BMC",
    "BIA-BIA_DEE", "BIA-BIA_ICW", "BIA-BIA_Fat", "BIA-BIA_LST", "BIA-BIA_Activity_Level_num",
]

######################################################
# Utility Functions
######################################################
def set_seed(seed: int):
    """Fix random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def quadratic_weighted_kappa(predictions, targets):
    """Compute Quadratic Weighted Kappa for given predictions and targets."""
    return cohen_kappa_score(predictions, targets, weights="quadratic")

def threshold_rounder(oof_non_rounded: np.ndarray, thresholds: List[float]) -> np.ndarray:
    """Round continuous predictions into discrete classes based on thresholds."""
    return np.where(
        oof_non_rounded < thresholds[0],
        0,
        np.where(
            oof_non_rounded < thresholds[1],
            1,
            np.where(oof_non_rounded < thresholds[2], 2, 3),
        ),
    )

def evaluate_predictions(thresholds: List[float], y_true: np.ndarray, oof_non_rounded: np.ndarray) -> float:
    """Evaluate predictions by computing negative QWK for use in optimization."""
    rounded_p = threshold_rounder(oof_non_rounded, thresholds)
    return -quadratic_weighted_kappa(y_true, rounded_p)

def create_preprocessor(categorical_features: List[int], numerical_features: List[int]):
    """Create a ColumnTransformer for preprocessing categorical and numerical features."""
    # More robust imputations can be done here (like KNNImputer), but here we keep it simple.
    return ColumnTransformer(
        transformers=[
            (
                "num",
                Pipeline(
                    [
                        ("imputer", SimpleImputer(strategy="median")),
                        ("scaler", QuantileTransformer()),
                    ]
                ),
                numerical_features,
            ),
            (
                "cat",
                Pipeline(
                    [
                        (
                            "imputer",
                            SimpleImputer(strategy="constant", fill_value="missing"),
                        ),
                        (
                            "encoder",
                            OrdinalEncoder(
                                handle_unknown="use_encoded_value", unknown_value=-1
                            ),
                        ),
                    ]
                ),
                categorical_features,
            ),
        ]
    )

######################################################
# Dataset and Model Classes
######################################################
class HybridDataset(Dataset):
    """PyTorch Dataset to handle hybrid (tabular + time-series) data."""

    def __init__(self, tabular_data: np.ndarray, ids: np.ndarray, targets: np.ndarray = None, cat_len=9, is_test=False):
        self.categorical_data = torch.LongTensor(tabular_data[:, :cat_len])
        self.numerical_data = torch.FloatTensor(tabular_data[:, cat_len:-1])
        self.ts_indicator = tabular_data[:, -1]
        self.ids = ids
        self.is_test = is_test
        if targets is not None:
            self.targets = torch.LongTensor(targets)
        else:
            self.targets = None

    def __len__(self):
        return len(self.numerical_data)

    def __getitem__(self, idx: int):
        categorical = self.categorical_data[idx]
        numerical = self.numerical_data[idx]

        # Check if time-series data exists
        if self.ts_indicator[idx] == 1:
            file_path = os.path.join(
                DATA_PATH,
                f"series_{'test' if self.is_test else 'train'}.parquet/id={self.ids[idx]}/part-0.parquet",
            )
            if not os.path.exists(file_path):
                # Gracefully handle missing files
                logger.warning(f"Time-series file not found for id={self.ids[idx]}, using zeros.")
                time_series = torch.zeros((1, 7))
            else:
                time_series_df = pd.read_parquet(file_path)
                # Basic sanity checks
                if len(time_series_df) == 0:
                    logger.warning(f"No time-series data for id={self.ids[idx]}, using zeros.")
                    time_series = torch.zeros((1, 7))
                else:
                    # Extract relevant columns
                    numerical_ts = time_series_df.iloc[:, [1, 2, 3, 4, 5]].values
                    timestamp = (time_series_df.iloc[:, 9] / 5000000000).round(3).values
                    weekday = time_series_df.iloc[:, 10].values - 1
                    # Ensure weekday bounds
                    weekday = np.clip(weekday, 0, 6)
                    time_series_arr = np.column_stack([numerical_ts, timestamp, weekday])
                    time_series = torch.FloatTensor(time_series_arr[::100])
        else:
            time_series = torch.zeros((1, 7))

        if self.targets is not None:
            target = self.targets[idx]
            return categorical, numerical, time_series, target
        return categorical, numerical, time_series

def collate_fn(batch: List[Any]):
    """Collate function for train/validation data loaders."""
    categorical, numerical, time_series, targets = zip(*batch)
    time_series_padded = pad_sequence(time_series, batch_first=True)
    return (
        torch.stack(categorical),
        torch.stack(numerical),
        time_series_padded,
        torch.stack(targets),
    )

def collate_fn_test(batch: List[Any]):
    """Collate function for test data loader."""
    categorical, numerical, time_series = zip(*batch)
    time_series_padded = pad_sequence(time_series, batch_first=True)
    return (
        torch.stack(categorical),
        torch.stack(numerical),
        time_series_padded,
    )

class SimpleTimeSeriesEncoder(nn.Module):
    """Simple time-series encoder using Conv + LSTM + Embedding for weekday."""
    def __init__(self, input_dim, hidden_dim, n_layers=2):
        super().__init__()
        self.conv = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, n_layers, batch_first=True)
        self.weekday_embedding = nn.Embedding(7, 8)
        self.fc = nn.Linear(hidden_dim + 8, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # x: [batch, seq_len, 7], last col is weekday
        x_num, x_cat = x[:, :, :-1], x[:, :, -1].long()

        x_num = x_num.transpose(1, 2)
        x_num = self.conv(x_num)
        x_num = x_num.transpose(1, 2)

        _, (h_n, _) = self.lstm(x_num)
        x_num = h_n[-1]

        x_cat = self.weekday_embedding(x_cat[:, -1])
        x_combined = torch.cat([x_num, x_cat], dim=1)
        x_combined = self.fc(x_combined)
        x_combined = self.layer_norm(x_combined)
        return self.dropout(x_combined)

class ResidualBlock(nn.Module):
    """A simple Residual Block for numerical feature encoding."""
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm1d(out_channels),
            )

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = 0.7 * out + 0.3 * residual
        out = self.relu(out)
        return out

class NumericalEncoder(nn.Module):
    """Encoder for numerical features using a small CNN + ResidualBlocks."""
    def __init__(self, numerical_dim, hidden_dim):
        super(NumericalEncoder, self).__init__()
        self.initial_conv = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1)
        self.bn_initial = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.res_block1 = ResidualBlock(hidden_dim, hidden_dim)
        self.res_block2 = ResidualBlock(hidden_dim, hidden_dim)
        self.final_conv = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
        self.bn_final = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(DROPOUT)
        self.fc = nn.Linear(hidden_dim * numerical_dim, hidden_dim)

    def forward(self, x):
        # x: [batch, numerical_dim]
        x = x.unsqueeze(1)
        x = self.relu(self.bn_initial(self.initial_conv(x)))
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.relu(self.bn_final(self.final_conv(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class HybridModel(nn.Module):
    """Hybrid model combining categorical embeddings, numerical encoding and time-series encoding."""
    def __init__(
        self,
        categorical_dims,
        numerical_dim,
        time_series_dim,
        embedding_dim,
        hidden_dim,
        num_classes,
    ):
        super().__init__()
        self.embeddings = nn.ModuleList(
            [
                nn.Embedding(dim + 1, embedding_dim, padding_idx=dim)
                for dim in categorical_dims
            ]
        )

        self.numerical_encoder = NumericalEncoder(numerical_dim, hidden_dim)
        self.time_series_encoder = SimpleTimeSeriesEncoder(
            time_series_dim - 1, hidden_dim
        )

        combined_dim = len(categorical_dims) * embedding_dim + hidden_dim * 2
        self.classifier = nn.Sequential(
            NumericalEncoder(combined_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(hidden_dim // 2, num_classes),
        )

    def forward(self, categorical, numerical, time_series):
        embedded = []
        for i, emb in enumerate(self.embeddings):
            clamped_indices = torch.clamp(categorical[:, i], 0, emb.num_embeddings - 1)
            embedded.append(emb(clamped_indices))
        embedded = torch.cat(embedded, dim=1)

        numerical_features = self.numerical_encoder(numerical)
        time_series_features = self.time_series_encoder(time_series)

        combined = torch.cat(
            [embedded, numerical_features, time_series_features], dim=1
        )
        return self.classifier(combined)

class QuadraticWeightedKappaLoss(nn.Module):
    """A loss function approximating QWK for training."""
    def __init__(self, num_classes, epsilon=1e-10):
        super().__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon

    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        weight_mat = torch.zeros(
            (self.num_classes, self.num_classes), device=pred.device
        )
        for i in range(self.num_classes):
            for j in range(self.num_classes):
                weight_mat[i, j] = (i - j) ** 2

        conf_mat = torch.zeros((self.num_classes, self.num_classes), device=pred.device)
        for i in range(self.num_classes):
            for j in range(self.num_classes):
                conf_mat[i, j] = torch.sum((target == i) * pred[:, j])

        conf_mat = conf_mat / torch.sum(conf_mat)
        row_sum = torch.sum(conf_mat, dim=1)
        col_sum = torch.sum(conf_mat, dim=0)
        expected = torch.outer(row_sum, col_sum)

        numerator = torch.sum(weight_mat * conf_mat)
        denominator = torch.sum(weight_mat * expected)
        qwk = numerator / (denominator + self.epsilon)

        return qwk

######################################################
# Training and Evaluation Functions
######################################################
def train_and_evaluate(
    model,
    train_loader,
    val_loader,
    epochs,
    lr,
    device,
    patience=10,
    fold=1,
    wur=0.0,
    it=0,
):
    """Train and evaluate the model for a single fold."""
    ce_loss = nn.CrossEntropyLoss()
    qwk_loss = QuadraticWeightedKappaLoss(num_classes=4)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    total_steps = len(train_loader) * epochs
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(wur * total_steps) if WARMUP else 0,
        num_training_steps=total_steps,
        num_cycles=1,
    )

    best_val_qwk = 0
    epochs_no_improve = 0
    pbar = tqdm(range(epochs), desc=f"Fold {fold}, LR={lr}")
    for _ in pbar:
        model.train()
        train_loss = 0.0
        for categorical, numerical, time_series, targets in train_loader:
            categorical, numerical, time_series, targets = (
                categorical.to(device),
                numerical.to(device),
                time_series.to(device),
                targets.to(device),
            )
            optimizer.zero_grad()
            outputs = model(categorical, numerical, time_series)
            loss = CE_WEIGHT * ce_loss(outputs, targets) + QWK_WEIGHT * qwk_loss(
                outputs, targets
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        with torch.no_grad():
            for (categorical, numerical, time_series, targets) in val_loader:
                categorical, numerical, time_series, targets = (
                    categorical.to(device),
                    numerical.to(device),
                    time_series.to(device),
                    targets.to(device),
                )
                outputs = model(categorical, numerical, time_series)
                loss = CE_WEIGHT * ce_loss(outputs, targets) + QWK_WEIGHT * qwk_loss(
                    outputs, targets
                )
                val_loss += loss.item()
                temp_val_preds = (
                    (outputs.softmax(dim=1) * torch.tensor([0, 1, 2, 3], device=device))
                    .sum(dim=1)
                    .cpu()
                    .numpy()
                )
                temp_targets = targets.cpu().numpy()
                val_preds.extend(temp_val_preds)
                val_targets.extend(temp_targets)

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_preds_rounded = np.round(val_preds).astype(int)
        val_qwk = quadratic_weighted_kappa(val_targets, val_preds_rounded)

        pbar.set_postfix(
            {
                "Train Loss": f"{train_loss:.4f}",
                "Val Loss": f"{val_loss:.4f}",
                "Val QWK": f"{val_qwk:.4f}",
            }
        )

        if val_qwk > best_val_qwk:
            best_val_qwk = val_qwk
            torch.save(model.state_dict(), f"best_model_fold_{fold}_{it}.pth")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            pbar.write("Early stopping triggered")
            break

    # Load best model and compute OOF predictions
    oof_predictions = np.zeros(len(val_loader.dataset))
    oof_targets = np.zeros(len(val_loader.dataset))
    model.load_state_dict(torch.load(f"best_model_fold_{fold}_{it}.pth"))
    model.eval()
    with torch.no_grad():
        start_idx = 0
        for i, (categorical, numerical, time_series, targets) in enumerate(val_loader):
            categorical, numerical, time_series, targets = (
                categorical.to(device),
                numerical.to(device),
                time_series.to(device),
                targets.to(device),
            )
            outputs = model(categorical, numerical, time_series)
            temp_val_preds = (
                (outputs.softmax(dim=1) * torch.tensor([0, 1, 2, 3], device=device))
                .sum(dim=1)
                .cpu()
                .numpy()
            )
            temp_targets = targets.cpu().numpy()
            batch_size = targets.size(0)
            end_idx = start_idx + batch_size
            oof_predictions[start_idx:end_idx] = temp_val_preds
            oof_targets[start_idx:end_idx] = temp_targets
            start_idx = end_idx

    fold_qwk = cohen_kappa_score(
        oof_predictions.round(), oof_targets, weights="quadratic"
    )
    logger.info(f"Fold {fold} QWK: {fold_qwk:.4f}")
    return fold_qwk, oof_predictions, oof_targets

######################################################
# Main Training Function
######################################################
def train_main():
    """Main training function with cross-validation and threshold optimization."""
    train_df = pd.read_csv(os.path.join(DATA_PATH, "train.csv"))
    if DROP_NAN:
        # Ensure we have target values
        train_df = train_df.dropna(subset=["sii"])
    else:
        # If not dropping, we can impute or ignore, here we fill with 0
        train_df["sii"] = train_df["sii"].fillna(0)

    # Basic data checks
    if train_df["sii"].isnull().any():
        raise ValueError("Some target values are still missing!")

    # Extract features and target
    tabular_data = train_df[COLS]
    targets = train_df["sii"].values
    ids = train_df["id"].values

    categorical_features = list(range(9))
    numerical_features = list(range(9, len(COLS)))

    # Stratified split
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

    fold_qwks = []
    all_oof_predictions = np.zeros(len(targets))
    all_oof_targets = np.zeros(len(targets))
    preprocessors = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Training on device: {device}")

    for fold, (train_idx, val_idx) in enumerate(skf.split(tabular_data, targets), 1):
        logger.info(f"Starting fold {fold}")
        preprocessor = create_preprocessor(categorical_features, numerical_features)

        fold_train_data = tabular_data.iloc[train_idx]
        fold_val_data = tabular_data.iloc[val_idx]

        tabular_train = preprocessor.fit_transform(fold_train_data)
        tabular_val = preprocessor.transform(fold_val_data)

        preprocessors.append(preprocessor)

        ids_train, ids_val = ids[train_idx], ids[val_idx]
        y_train, y_val = targets[train_idx], targets[val_idx]

        # Check time-series data existence
        def ts_exists(x, mode='train'):
            path = os.path.join(DATA_PATH, f"series_{mode}.parquet/id={x}/part-0.parquet")
            return 1 if os.path.exists(path) else 0

        train_ts_indicator = np.array([ts_exists(x, 'train') for x in ids_train]).reshape(-1, 1)
        val_ts_indicator = np.array([ts_exists(x, 'train') for x in ids_val]).reshape(-1, 1)

        tabular_train = np.column_stack([tabular_train, train_ts_indicator])
        tabular_val = np.column_stack([tabular_val, val_ts_indicator])

        categorical_dims = [
            len(
                preprocessor.named_transformers_["cat"]
                .named_steps["encoder"]
                .categories_[i]
            )
            for i in range(len(categorical_features))
        ]

        train_dataset = HybridDataset(tabular_train, ids_train, y_train)
        val_dataset = HybridDataset(tabular_val, ids_val, y_val)

        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=4,
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=4,
        )

        model = HybridModel(
            categorical_dims=categorical_dims,
            numerical_dim=tabular_train.shape[1] - len(categorical_dims) - 1,
            time_series_dim=7,
            embedding_dim=8,
            hidden_dim=128,
            num_classes=4,
        ).to(device)

        # Train first phase
        fold_qwk_1, oof_predictions_1, oof_targets_1 = train_and_evaluate(
            model,
            train_loader,
            val_loader,
            epochs=EPOCHS,
            lr=LR[0],
            device=device,
            patience=15,
            fold=fold,
            wur=WARMUP_RATIO[0],
            it=0,
        )

        # Train second phase (with different LR and warmup)
        fold_qwk_2, oof_predictions_2, oof_targets_2 = train_and_evaluate(
            copy.deepcopy(model),
            train_loader,
            val_loader,
            epochs=EPOCHS,
            lr=LR[1],
            device=device,
            patience=15,
            fold=fold,
            wur=WARMUP_RATIO[1],
            it=1,
        )

        # Combine predictions from two training runs
        combined_predictions = (oof_predictions_1 + oof_predictions_2) / 2
        combined_targets = (oof_targets_1 + oof_targets_2) / 2

        fold_qwks.extend([fold_qwk_1, fold_qwk_2])
        all_oof_predictions[val_idx] = combined_predictions
        all_oof_targets[val_idx] = combined_targets

    # Optimize thresholds on OOF predictions
    Kappa_optimizer = minimize(
        evaluate_predictions,
        x0=[0.5, 1.5, 2.5],
        args=(all_oof_targets, all_oof_predictions),
        method="Nelder-Mead",
    )

    if not Kappa_optimizer.success:
        logger.warning("Threshold optimization did not converge. Using default thresholds.")
        optimal_thresholds = [0.5, 1.5, 2.5]
    else:
        optimal_thresholds = Kappa_optimizer.x

    oof_tuned = threshold_rounder(all_oof_predictions, optimal_thresholds)
    kappa_optimized = quadratic_weighted_kappa(all_oof_targets, oof_tuned)

    default_thresholds = [0.5, 1.5, 2.5]
    oof_not_tuned = threshold_rounder(all_oof_predictions, default_thresholds)
    oof_qwk = quadratic_weighted_kappa(all_oof_targets, oof_not_tuned)

    logger.info(f"Mean of fold QWKs: {np.mean(fold_qwks):.4f}")
    logger.info(f"OOF QWK (not optimized): {oof_qwk:.4f}")
    logger.info(f"OOF QWK (optimized): {kappa_optimized:.4f}")
    logger.info(f"Optimal thresholds: {optimal_thresholds}")

    return preprocessors, optimal_thresholds

######################################################
# LightGBM Ensemble (Optional)
######################################################
def train_lightgbm(train_df: pd.DataFrame, preprocessors, targets: np.ndarray):
    """Train a LightGBM model on the same features as a fallback or ensemble model."""
    # Use the first preprocessor as a reference (assuming they are quite similar after fit)
    preprocessor = preprocessors[0]

    # Preprocess data fully
    tab_data = preprocessor.transform(train_df[COLS])
    # No time-series indicator here for simplicity
    # You may want to incorporate it if it's a known feature
    # For demonstration, we just assume no TS data used in LightGBM

    # Basic LGBM model
    lgb_params = {
        "objective": "multiclass",
        "num_class": 4,
        "verbosity": -1,
        "seed": RANDOM_STATE
    }
    lgb_train = lgb.Dataset(tab_data, label=targets)
    model = lgb.train(lgb_params, lgb_train, num_boost_round=100)
    return model

######################################################
# Inference Function
######################################################
def inference(preprocessors, optimal_thresholds):
    """Perform inference on the test set using saved models and ensemble the predictions."""
    test_df = pd.read_csv(os.path.join(DATA_PATH, "test.csv"))
    tabular_data = test_df[COLS]
    ids = test_df["id"].values

    # Check for time-series indicator
    def ts_exists(x, mode='test'):
        path = os.path.join(DATA_PATH, f"series_{mode}.parquet/id={x}/part-0.parquet")
        return 1 if os.path.exists(path) else 0

    ts_indicator = np.array([ts_exists(x, 'test') for x in ids]).reshape(-1, 1)

    # Predictions from neural nets
    test_predictions_nn = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info("Starting inference with neural network models.")

    for fold in range(N_SPLITS):
        # Preprocess test data using fold-specific preprocessor
        tabular_data_processed = preprocessors[fold].transform(tabular_data)
        tabular_data_processed = np.column_stack([tabular_data_processed, ts_indicator])

        # Same structure as training dataset
        categorical_features = list(range(9))
        categorical_dims = [
            len(
                preprocessors[fold].named_transformers_["cat"]
                .named_steps["encoder"]
                .categories_[i]
            )
            for i in range(len(categorical_features))
        ]

        test_dataset = HybridDataset(tabular_data_processed, ids, is_test=True, cat_len=9)
        test_loader = DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            collate_fn=collate_fn_test,
            num_workers=4,
        )

        # Rebuild model with same structure
        model = HybridModel(
            categorical_dims=categorical_dims,
            numerical_dim=tabular_data_processed.shape[1] - len(categorical_dims) - 1,
            time_series_dim=7,
            embedding_dim=8,
            hidden_dim=128,
            num_classes=4,
        ).to(device)

        fold_preds = []
        # Ensemble from both training phases (two saved models per fold)
        for it in range(2):
            model_path = f"best_model_fold_{fold+1}_{it}.pth"
            if not os.path.exists(model_path):
                logger.warning(f"Model checkpoint missing: {model_path}")
                continue
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.eval()
            preds_tmp = []
            with torch.no_grad():
                for categorical, numerical, time_series in test_loader:
                    categorical, numerical, time_series = (
                        categorical.to(device),
                        numerical.to(device),
                        time_series.to(device),
                    )
                    outputs = model(categorical, numerical, time_series)
                    preds_tmp.extend(
                        (outputs.softmax(dim=1) * torch.tensor([0, 1, 2, 3], device=device))
                        .sum(dim=1)
                        .cpu()
                        .numpy()
                    )
            fold_preds.append(preds_tmp)

        if len(fold_preds) > 0:
            # Average predictions from the two model variants for this fold
            fold_preds = np.mean(fold_preds, axis=0)
            test_predictions_nn.append(fold_preds)

    # Average predictions across folds
    if len(test_predictions_nn) == 0:
        logger.error("No neural network predictions were generated.")
        test_predictions_nn_final = np.zeros(len(test_df))
    else:
        test_predictions_nn_final = np.mean(test_predictions_nn, axis=0)

    # Ensemble with LightGBM (Optional)
    # Load training data again and train LGBM
    train_df = pd.read_csv(os.path.join(DATA_PATH, "train.csv"))
    if DROP_NAN:
        train_df = train_df.dropna(subset=["sii"])
    else:
        train_df["sii"] = train_df["sii"].fillna(0)
    targets = train_df["sii"].values

    lgb_model = train_lightgbm(train_df, preprocessors, targets)
    tabular_test_processed = preprocessors[0].transform(tabular_data)
    lgb_preds = lgb_model.predict(tabular_test_processed)
    # Convert LGBM multiclass probs to a weighted sum
    lgb_preds_cont = (lgb_preds * np.array([0,1,2,3])).sum(axis=1)

    # Ensemble neural net and LGBM predictions by simple averaging
    final_predictions_cont = (test_predictions_nn_final + lgb_preds_cont) / 2

    # Threshold and round
    final_predictions_class = threshold_rounder(final_predictions_cont, optimal_thresholds)
    test_df["sii"] = final_predictions_class
    test_df[["id", "sii"]].to_csv("submission.csv", index=False)

    logger.info("Inference complete. Submission saved to submission.csv.")

######################################################
# Main Execution
######################################################
if __name__ == "__main__":
    set_seed(RANDOM_STATE)
    preprocessors, optimal_thresholds = train_main()
    inference(preprocessors, optimal_thresholds)


Fold 1, LR=0.001: 100%|██████████| 8/8 [02:47<00:00, 20.90s/it, Train Loss=0.8192, Val Loss=0.8665, Val QWK=0.3264]
  model.load_state_dict(torch.load(f"best_model_fold_{fold}_{it}.pth"))
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
Fold 1, LR=0.003: 100%|██████████| 8/8 [02:38<00:00, 19.79s/it, Train Loss=0.8931, Val Loss=0.9319, Val QWK=0.2688]
  model.load_state_dict(torch.load(f"best_model_fold_{fold}_{it}