In [None]:
import torch
import random
import pandas as pd
import numpy as np
import os
from torch import nn

from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
import cv2

import json
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset

from sklearn.utils.class_weight import compute_class_weight

from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, StandardScaler

# Local application/library imports
from utils import load_search_space

import optuna

from sklearn.metrics import (
    RocCurveDisplay, PrecisionRecallDisplay,
    ConfusionMatrixDisplay, roc_auc_score, average_precision_score
)

## DATASET

In [None]:
SEED = 64

# Set random seeds
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [None]:
# Dataset Info
# adult_income_cleaned, framingham_cleaned, preprocessed_heloc, diabetes
dataset_name = 'boston'        
dataset_subpath = 'Regression/boston'       
task_type = 'Regression'

In [None]:
df = pd.read_csv(f"./data/{dataset_subpath}/{dataset_name}.csv")

In [None]:
df.shape

In [None]:
df.head()

## LOAD AND PREPROCESS

In [None]:
def prepare_target_tensor(y, task):
    task = task.lower()
    if isinstance(y, pd.Series):
        y = y.to_numpy()
    elif isinstance(y, list):
        y = np.array(y)
        
    if task == "regression" or task == "binary":
        return torch.as_tensor(y, dtype=torch.float32).reshape(-1, 1)
    elif task == "multiclass":
        return torch.as_tensor(y, dtype=torch.long)
    else:
        raise ValueError(f"Unsupported task type: {task}")

In [None]:
def load_and_preprocess_data(df, dataset_name, task_type, seed=42, batch_size=32, device='cpu'):
    task_type = task_type.lower()

    # Load config
    with open(f"./configs/preprocess/{dataset_name}.json") as f:
        config = json.load(f)

    categorical_cols = config["categorical_cols"]
    numerical_cols = config["numerical_cols"]
    encoding = config["encoding"]

    # Extract features and target
    X = df[numerical_cols + categorical_cols].copy()
    y = df.iloc[:, -1].copy()

    # Encode target if needed
    le = None
    if encoding.get("target") == "label":
        le = LabelEncoder()
        y = le.fit_transform(y)
        label_mapping = dict(zip(le.classes_, le.transform(le.classes_)))
    else:
        label_mapping = None

    # Split raw data before transformation
    if task_type == "regression":
        # For regression, we can use a simple split
        X_train_raw, X_temp_raw, y_train, y_temp = train_test_split(
            X, y, test_size=0.3, random_state=seed
        )
        X_val_raw, X_test_raw, y_val, y_test = train_test_split(
            X_temp_raw, y_temp, test_size=0.5, random_state=seed
        )
    else:
        # For classification, we need stratified splits
        X_train_raw, X_temp_raw, y_train, y_temp = train_test_split(
            X, y, test_size=0.3, random_state=seed, stratify=y
        )
        X_val_raw, X_test_raw, y_val, y_test = train_test_split(
            X_temp_raw, y_temp, test_size=0.5, random_state=seed, stratify=y_temp
        )

    # Compute class weights for classification
    class_weight = None
    if task_type in ["binary", "multiclass"]:
        # Compute raw weights
        class_weight_values = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
        classes_sorted = np.sort(np.unique(y_train))
        
        if task_type == "binary":
            # Compute pos_weight = weight for class 1 / weight for class 0
            weight_dict = dict(zip(classes_sorted, class_weight_values))
            pos_weight = weight_dict[1] / weight_dict[0]
            class_weight = torch.tensor(pos_weight, dtype=torch.float32).to(device)
            print(f"Binary pos_weight (for BCEWithLogitsLoss): {class_weight.item()}")

        elif task_type == "multiclass":
            class_weight = torch.tensor(class_weight_values, dtype=torch.float32).to(device)
            print(f"Multiclass class weights (for CrossEntropyLoss): {class_weight.tolist()}")

    # Transform numerical and categorical features
    transformers = []

    if encoding["numerical_features"] == "minmax":
        transformers.append(("num", MinMaxScaler(), numerical_cols))
    elif encoding["numerical_features"] == "standard":
        transformers.append(("num", StandardScaler(), numerical_cols))

    if categorical_cols and encoding["categorical_features"] == "onehot":
        transformers.append(("cat", OneHotEncoder(sparse_output=False, handle_unknown="ignore"), categorical_cols))

    if transformers:
        preprocessor = ColumnTransformer(transformers=transformers)
        X_train = preprocessor.fit_transform(X_train_raw)
        X_val = preprocessor.transform(X_val_raw)
        X_test = preprocessor.transform(X_test_raw)

        # Recover transformed column names
        if "cat" in preprocessor.named_transformers_:
            cat_feature_names = preprocessor.named_transformers_["cat"].get_feature_names_out(categorical_cols)
            all_feature_names = numerical_cols + list(cat_feature_names)
        else:
            all_feature_names = numerical_cols + categorical_cols

        X_train_num = pd.DataFrame(X_train, columns=all_feature_names, index=X_train_raw.index)
        X_val_num = pd.DataFrame(X_val, columns=all_feature_names, index=X_val_raw.index)
        X_test_num = pd.DataFrame(X_test, columns=all_feature_names, index=X_test_raw.index)
    else:
        all_feature_names = numerical_cols + categorical_cols  # or keep original order
        X_train_num = pd.DataFrame(X_train_raw, columns=all_feature_names, index=X_train_raw.index)
        X_val_num = pd.DataFrame(X_val_raw, columns=all_feature_names, index=X_val_raw.index)
        X_test_num = pd.DataFrame(X_test_raw, columns=all_feature_names, index=X_test_raw.index)


    print(f"Shapes — Train: {X_train_num.shape}, Val: {X_val_num.shape}, Test: {X_test_num.shape}")
    print(f"Numerical features: {len(numerical_cols)} — {numerical_cols}")
    print(f"Categorical features: {len(categorical_cols)} — {categorical_cols}")
    print(f"Total features: {X_train_num.shape[1]}")
    if label_mapping:
        print(f"Target label mapping: {label_mapping}")
    

    attributes = len(X_train_num.columns)

    print("Attributes: ", attributes)
    # Convert data to PyTorch tensors
    X_train_num_tensor = torch.as_tensor(X_train_num.values, dtype=torch.float32)
    X_val_num_tensor = torch.as_tensor(X_val_num.values, dtype=torch.float32)
    X_test_num_tensor = torch.as_tensor(X_test_num.values, dtype=torch.float32)
    y_train_tensor = prepare_target_tensor(y_train, task_type)
    y_val_tensor = prepare_target_tensor(y_val, task_type)
    y_test_tensor = prepare_target_tensor(y_test, task_type)

    # Normalize to [0, 1]
    #X_train_img_tensor = X_train_img_tensor / 255.0
    #X_val_img_tensor = X_val_img_tensor / 255.0
    #X_test_img_tensor = X_test_img_tensor / 255.0

    # Create DataLoaders
    train_dataset = TensorDataset( X_train_num_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_num_tensor, y_val_tensor)
    test_dataset = TensorDataset(X_test_num_tensor, y_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    return train_loader, val_loader, test_loader, attributes,  le, class_weight

## MODEL ARCHITECTURES

### MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, attributes, params, task, num_classes=None):
        super(MLP, self).__init__()

        # MLP branch
        mlp_layers = []
        input_dim = attributes
        for hidden_dim in params["mlp_hidden_dims"]:
            mlp_layers.append(nn.Linear(input_dim, hidden_dim))
            mlp_layers.append(nn.ReLU())
            input_dim = hidden_dim

        # Determine output layer
        output_dim = 1 if task in ['regression', 'binary'] else num_classes
        mlp_layers.append(nn.Linear(input_dim, output_dim))
        self.mlp = nn.Sequential(*mlp_layers) 

        # Change identity to something else if needed
        self.activation = nn.Identity()

    def forward(self, num_input):
        x = self.mlp(num_input)
        return self.activation(x)


## COMPILE AND FIT

In [None]:
import gc
import copy

from models.utils import get_loss_fn, calculate_metrics, calculate_metrics_from_numpy, get_class_weighted_loss_fn

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torch.optim.lr_scheduler import OneCycleLR
import matplotlib.pyplot as plt
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import os

def compile_and_fit(model, train_loader, val_loader, test_loader, dataset_name, 
                    model_name, trial_name=None, task='regression', epochs=200, max_lr=1, 
                    div_factor=10, final_div_factor=1, device='cuda', weight_decay=1e-2, save_model=False, class_weights=None, save_dir=None, study=None, verbose=False):
    model = model.to(device)
    
    if class_weights != None:
        loss_fn = get_class_weighted_loss_fn(task, class_weights)
    else:
        loss_fn = get_loss_fn(task)

    # Compute min_lr from max_lr and div_factor
    min_lr = max_lr / div_factor

    optimizer = optim.AdamW(model.parameters(), lr=min_lr, weight_decay=weight_decay)
    
    total_steps = epochs * len(train_loader)
    scheduler = OneCycleLR(optimizer, max_lr=max_lr, div_factor=div_factor, final_div_factor=final_div_factor, total_steps=total_steps, pct_start=0.3, anneal_strategy="cos")
    
    best_val_loss = float('inf')
    best_model = None
    best_epoch = 0
    #early_stopping_counter = 0
    #patience = 10  # Early stopping patience

    history = {'train_loss': [], 'val_loss': [], 'learning_rate': [], 'epoch_time': []}

    if task == 'regression':
        history.update({'train_mse': [],  'val_mse': [], 'train_mae': [],  'val_mae': [], 'train_rmse': [], 'val_rmse': [], 'train_r2': [], 'val_r2': []})
    elif task in ['binary', 'multiclass']:
        history.update({'train_accuracy': [], 'val_accuracy': [], 'train_precision': [], 'val_precision': [], 'train_recall': [], 'val_recall': [], 'train_f1': [], 'val_f1': []})

    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()

        model.train()
        train_loss = 0.0
        train_preds = []
        train_targets = []

        for num_data, targets in train_loader:
            num_data, targets = num_data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            outputs = model(num_data)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            train_preds.extend(outputs.cpu().detach().numpy())
            train_targets.extend(targets.cpu().numpy())


        train_loss /= len(train_loader)
        if task == 'multiclass':
            y_train_pred = np.vstack(train_preds)
            y_train_true = train_targets
        else:
            y_train_pred = np.concatenate(train_preds)
            y_train_true = np.concatenate(train_targets)
            
        train_metrics = calculate_metrics_from_numpy(y_train_true, y_train_pred, task)

        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        with torch.no_grad():
            for num_data, targets in val_loader:
                num_data, targets = num_data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                outputs = model(num_data)
                loss = loss_fn(outputs, targets)
                
                val_loss += loss.item()
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        if task == 'multiclass':
            y_val_pred = np.vstack(val_preds)
            y_val_true = val_targets
        else:
            y_val_pred = np.concatenate(val_preds)
            y_val_true = np.concatenate(val_targets)
        
        val_metrics = calculate_metrics_from_numpy(y_val_true, y_val_pred, task)
        
        # Get the current learning rate
        current_lr = scheduler.get_last_lr()

        epoch_time = time.time() - epoch_start_time

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['learning_rate'].append(current_lr)
        history['epoch_time'].append(epoch_time)

        for k, v in train_metrics.items():
            history[f'train_{k}'].append(v)
        for k, v in val_metrics.items():
            history[f'val_{k}'].append(v)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            best_epoch = epoch + 1
            #early_stopping_counter = 0
        #else:
        #    early_stopping_counter += 1
        #    if early_stopping_counter >= patience:
        #        print(f"Early stopping at epoch {epoch + 1}")
        #        break

    total_time = time.time() - start_time
    model.load_state_dict(best_model)

    # Recompute metrics using the best model
    train_metrics, y_true_train, y_pred_train, y_prob_train = calculate_metrics(model, train_loader, device, class_weights, task)
    val_metrics, y_true_val, y_pred_val, y_prob_val  = calculate_metrics(model, val_loader, device, class_weights, task)
    test_metrics, y_true_test, y_pred_test, y_prob_test = calculate_metrics(model, test_loader, device, class_weights, task)

    # Store recomputed metrics
    metrics = {
        'train_loss': train_metrics['loss'],
        'val_loss': val_metrics['loss'],
        'test_loss': test_metrics['loss'],
        'min_lr': min_lr,
        'max_lr': max_lr,
        'total_time': total_time,
        'average_epoch_time': sum(history['epoch_time']) / len(history['epoch_time'])
    }

    # Add task-specific metrics
    for k in train_metrics:
        if k != 'loss':
            metrics[f'train_{k}'] = train_metrics[k]
    for k in val_metrics:
        if k != 'loss':
            metrics[f'val_{k}'] = val_metrics[k]
    for k in test_metrics:
        if k != 'loss':
            metrics[f'test_{k}'] = test_metrics[k]
    
    if verbose:   
        print(f"\nTraining completed in {total_time:.2f} seconds")
        print(f"Best model found at epoch {best_epoch}/{epochs}")
        print(f"Best Train Loss: {metrics['train_loss']:.4f}, Best Val Loss: {metrics['val_loss']:.4f}")
        print(metrics)
    
    if save_model:
        save_path = os.path.join(save_dir, f"{model_name}/best_model/{trial_name}")
        os.makedirs(save_path, exist_ok=True)

        plot_metric(history['train_loss'], history['val_loss'], 'Loss', save_path)
        if task == 'regression':
            plot_metric(history['train_mse'], history['val_mse'], 'MSE', save_path)
            plot_metric(history['train_rmse'], history['val_rmse'], 'RMSE', save_path)
        else:
            plot_metric(history['train_accuracy'], history['val_accuracy'], 'Accuracy', save_path)
            plot_metric(history['train_f1'], history['val_f1'], 'F1', save_path)

        plot_learning_rate(history['learning_rate'], save_path)

        # Save metrics
        os.makedirs(save_path, exist_ok=True)
        with open(f'{save_path}/best_model_metrics.txt', 'w') as f:
            for key, value in metrics.items():
                f.write(f'{key}: {value}\n')

        # Save model
        torch.save(best_model, f"{save_path}/best_model.pth")
        print(f"Best model saved to {save_path}/best_model.pth")

        # Additional plots for classification
        if task in ["binary"]:
            plot_extra("Train", y_true_train, y_pred_train, y_prob_train, save_path)
            plot_extra("Validation", y_true_val, y_pred_val, y_prob_val, save_path)
            plot_extra("Test", y_true_test, y_pred_test, y_prob_test, save_path)

    del model
    torch.cuda.empty_cache()
    gc.collect()

    return metrics


def plot_extra(split_name, y_true, y_pred, y_prob, save_path):
    y_true = y_true.ravel()
    y_pred = y_pred.ravel()

    # ROC Curve
    RocCurveDisplay.from_predictions(y_true, y_prob)
    auc_score = roc_auc_score(y_true, y_prob)
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random')
    plt.title(f"{split_name} ROC Curve (AUC = {auc_score:.2f})")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f"{split_name.lower()}_roc_curve.png"))
    plt.close("all")

    # Precision-Recall Curve
    PrecisionRecallDisplay.from_predictions(y_true, y_prob)
    avg_prec = average_precision_score(y_true, y_prob)
    plt.title(f"{split_name} PR Curve (AP = {avg_prec:.2f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f"{split_name.lower()}_pr_curve.png"))
    plt.close("all")

    # Normalized confusion matrix
    ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize='true').plot(cmap='Blues')
    plt.title(f"{split_name} Confusion Matrix (Normalized)")
    plt.grid(False)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.savefig(os.path.join(save_path, f"{split_name.lower()}_confusion_matrix_normalized.png"))
    plt.close("all")

    # Raw confusion matrix
    ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize=None).plot(cmap='Blues')
    plt.title(f"{split_name} Confusion Matrix (Counts)")
    plt.grid(False)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.savefig(os.path.join(save_path, f"{split_name.lower()}_confusion_matrix_counts.png"))
    plt.close("all")


def plot_metric(train_metric, val_metric, metric_name, save_path):
    plt.figure()
    plt.plot(train_metric, label=f'Train {metric_name}')
    plt.plot(val_metric, label=f'Validation {metric_name}')
    plt.xlabel('Epoch')
    plt.ylabel(metric_name)
    plt.legend()
    plt.title(f'{metric_name} vs. Epoch')
    save_path = f"{save_path}/{metric_name.lower()}_plot.png"
    plt.savefig(save_path)
    plt.close("all")

def plot_learning_rate(learning_rates, save_path):
    plt.figure()
    plt.plot(learning_rates)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate vs. Epoch')
    save_path = f"{save_path}/learning_rate_plot.png"
    plt.savefig(save_path)
    plt.close("all")

# EXPERIMENTS

## MLP

In [None]:
save_dir =  os.path.join("logs", task_type, dataset_name)
model_name = "mlp"

# Load config
with open(f"./configs/preprocess/{dataset_name}.json") as f:
    config = json.load(f)

batch_size = config["batch_size"]
epochs = [100,200]
n_trials = 50

if task_type.lower() == 'multiclass':
    num_classes = df.iloc[:,-1].nunique()
else:
    num_classes = 1

device='cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
def objective(trial, model_name, task_type, 
              train_loader, val_loader, test_loader,
              attributes, num_classes=None,
              device='cuda', save_dir=None, class_weight=None, epochs=100):
    
    task = task_type.lower()
    
    params = load_search_space(model_name, trial)

    params["mlp_hidden_dims"] = json.loads(params["mlp_hidden_dims"])
    
    with open(f"configs/optuna_search/{model_name}.json", "r") as f:
        full_config = json.load(f)

    config = full_config[model_name]["fit"]  # Access the model key

    # Build and train model
    model = MLP(attributes, params, task, num_classes)
    metrics = compile_and_fit(
        model,
        train_loader, val_loader, test_loader,
        dataset_name=dataset_name,
        model_name=f"trial_{trial.number}",
        task=task,  # assumed to be defined externally
        max_lr=trial.suggest_float("max_lr", config["max_lr"][1], config["max_lr"][2], log=True),
        div_factor=trial.suggest_int("div_factor", config["div_factor"][1], config["div_factor"][2]),
        final_div_factor=trial.suggest_int("final_div_factor", config["final_div_factor"][1], config["final_div_factor"][2]),
        weight_decay=trial.suggest_float("weight_decay", config["weight_decay"][1], config["weight_decay"][2], log=True),
        epochs=trial.suggest_categorical("epochs", [100, 200]),
        save_model=False,
        class_weights=class_weight
    )

    save_dir = os.path.join(save_dir, model_name, "optuna")
    os.makedirs(save_dir, exist_ok=True)

    if task == 'regression':
        score = metrics["val_rmse"]
        with open(f"{save_dir}/optuna_trials_log.txt", "a") as f:
            f.write(f"Trial {trial.number} - VAL-RMSE: {score:.4f}, Params: {params}\n")
            f.write("=" * 60 + "\n")
    
    elif task == 'binary':
        score = metrics["val_roc_auc"]
        with open(f"{save_dir}/optuna_trials_log.txt", "a") as f:
            f.write(f"Trial {trial.number} - VAL-AUC: {score:.4f}, Params: {params}\n")
            f.write("=" * 60 + "\n")

    elif task == 'multiclass':
        score = metrics["val_accuracy"]
        with open(f"{save_dir}/optuna_trials_log.txt", "a") as f:
            f.write(f"Trial {trial.number} - VAL-Accuracy: {score:.4f}, Params: {params}\n")
            f.write("=" * 60 + "\n")
    else:
        raise ValueError(f"Unsupported task type: {task_type}")
    
    return score


In [None]:
def evaluate_best_model(best_trial, train_loader, val_loader, test_loader, 
                        dataset_name, task_type, save_dir, attributes, trial_name,
                        class_weight=None, num_classes=None, epochs=10):

    task = task_type.lower()
    best_params = best_trial.params

    print(f"\nBest Trial: {best_trial.number}")
    print(f"  Best Score: {best_trial.value:.4f}")
    print("  Best Hyperparameters:")
    for k, v in best_params.items():
        print(f"    {k}: {v}")

    # Extract architecture-related parameters
    architecture_params = {
        k: v for k, v in best_params.items()
        if k in ["mlp_hidden_dims"]
    }

    # Convert JSON string to list if necessary
    if isinstance(architecture_params.get("mlp_hidden_dims"), str):
        architecture_params["mlp_hidden_dims"] = json.loads(architecture_params["mlp_hidden_dims"])

    # Initialize model
    model = MLP(attributes, architecture_params, task, num_classes)

    # Train and evaluate
    metrics = compile_and_fit(
        model,
        train_loader, val_loader, test_loader,
        dataset_name=dataset_name,
        model_name=model_name,
        trial_name=f"trial_{best_trial.number}",
        task=task,
        max_lr=best_params["max_lr"],
        div_factor=best_params["div_factor"],
        final_div_factor=best_params["final_div_factor"],
        weight_decay=best_params["weight_decay"],
        epochs=best_params["epochs"],
        save_model=True,
        class_weights=class_weight,
        save_dir=save_dir
    )

    # Save best hyperparameters
    params_file = os.path.join(save_dir, f"{model_name}/best_model/{trial_name}", "best_params.json")
    os.makedirs(os.path.dirname(params_file), exist_ok=True)

    with open(params_file, "w") as f:
        json.dump(best_params, f, indent=4)

    return metrics

In [None]:
import random
import numpy as np
import torch

def set_model_seed(seed: int):
    # Python built-in RNG
    random.seed(seed)
    # NumPy RNG
    np.random.seed(seed)
    # Torch RNG
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you use multi-GPU
    
    # For reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


### EXPERIMENT

In [None]:
train_loader, val_loader, test_loader, attributes, label_encoder, class_weight  = load_and_preprocess_data(df, dataset_name, task_type, seed=SEED, batch_size=batch_size, device=device)

In [None]:
import optuna
study = optuna.create_study(direction="minimize" if task_type.lower() == "regression" else "maximize")
study.optimize(lambda trial: objective(
    trial=trial,
    model_name=model_name,
    task_type=task_type,
    num_classes=num_classes ,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    attributes=attributes,
    device=device,
    save_dir=save_dir,
    class_weight=class_weight,
    epochs=epochs
), n_trials=n_trials)

In [None]:
from numbers import Number

# --- Configure which seeds to use for stability reporting ---
model_seeds = [0, 1, 2, 3, 4]   # change as needed
numeric_keys = None  # we’ll infer from first run

# Determine study direction safely (minimize by default if unknown)
def is_minimize_study(study):
    try:
        return study.direction == optuna.study.StudyDirection.MINIMIZE
    except Exception:
        try:
            return study.directions[0] == optuna.study.StudyDirection.MINIMIZE
        except Exception:
            return True  # fallback

minimize = is_minimize_study(study)

# --- Pick the single best completed trial across ALL patch sizes ---
completed = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
if not completed:
    raise RuntimeError("No completed trials in the study.")

best_trial = (min if minimize else max)(completed, key=lambda t: t.value)


trial_name = f"trial_{best_trial.number}"
print(f"\nEvaluating overall best trial "
      f"(Trial {best_trial.number}, ValObjective: {best_trial.value:.4f}")
    
save_path = os.path.join(save_dir, f"{model_name}/best_model/{trial_name}")
os.makedirs(save_path, exist_ok=True)

per_seed_metrics = []

for s in model_seeds:
    set_model_seed(s)
    metrics = evaluate_best_model(
        best_trial,
        train_loader, val_loader, test_loader,
        dataset_name=dataset_name,
        task_type=task_type,
        save_dir=save_dir,
        attributes=attributes,
        class_weight=class_weight,
        num_classes=num_classes,
        epochs=epochs,
        trial_name=trial_name
        # If evaluate_best_model accepts a seed arg, pass model_seed=s
    )
    if not isinstance(metrics, dict):
        raise TypeError(f"evaluate_best_model must return dict, got: {type(metrics)}")

    # infer numeric keys once (ints, floats, numpy scalars)
    if numeric_keys is None:
        numeric_keys = [k for k, v in metrics.items()
                        if isinstance(v, (Number, np.floating, np.integer))]
    per_seed_metrics.append(metrics)

    # brief per-seed printout
    log_bits = []
    for k in ["test_loss", "val_loss", "train_loss"]:
        if k in metrics and isinstance(metrics[k], (Number, np.floating, np.integer)):
            log_bits.append(f"{k}={float(metrics[k]):.6f}")
    print(f"  Seed {s}: " + (", ".join(log_bits) if log_bits else str(metrics)))

# Aggregate mean/std per numeric key
aggregates = {}
for k in numeric_keys:
    vals = [float(m[k]) for m in per_seed_metrics]
    mean_k = float(np.mean(vals))
    std_k = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
    aggregates[k] = {"mean": mean_k, "std": std_k}

# Save YAML-like txt
out_file = os.path.join(save_path, "best_results_mean.txt")
with open(out_file, "w", encoding="utf-8") as f:
    f.write("# Overall best trial re-evaluation across model seeds\n")
    f.write(f"trial_number: {best_trial.number}\n")
    if model_name == "ViT_with_register_tokens":
        f.write(f"patch_size: {best_patch}\n")
    f.write(f"val_objective_best: {best_trial.value:.6f}\n")
    f.write(f"direction: {'minimize' if minimize else 'maximize'}\n")
    f.write(f"seeds: {model_seeds}\n")
    f.write("per_seed_metrics:\n")
    for s, m in zip(model_seeds, per_seed_metrics):
        f.write(f"  - seed: {s}\n")
        for k in numeric_keys:
            f.write(f"      {k}: {float(m[k]):.6f}\n")
    f.write("aggregates:\n")
    for k, mm in aggregates.items():
        f.write(f"  {k}:\n")
        f.write(f"    mean: {mm['mean']:.6f}\n")
        f.write(f"    std: {mm['std']:.6f}\n")

# Console summary
if "test_loss" in aggregates:
    print("  → test_loss Mean ± Std: "
          f"{aggregates['test_loss']['mean']:.6f} ± {aggregates['test_loss']['std']:.6f}")
elif "val_loss" in aggregates:
    print("  → val_loss Mean ± Std: "
          f"{aggregates['val_loss']['mean']:.6f} ± {aggregates['val_loss']['std']:.6f}")

print(f"Saved to: {out_file}")