In [None]:
!pip install wandb
!pip install pytorch-lightning
!pip install timm
!pip install pyro-ppl
!pip install zoobot

In [None]:
import gc
import logging
import os
import time

import numpy as np
import pytorch_lightning as pl
import torch
from google.colab import drive, userdata

logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s', force=True
)
logger = logging.getLogger(__name__)

import h5py
import huggingface_hub
import matplotlib.pyplot as plt
import pandas as pd
import timm
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import wandb
from IPython.display import clear_output
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from scipy import stats
from sklearn.model_selection import StratifiedKFold, train_test_split
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from tqdm import tqdm
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

### Color blind palette ###
colors = {
    'blue': '#377eb8',
    'orange': '#ff7f00',
    'green': '#4daf4a',
    'pink': '#f781bf',
    'brown': '#a65628',
    'purple': '#984ea3',
    'gray': '#999999',
    'red': '#e41a1c',
    'yellow': '#dede00',
}

SEED = 42
pl.seed_everything(SEED, workers=True)
torch.set_float32_matmul_precision('high')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.float32
EPSILON = 1e-8

# print(f"Current device: {torch.cuda.current_device()}")  # Shows index of current device
# print(f"Device name: {torch.cuda.get_device_name()}")    # Shows name of the GPU
# print(f"Available devices: {torch.cuda.device_count()}") # Shows number of available GPUs

# mount Google Drive
drive.mount('/content/drive')

### Functions

In [None]:
def load_data(data_path):
    """
    Loads data from an HDF5 file.

    Args:
        data_path (str): Path to the HDF5 file.

    Returns:
        tuple: (data, label)
    """
    # Load the Data
    with h5py.File(data_path, 'r') as f:
        # You can now access the data within the HDF5 file using the variable `f`
        label = f['label'][...].astype(np.float32)
        data = np.nan_to_num(f['images'][...], nan=0.0).astype(np.float32)
        obj_id = f['known_id'][...].astype(str)
    logger.info(f'Data loaded. Image shape: {data.shape}, Label shape: {label.shape}')
    return data, label, obj_id


def init_wandb(key_name='WANDB_API_KEY'):
    try:
        api_key = userdata.get(key_name)
        if not api_key:
            logger.error('WANDB_API_KEY not found in Colab Secrets. Please add it.')
            # Handle error - maybe raise an exception or skip W&B logging
            wandb_enabled = False
        else:
            wandb.login(key=api_key)
            logger.info('Successfully logged into W&B.')
            wandb_enabled = True
            # Finish any lingering runs from previous executions in the same session
            if wandb.run is not None:
                logger.warning('Detected an existing W&B run. Finishing it...')
                wandb.finish()
    except Exception as e:
        logger.error(f'Error during W&B login: {e}')
        wandb_enabled = False
    return wandb_enabled


def init_hf():
    hf_token = userdata.get('HF_TOKEN')
    try:
        logger.info('Attempting Hugging Face Hub login...')
        huggingface_hub.login(token=hf_token)
        logger.info('Hugging Face Hub login successful.')
    except Exception as e:
        logger.error(f'Error during Hugging Face Hub login: {e}')


def prepare_kfold_splits(labels, test_size=0.1, n_splits=10, num_strat_bins=5, seed=42):
    """Performs initial test split and sets up K-Fold generator."""
    num_samples = len(labels)
    all_indices = np.arange(num_samples)

    # Bin labels for stratification
    bin_edges = np.linspace(0, 1, num_strat_bins + 1)
    binned_labels = np.digitize(labels, bins=bin_edges[1:-1])
    logger.info(f'Labels binned into {num_strat_bins} bins for stratification.')
    logger.info(f'Bin counts (overall): {np.bincount(binned_labels, minlength=num_strat_bins)}')

    # Initial Train/Val vs. Test split
    train_val_indices, test_indices = train_test_split(
        all_indices, test_size=test_size, random_state=seed, stratify=binned_labels
    )
    binned_labels_train_val = binned_labels[train_val_indices]
    logger.info(f'Initial split: Train/Val={len(train_val_indices)}, Test={len(test_indices)}')
    logger.info(
        f'Bin counts (Test): {np.bincount(binned_labels[test_indices], minlength=num_strat_bins)}'
    )

    # K-Fold setup for the Train/Validation set
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    logger.info(f'StratifiedKFold initialized with {n_splits} splits.')

    # Return the KFold generator, the train/val indices, their binned labels, and test indices
    return (
        skf.split(np.zeros(len(train_val_indices)), binned_labels_train_val),
        train_val_indices,
        test_indices,
    )


def create_weighted_sampler(train_labels, num_weighting_bins=10, epsilon=1e-6, alpha=0.5):
    """
    Creates a WeightedRandomSampler for handling imbalance in continuous labels
    by weighting samples based on their label bin frequency.

    Args:
        train_labels (np.ndarray): Continuous labels for the training set only.
        num_weighting_bins (int): Number of bins to use for calculating weights.
        epsilon (float): Small value to add to bin counts to avoid division by zero.
        alpha (float): Dampening exponent (0 < alpha <= 1). 1=full inverse, <1 less aggressive.

    Returns:
        torch.utils.data.WeightedRandomSampler: Sampler instance for the training DataLoader.
    """
    num_train_samples = len(train_labels)

    # Calculate histogram of training labels
    hist, bin_edges = np.histogram(train_labels, bins=num_weighting_bins, range=(0, 1))

    # Calculate weight per bin (inverse frequency)
    counts_dampened = np.power(hist + epsilon, alpha)
    weights_per_bin = 1.0 / counts_dampened
    weights_per_bin = weights_per_bin / np.sum(weights_per_bin)  # normalize
    logger.info(f'Sampler bin weights: {weights_per_bin}')

    # Find which bin each training label belongs to
    # Ensure labels exactly equal to 1.0 fall into the last bin correctly
    # Using right=True includes the right edge, then clip ensures index stays within bounds
    bin_indices_train = np.digitize(train_labels, bin_edges[1:], right=True)

    # Assign weight to each sample based on its bin
    sample_weights = weights_per_bin[bin_indices_train]

    # Convert weights to a torch Tensor
    train_sample_weights_tensor = torch.DoubleTensor(
        sample_weights
    )  # PyTorch expects DoubleTensor for weights

    # Create the sampler
    train_sampler = WeightedRandomSampler(
        weights=train_sample_weights_tensor, num_samples=num_train_samples, replacement=True
    )
    return train_sampler


class SimpleDataset(Dataset):
    def __init__(self, inputs, labels, transform=None, preprocess=None):
        """
        Args:
            inputs (list or ndarray): The input features.
            labels (list or ndarray): The labels corresponding to the inputs.
            transform (callable, optional): Optional transform to be applied on a sample.
            preprocess (callable, optional): Optional preprocessing to be applied on a sample.
        """

        # If there's a preprocessing function, apply it
        if preprocess is not None:
            inputs = preprocess(inputs)

        self.inputs = inputs
        self.labels = labels
        self.transform = transform
        self.preprocess = preprocess

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        image = torch.tensor(self.inputs[idx], dtype=DTYPE)
        label = torch.tensor(self.labels[idx], dtype=DTYPE).squeeze()  # Ensure labels are 1D

        if self.transform is not None:
            image = self.transform(image)

        return image, label


### Lightning Module ###
class ZooBot_lightning(pl.LightningModule):
    def __init__(
        self,
        zoobot_size,
        zoobot_blocks,
        learning_rate,
        learning_decay,
        weight_decay,
        label_smoothing,
        loss_type='kld',
        focal_gamma=2.0,
    ):
        super(ZooBot_lightning, self).__init__()
        self.save_hyperparameters()  # Saves all arguments for checkpointing

        if self.hparams.loss_type not in ['focal', 'kld']:
            raise ValueError(
                f"Invalid loss_type: {self.hparams.loss_type}. Choose 'focal' or 'kld'."
            )

        # Define the model
        self.model = FinetuneableZoobotClassifier(
            name=f'hf_hub:mwalmsley/zoobot-encoder-convnext_{zoobot_size}',
            n_blocks=zoobot_blocks,  # Finetune this many blocks.x
            learning_rate=learning_rate,  # use a low learning rate
            lr_decay=learning_decay,  # reduce the learning rate from lr to lr^0.5 for deeper blocks
            num_classes=2,  # Number of output classes
        )

        self.train_step_outputs = []
        self.valid_step_outputs = []

        self.validation_outputs = []
        self.validation_targets = []

    def forward(self, x):
        x = self.model(x)
        return x

    def _compute_focal_loss(self, probabilities, targets):
        """
        Computes the Focal loss.
        Args:
            probabilities: Predicted probabilities (shape: [batch_size, 2]).
            targets: True labels (shape: [batch_size]).
        Returns:
            Mean Focal loss for the batch.
        """
        gamma = self.hparams.focal_gamma
        p_pred = probabilities[:, 1]
        p_pred_clamped = p_pred.clamp(min=EPSILON, max=1.0 - EPSILON)
        loss_positive = -targets * torch.pow(1.0 - p_pred, gamma) * torch.log(p_pred_clamped)
        loss_negative = (
            -(1.0 - targets) * torch.pow(p_pred, gamma) * torch.log(1.0 - p_pred_clamped)
        )
        loss_unreduced = loss_positive + loss_negative
        return loss_unreduced.mean()

    def _compute_kld_loss(self, logits, targets):
        """
        Computes the KL Divergence loss.
        Args:
            logits: Raw logits output from the model (shape: [batch_size, 2]).
            targets: Soft labels (probabilities for positive class) (shape: [batch_size]).
        Returns:
            Mean KL divergence loss for the batch.
        """
        # Create the target probability distribution [P(class 0), P(class 1)]
        target_dist = torch.stack([1.0 - targets, targets], dim=1)

        # Calculate log probabilities from logits
        log_probabilities = F.log_softmax(logits, dim=1)

        # Calculate KL divergence
        # Use log_target=False because target_dist contains probabilities, not log-probabilities
        loss = F.kl_div(log_probabilities, target_dist, reduction='batchmean', log_target=False)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch

        p_true = self.apply_label_smoothing(labels, alpha=self.hparams.label_smoothing)

        # Get model outputs
        logits = self(images)  # Logits shape [batch, 2]

        if self.hparams.loss_type == 'focal':
            probabilities = F.softmax(logits, dim=1)
            loss = self._compute_focal_loss(probabilities, p_true)
        elif self.hparams.loss_type == 'kld':
            loss = self._compute_kld_loss(logits, p_true)
        else:
            # This should not happen due to check in __init__ but good practice
            raise ValueError(f'Invalid loss_type specified: {self.hparams.loss_type}')

        self.train_step_outputs.append(loss)
        return loss

    def apply_label_smoothing(self, labels, alpha=None):
        """Apply label smoothing with strength alpha."""
        # Move probabilities slightly toward 0.5
        if alpha is None or alpha == 0.0:
            return labels
        else:
            return labels * (1 - alpha) + 0.5 * alpha

    def expected_calibration_error(self, preds, soft_labels, n_bins=5):
        """
        Calculate ECE for soft labels.

        Args:
            preds: Predicted probabilities (tensor)
            soft_labels: Soft label probabilities (tensor)
            n_bins: Number of bins to use

        Returns:
            ece: The Expected Calibration Error
            bin_data: Dictionary with detailed bin information
        """
        # Convert to numpy for easier manipulation
        preds_np = preds.detach().cpu().numpy()
        labels_np = soft_labels.detach().cpu().numpy()

        # Create equal-width bins across prediction range
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_indices = np.digitize(preds_np, bin_boundaries) - 1
        bin_indices = np.clip(bin_indices, 0, n_bins - 1)  # Handle edge cases

        # Initialize arrays for bin statistics
        bin_sizes = np.zeros(n_bins)
        bin_avg_preds = np.zeros(n_bins)
        bin_avg_labels = np.zeros(n_bins)

        # Calculate statistics for each bin
        for bin_idx in range(n_bins):
            mask = bin_indices == bin_idx
            bin_sizes[bin_idx] = mask.sum()

            if bin_sizes[bin_idx] > 0:
                bin_avg_preds[bin_idx] = preds_np[mask].mean()
                bin_avg_labels[bin_idx] = labels_np[mask].mean()

        # Calculate ECE with proper weighting
        total_samples = len(preds_np)
        ece = np.sum((bin_sizes / total_samples) * np.abs(bin_avg_preds - bin_avg_labels))

        # Return ECE and bin data for visualizations
        bin_data = {
            'sizes': bin_sizes,
            'avg_preds': bin_avg_preds,
            'avg_labels': bin_avg_labels,
            'boundaries': bin_boundaries,
        }

        return ece, bin_data

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)

        # Calculate probabilities
        probabilities = F.softmax(logits, dim=1)
        positive_class_probs = probabilities[:, 1]  # Dwarf probability

        # 1. Calculate the loss
        if self.hparams.loss_type == 'focal':
            # Use original labels (not smoothed) for validation loss
            val_loss = self._compute_focal_loss(probabilities, labels)
        elif self.hparams.loss_type == 'kld':
            # Use original labels (not smoothed) for validation loss
            val_loss = self._compute_kld_loss(logits, labels)
        else:
            raise ValueError(f'Invalid loss_type specified: {self.hparams.loss_type}')

        # 2. Brier Score (MSE)
        brier_score = F.mse_loss(positive_class_probs, labels)

        # 3. Calculate soft-label ECE
        ece, bin_data = self.expected_calibration_error(positive_class_probs, labels)

        # Store bin data for visualization
        if batch_idx == 0:  # Only store once per epoch
            self.bin_data = bin_data

        # Store outputs and labels for visualization
        self.validation_outputs.append(positive_class_probs.detach().cpu().numpy())
        self.validation_targets.append(labels.detach().cpu().numpy())

        # Store the main loss for epoch end calculations
        self.valid_step_outputs.append(val_loss.item())

        # Log all metrics
        self.log_dict(
            {
                'valid_loss': val_loss,
                'val_brier': brier_score,
                'val_calibration': ece,
            },
            prog_bar=True,
        )

        return val_loss

    def on_train_epoch_end(self):
        # train_loss = sum(self.train_step_outputs) / len(self.train_step_outputs)
        train_loss = torch.stack(self.train_step_outputs).mean()
        self.log('train_loss', train_loss)
        self.train_step_outputs.clear()

    def on_validation_epoch_end(self):
        valid_loss = sum(self.valid_step_outputs) / len(self.valid_step_outputs)
        self.log('valid_loss', valid_loss)
        self.valid_step_outputs.clear()

        # Concatenate stored outputs and targets
        outputs = np.concatenate(self.validation_outputs)
        targets = np.concatenate(self.validation_targets)

        # 1. Predictions vs True Values
        fig1 = plt.figure(figsize=(6, 6))
        h = plt.hist2d(
            targets, outputs, bins=[25, 25], range=[[0, 1], [0, 1]], cmap='viridis', norm='log'
        )
        plt.plot([0, 1], [0, 1], c='r', linestyle='--')
        plt.axhline(targets.mean(), linestyle='--', color='b')
        plt.colorbar(h[3])
        plt.ylabel('Model Predictions')
        plt.xlabel('Expert Classifications')
        plt.title('Prediction vs Truth Distribution')
        plt.tight_layout()
        self.logger.experiment.log({'Prediction Distribution': wandb.Image(fig1)})
        plt.close(fig1)

        # 2. Reliability (Calibration) Curve
        n_bins = 5  # Number of bins for reliability plot
        bin_edges = np.linspace(0, 1, n_bins + 1)

        # Bin data based on PREDICTED probabilities (outputs)
        bin_indices = np.digitize(
            outputs, bin_edges[1:], right=False
        )  # bin_edges[1:] -> bins [0, n_bins-1]

        mean_predicted = np.zeros(n_bins)
        mean_true = np.zeros(n_bins)
        counts = np.zeros(n_bins, dtype=int)

        for i in range(n_bins):
            bin_mask = bin_indices == i
            counts[i] = np.sum(bin_mask)
            if counts[i] > 0:
                mean_predicted[i] = outputs[bin_mask].mean()
                mean_true[i] = targets[bin_mask].mean()
            # else: means remain 0, counts remain 0

        # Filter out bins with zero counts to avoid plotting artifacts
        valid_bins_mask = counts > 0
        mean_predicted_valid = mean_predicted[valid_bins_mask]
        mean_true_valid = mean_true[valid_bins_mask]
        counts_valid = counts[valid_bins_mask]

        # Create the plot
        fig2 = plt.figure(figsize=(6, 6))
        plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')  # Diagonal line

        if len(mean_predicted_valid) > 0:  # Only plot if there's valid data
            # Plot with conventional axes: Prediction (Confidence) on X, Truth (Accuracy) on Y
            plt.plot(mean_predicted_valid, mean_true_valid, 'bo-', label='Model Calibration')

            # Add sample count annotations
            for i in range(len(mean_predicted_valid)):
                # Adjust text position slightly for clarity
                plt.text(
                    mean_predicted_valid[i],
                    mean_true_valid[i] + 0.02,
                    f'n={counts_valid[i]}',
                    ha='center',
                    va='bottom',
                    fontsize=8,
                )
        else:
            plt.text(0.5, 0.5, 'No data in validation bins', ha='center', va='center')

        plt.xlabel('Mean Predicted Probability (Confidence)')  # Conventional X-axis label
        plt.ylabel('Mean True Label (Accuracy)')  # Conventional Y-axis label
        plt.title('Reliability Curve')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.xlim([-0.05, 1.05])  # Add padding
        plt.ylim([-0.05, 1.05])
        plt.tight_layout()
        self.logger.experiment.log({'Reliability Curve': wandb.Image(fig2)})
        plt.close(fig2)

        if hasattr(self, 'bin_data') and self.bin_data:
            avg_preds = self.bin_data['avg_preds']
            avg_labels = self.bin_data['avg_labels']
            sizes = self.bin_data['sizes']
            n_bins_bar = len(avg_preds)  # Get number of bins from data

            fig3 = plt.figure(figsize=(7, 6))  # Slightly wider for annotations
            bar_width = 0.35
            bin_indices_bar = np.arange(n_bins_bar)

            # Bars for average true labels
            plt.bar(
                bin_indices_bar - bar_width / 2,
                avg_labels,
                bar_width,
                label='Avg Label (Truth)',
                color=colors['blue'],
            )
            # Bars for average predictions
            plt.bar(
                bin_indices_bar + bar_width / 2,
                avg_preds,
                bar_width,
                label='Avg Prediction',
                color=colors['orange'],
            )

            # Add sample count annotations above bars
            for i in range(n_bins_bar):
                if sizes[i] > 0:  # Only annotate if count > 0
                    # Position text above the taller bar
                    y_pos = max(avg_labels[i], avg_preds[i]) + 0.02
                    plt.text(
                        bin_indices_bar[i],
                        y_pos,
                        f'n={int(sizes[i])}',
                        ha='center',
                        va='bottom',
                        fontsize=8,
                    )

            # Add ideal calibration line (y=x equivalent for bins)
            # Get bin boundaries for x-axis labels
            bin_boundaries = self.bin_data.get('boundaries', np.linspace(0, 1, n_bins_bar + 1))
            bin_centers_approx = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
            plt.plot(
                bin_indices_bar, bin_centers_approx, 'r--', label='Perfect Calibration'
            )  # Plot against bin index

            # Get ECE value from logged metrics if available
            ece_metric_name = 'val_calibration'  # Make sure this matches the name used in validation_step log_dict
            ece_value = self.trainer.callback_metrics.get(ece_metric_name, None)
            title = 'Calibration Bar Chart'
            if ece_value is not None:
                title += f' (ECE = {float(ece_value):.4f})'  # Ensure ECE is float

            plt.xlabel('Prediction Bin Index')
            # Use bin boundaries for clearer x-axis ticks
            tick_labels = [
                f'{bin_boundaries[i]:.1f}-{bin_boundaries[i+1]:.1f}' for i in range(n_bins_bar)
            ]
            plt.xticks(ticks=bin_indices_bar, labels=tick_labels, rotation=45, ha='right')
            plt.ylabel('Probability')
            plt.ylim([-0.05, 1.05])
            plt.title(title)
            plt.legend()
            plt.grid(True, axis='y', linestyle='--', alpha=0.6)
            plt.tight_layout()
            self.logger.experiment.log({'Calibration': wandb.Image(fig3)})
            plt.close(fig3)

        # 3. Error Distribution
        fig4 = plt.figure(figsize=(6, 6))
        errors = outputs - targets
        plt.hist(errors, bins=50, density=True)
        plt.xlabel('Prediction Error')
        plt.ylabel('Density')
        plt.title('Error Distribution')
        plt.tight_layout()
        self.logger.experiment.log({'Error Distribution': wandb.Image(fig4)})
        plt.close(fig4)

        # Clear stored data
        self.validation_outputs.clear()
        self.validation_targets.clear()

    def configure_optimizers(self):
        """
        Configures the optimizer (AdamW) and LR scheduler (ReduceLROnPlateau).

        Selects parameters for optimization based on self.hparams.zoobot_blocks,
        applying learning rate decay (self.hparams.learning_decay) to deeper blocks,
        mimicking the logic from FinetuneableZoobotAbstract.
        """
        # Retrieve hyperparameters
        try:
            lr = self.hparams.learning_rate
            lr_decay_factor = self.hparams.learning_decay
            num_blocks_to_tune = self.hparams.zoobot_blocks
            # Use weight_decay from hparams
            weight_decay = self.hparams.weight_decay
        except AttributeError as e:
            logger.error(
                f'Optimizer config failed: Missing hyperparameter. Ensure learning_rate, learning_decay, zoobot_blocks, (and optionally weight_decay) are saved. Error: {e}'
            )
            raise

        # Start parameter groups: always include the head (no LR decay)
        # Ensure self.model.head exists (it should be created by FinetuneableZoobotClassifier)
        if not hasattr(self.model, 'head'):
            raise AttributeError(
                "self.model does not have a 'head' attribute. Ensure FinetuneableZoobotClassifier initialization was successful."
            )

        params_to_optimize = [{'params': self.model.head.parameters(), 'lr': lr}]
        logger.info(f'Opt: Initializing Optimizer. Base LR: {lr}, Weight Decay: {weight_decay}')
        logger.info(f'Opt: Head parameters included with LR {lr}')

        if num_blocks_to_tune > 0:
            logger.info(f'Opt: Fine-tuning last {num_blocks_to_tune} encoder blocks/stages.')
            logger.info(f'Opt: Encoder architecture: {type(self.model.encoder).__name__}')

            # --- Parameter Group Selection Logic (adapted from FinetuneableZoobotAbstract) ---
            if isinstance(self.model.encoder, timm.models.ConvNeXt):
                # For ConvNeXt: stem + 4 stages
                tuneable_blocks_or_stages = [self.model.encoder.stem] + list(
                    self.model.encoder.stages
                )
                logger.info(
                    f'Opt: Identified {len(tuneable_blocks_or_stages)} tuneable blocks/stages for ConvNeXt (stem + stages).'
                )
            else:
                raise ValueError(
                    f'Opt: Encoder architecture {type(self.model.encoder).__name__} not explicitly handled in custom configure_optimizers.'
                )

            if num_blocks_to_tune > len(tuneable_blocks_or_stages):
                logger.info(
                    f'Opt: Requested {num_blocks_to_tune} blocks, but only {len(tuneable_blocks_or_stages)} available. Tuning all available.'
                )
                num_blocks_to_tune = len(tuneable_blocks_or_stages)

            # Reverse to order from last layer (highest index) to first
            tuneable_blocks_or_stages.reverse()
            blocks_to_tune = tuneable_blocks_or_stages[:num_blocks_to_tune]

            # Add parameter groups for encoder blocks with decayed LR
            for i, block in enumerate(blocks_to_tune):
                block_lr = lr * (
                    lr_decay_factor**i
                )  # Apply decay based on depth (i=0 is last block)
                block_params = list(block.parameters())
                if not block_params:
                    logger.info(
                        f'Opt: Block {i} (type {type(block).__name__}) has no parameters. Skipping.'
                    )
                    continue
                params_to_optimize.append({'params': block_params, 'lr': block_lr})
                logger.info(
                    f'Opt: Including block {i} (type {type(block).__name__}) with LR {block_lr:.2e}'
                )
            # --- End of Parameter Group Selection Logic ---
        else:
            logger.info('Opt: num_blocks_to_tune is 0. Only training the head.')

        logger.info(f'Opt: Total parameter groups for optimizer: {len(params_to_optimize)}')

        # 1. Define your chosen optimizer (AdamW)
        optimizer = optim.AdamW(
            params_to_optimize,
            lr=lr,  # Base LR is default for AdamW, but groups override it
            weight_decay=weight_decay,
        )

        # 2. Define your chosen scheduler (ReduceLROnPlateau)
        # Monitoring 'valid_loss' which you log in validation_step
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',  # Minimize the monitored metric (loss)
            factor=0.75,  # Reduce LR by half when plateaued
            patience=5,  # Number of epochs with no improvement to wait
            min_lr=1e-6,  # Minimum learning rate
        )

        # 3. Return configuration for PyTorch Lightning
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'valid_loss',  # Metric to monitor for scheduler
                'interval': 'epoch',  # Check metric at the end of each epoch
                'frequency': 1,  # Check every 1 epoch
            },
        }


class AddBimodalNoise(object):
    def __init__(self, means, stds=None, weights=None, noise_scale=0.1):
        """
        Args:
            means: List of two means [mean1, mean2]
            stds: List of two stds [std1, std2]
            weights: List of two weights [w1, w2]
            noise_scale: Scale factor for noise (default 0.1)
        """
        self.means = torch.tensor(means)
        self.stds = torch.tensor(stds if stds else [0.1, 0.1])
        self.weights = torch.tensor(weights if weights else [0.5, 0.5])
        self.noise_scale = noise_scale

    def __call__(self, tensor):
        """
        Add noise while preserving bimodal distribution.
        """
        # Calculate distances to each mode
        dist1 = torch.abs(tensor - self.means[0])
        dist2 = torch.abs(tensor - self.means[1])

        # Simple mask based on which mode is closer
        mask = dist1 < dist2

        # Generate noise for each mode
        noise1 = torch.randn_like(tensor) * self.stds[0] * self.noise_scale
        noise2 = torch.randn_like(tensor) * self.stds[1] * self.noise_scale

        # Apply noise based on mask
        noise = torch.where(mask, noise1, noise2)

        # Ensure no NaN or inf values
        noise = torch.nan_to_num(noise, 0.0)

        return torch.clamp(tensor + noise, -1.0, 1.0)


def load_model(model_path):
    # 1. get hyperparameters
    checkpoint = torch.load(model_path, map_location=DEVICE)
    hparams = checkpoint['hyper_parameters']

    # 2. Load model with those hyperparameters
    checkpoint_model = ZooBot_lightning.load_from_checkpoint(checkpoint_path=model_path, **hparams)

    # 3. Freeze all layers and set to eval mode
    checkpoint_model.freeze()  # Freeze all layers (equivalent to requires_grad=False)
    checkpoint_model.eval()  # Disables dropout, batchnorm updates

    # 4. Device placement
    checkpoint_model = checkpoint_model.to(DEVICE)  # After freeze()/eval()

    return checkpoint_model


class SimpleProgressCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        current_epoch = trainer.current_epoch
        max_epochs = trainer.max_epochs
        logger.info(f'Completed epoch {current_epoch+1}/{max_epochs}')
        # Force display update
        clear_output(wait=True)

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.sanity_checking:
            return
        logger.info(f"Validation loss: {float(trainer.callback_metrics.get('valid_loss', 0)):.4f}")


def train_fold(
    fold_idx,
    train_fold_indices,
    val_fold_indices,
    all_data,
    all_labels,
    transform,
    wandb_enabled,
    with_weighted_sampler=False,
):
    """Trains the model for a single fold."""
    fold_start_time = time.time()
    logger.info(f'\n===== Starting Fold {fold_idx}/{N_SPLITS} =====')

    # 1. Get Data for Fold
    X_train_fold = all_data[train_fold_indices]
    y_train_fold = all_labels[train_fold_indices]
    X_val_fold = all_data[val_fold_indices]
    y_val_fold = all_labels[val_fold_indices]
    logger.info(f'Fold {fold_idx}: Train size={len(X_train_fold)}, Val size={len(X_val_fold)}')

    try:
        if 'NUM_STRATIFICATION_BINS' not in globals():
            raise NameError('NUM_STRATIFICATION_BINS not found in global scope.')

        # Calculate bin edges (must match the ones used in prepare_kfold_splits)
        bin_edges = np.linspace(0, 1, NUM_STRATIFICATION_BINS + 1)

        # Calculate bin indices for training labels of this fold
        # Use bins=bin_edges[1:-1] to match the definition in prepare_kfold_splits
        train_bin_indices = np.digitize(y_train_fold, bins=bin_edges[1:-1])
        # Count samples per bin, ensuring all bins are represented (even if empty)
        train_bin_counts = np.bincount(train_bin_indices, minlength=NUM_STRATIFICATION_BINS)

        # Calculate bin indices for validation labels of this fold
        val_bin_indices = np.digitize(y_val_fold, bins=bin_edges[1:-1])
        # Count samples per bin
        val_bin_counts = np.bincount(val_bin_indices, minlength=NUM_STRATIFICATION_BINS)

        # Log the counts
        logger.info(f'Fold {fold_idx}: Stratification Bin Edges: {np.round(bin_edges, 3)}')
        logger.info(
            f'Fold {fold_idx}: **Train Set** Bin Counts (Bins 0 to {NUM_STRATIFICATION_BINS-1}): {train_bin_counts}'
        )
        logger.info(
            f'Fold {fold_idx}: **Validation Set** Bin Counts (Bins 0 to {NUM_STRATIFICATION_BINS-1}): {val_bin_counts}'
        )

    except Exception as e:
        logger.error(f'Fold {fold_idx}: Failed to calculate or log bin counts. Error: {e}')

    # 2. Create Datasets
    train_dataset_fold = SimpleDataset(X_train_fold, y_train_fold, transform=transform)
    # No augmentation for validation set
    val_dataset_fold = SimpleDataset(X_val_fold, y_val_fold, transform=None)

    # 3. Create Weighted Sampler
    if with_weighted_sampler:
        train_sampler_fold = create_weighted_sampler(
            y_train_fold,
            num_weighting_bins=NUM_STRATIFICATION_BINS,
            alpha=SAMPLER_ALPHA,
        )
        logger.info(f'Fold {fold_idx}: Created weighted sampler for training data.')

    # 4. Create DataLoaders
    train_loader_fold = DataLoader(
        train_dataset_fold,
        batch_size=BATCH_SIZE,
        sampler=train_sampler_fold if with_weighted_sampler else None,
        shuffle=False if with_weighted_sampler else True,
        num_workers=NUM_WORKERS,
        persistent_workers=True if NUM_WORKERS > 0 else False,
        pin_memory=True,
    )
    val_loader_fold = DataLoader(
        val_dataset_fold,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        persistent_workers=True if NUM_WORKERS > 0 else False,
        pin_memory=True,
    )

    # 5. Initialize Model, Trainer, Logger, Callbacks
    model = ZooBot_lightning(
        zoobot_size=ZOOBOT_SIZE,
        zoobot_blocks=ZOOBOT_BLOCKS,
        learning_rate=LEARNING_RATE,
        learning_decay=LEARNING_DECAY,
        weight_decay=WEIGHT_DECAY,
        label_smoothing=LABEL_SMOOTHING,
        loss_type=LOSS_TYPE,
        focal_gamma=FOCAL_GAMMA,
    )

    # W&B Logger - log each fold as a separate run within the same project
    if wandb_enabled:
        run_name = (
            f'{ZOOBOT_SIZE}_fold{fold_idx}_lr{LEARNING_RATE}_wd{WEIGHT_DECAY}_{LOSS_TYPE}-loss'
        )
        wandb_logger = WandbLogger(
            project=WANDB_PROJECT,
            group=WANDB_GROUP,
            name=run_name,
        )
        logger.info(
            f"Fold {fold_idx}: Initialized W&B logger for run '{run_name}' in group '{WANDB_GROUP}'."
        )
    else:
        logger.error('W&B logging is disabled. Exiting.')
        return

    # Checkpoint Callback - save the best model for this fold
    fold_save_dir = os.path.join(BASE_SAVE_DIR, f'fold_{fold_idx}')
    os.makedirs(fold_save_dir, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        monitor='valid_loss',
        dirpath=fold_save_dir,
        filename='best_model-{epoch:02d}-{valid_loss:.4f}',
        save_top_k=1,
        mode='min',
    )

    # Early Stopping Callback
    early_stop_callback = EarlyStopping(
        monitor='valid_loss', patience=EARLY_STOPPING_PATIENCE, verbose=True, mode='min'
    )

    # Trainer
    trainer = Trainer(
        logger=wandb_logger,
        callbacks=[checkpoint_callback, early_stop_callback, SimpleProgressCallback()],
        max_epochs=EPOCHS,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        log_every_n_steps=len(train_loader_fold),
        enable_progress_bar=False,
    )

    # 6. Train the Model
    try:
        logger.info(f'Fold {fold_idx}: Starting training...')
        trainer.fit(model, train_loader_fold, val_loader_fold)
        logger.info(f'Fold {fold_idx}: Training finished.')

        # 7. Store Best Checkpoint Path
        best_model_path = checkpoint_callback.best_model_path
        logger.info(f'Fold {fold_idx}: Best model saved at: {best_model_path}')
        logger.info(
            f'Fold {fold_idx}: Best validation loss: {checkpoint_callback.best_model_score:.4f}'
        )
    except Exception as e:
        logger.error(f'Fold {fold_idx}: Training failed with error: {e}')
        best_model_path = None  # Indicate failure
    finally:
        # 8. Clean Up
        wandb.finish()  # Finish the W&B run for this fold
        del model, trainer, train_loader_fold, val_loader_fold, train_dataset_fold, val_dataset_fold
        if with_weighted_sampler:
            del train_sampler_fold
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logger.info(f'Fold {fold_idx}: Resources cleaned up.')
        fold_end_time = time.time()
        logger.info(
            f'===== Fold {fold_idx} Duration: {(fold_end_time - fold_start_time) / 60:.2f} minutes ====='
        )

    return best_model_path

### Load the data

In [None]:
PATH_DATA = 'drive/MyDrive/training_data/h5/training_data_v2.h5'
DATA, LABEL, OBJ_ID = load_data(PATH_DATA)

### Initialize HF and W&B

In [None]:
# initialize huggingface
init_hf()
# initialize W&B
wandb_enabled = init_wandb(key_name='WANDB_API_KEY_KFOLD')

### Configuration

In [None]:
# Data
TEST_SIZE = 0.1
VAL_SIZE = 0.1
N_SPLITS = 10  # K for K-Fold
NUM_STRATIFICATION_BINS = 5
NOISE_MEAN = [-0.312967269, 0.30682983]
NOISE_STD = [0.079292069, 0.08915830]
NOISE_WEIGHT = [0.51760047, 0.482399529]
NOISE_SCALE = 0.3

# W&B
WANDB_PROJECT = 'unions_dwarfs_kfold'
WANDB_GROUP = f"kfold_run_{time.strftime('%Y%m%d_%H%M%S')}"

# Hyperparameters
LABEL_SMOOTHING = 0.01
LEARNING_RATE = 5e-5
LEARNING_DECAY = 0.75
WEIGHT_DECAY = 0.05
LOSS_TYPE = 'kld'
FOCAL_GAMMA = None
ZOOBOT_SIZE = 'nano'  # pico, nano, tiny, small, base, large
ZOOBOT_BLOCKS = 5
WITH_WEIGHTED_SAMPLER = False
SAMPLER_ALPHA = (
    0.5  # Controls oversampling scale, [0, 1], higher means more aggressive oversampling
)

# Training
NUM_WORKERS = 2
BATCH_SIZE = 128
EPOCHS = 100
EARLY_STOPPING_PATIENCE = 20

# Model
BASE_SAVE_DIR = 'drive/MyDrive/zoobot_models/kfold'

# 1. Define augmentations
transform = transforms.Compose(
    [
        transforms.RandomChoice(
            [
                transforms.RandomRotation((0, 0)),
                transforms.RandomRotation((90, 90)),
                transforms.RandomRotation((180, 180)),
                transforms.RandomRotation((270, 270)),
            ]
        ),  # Random rotation by multiples of 90 degrees
        transforms.RandomHorizontalFlip(),  # Random horizontal flip
        transforms.RandomVerticalFlip(),  # Random vertical flip
        transforms.RandomApply([transforms.ColorJitter(brightness=(0.9, 1.1))], p=0.6),
        transforms.RandomApply([transforms.ColorJitter(contrast=(0.9, 1.1))], p=0.4),
        AddBimodalNoise(
            means=NOISE_MEAN, stds=NOISE_STD, weights=NOISE_WEIGHT, noise_scale=NOISE_SCALE
        ),
    ]
)

# 2. Prepare test and k-fold splits
kfold_splitter, train_val_indices, test_indices = prepare_kfold_splits(
    labels=LABEL,
    test_size=TEST_SIZE,
    n_splits=N_SPLITS,
    num_strat_bins=NUM_STRATIFICATION_BINS,
    seed=SEED,
)

### Train K-fold ###

In [None]:
best_fold_model_paths = {}

for fold_idx, (train_fold_indices_rel, val_fold_indices_rel) in enumerate(kfold_splitter):
    # Get actual data indices from fold indices
    train_fold_indices_abs = train_val_indices[train_fold_indices_rel]
    val_fold_indices_abs = train_val_indices[val_fold_indices_rel]

    # Train the current fold
    best_path = train_fold(
        fold_idx=fold_idx,
        train_fold_indices=train_fold_indices_abs,
        val_fold_indices=val_fold_indices_abs,
        all_data=DATA,
        all_labels=LABEL,
        transform=transform,
        wandb_enabled=wandb_enabled,
        with_weighted_sampler=WITH_WEIGHTED_SAMPLER,
    )
    # Append the best model path for this fold to the best path dict
    if best_path:
        best_fold_model_paths[fold_idx] = best_path
    else:
        logger.warning(f'Fold {fold_idx} did not produce a best model path.')

logger.info('\n===== K-Fold Training Complete =====')
logger.info(f'Best model paths per fold: {best_fold_model_paths}')

In [None]:
best_fold_model_paths = {
    0: '/content/drive/MyDrive/zoobot_models/kfold/fold_0/best_model-epoch=07-valid_loss=0.0312.ckpt',
    1: '/content/drive/MyDrive/zoobot_models/kfold/fold_1/best_model-epoch=39-valid_loss=0.0322.ckpt',
    2: '/content/drive/MyDrive/zoobot_models/kfold/fold_2/best_model-epoch=14-valid_loss=0.0378.ckpt',
    3: '/content/drive/MyDrive/zoobot_models/kfold/fold_3/best_model-epoch=75-valid_loss=0.0292.ckpt',
    4: '/content/drive/MyDrive/zoobot_models/kfold/fold_4/best_model-epoch=14-valid_loss=0.0272.ckpt',
    5: '/content/drive/MyDrive/zoobot_models/kfold/fold_5/best_model-epoch=19-valid_loss=0.0244.ckpt',
    6: '/content/drive/MyDrive/zoobot_models/kfold/fold_6/best_model-epoch=33-valid_loss=0.0425.ckpt',
    7: '/content/drive/MyDrive/zoobot_models/kfold/fold_7/best_model-epoch=21-valid_loss=0.0315.ckpt',
    8: '/content/drive/MyDrive/zoobot_models/kfold/fold_8/best_model-epoch=45-valid_loss=0.0265.ckpt',
    9: '/content/drive/MyDrive/zoobot_models/kfold/fold_9/best_model-epoch=41-valid_loss=0.0266.ckpt',
}

### Evaluate ###

In [None]:
X_test = DATA[test_indices]
y_test = LABEL[test_indices]
id_test = OBJ_ID[test_indices]

In [None]:
len(np.unique(y_test))

In [None]:
def evaluate_kfold_ensemble(
    best_fold_model_paths, test_indices, all_data, all_labels, batch_size=64
):
    """
    Evaluates an ensemble of models from K-fold training on the test set.

    Args:
        best_fold_model_paths (dict): Dictionary mapping fold indices to model checkpoint paths
        test_indices (np.ndarray): Indices of test data
        all_data (np.ndarray): The complete dataset
        all_labels (np.ndarray): The complete label set
        batch_size (int): Batch size for evaluation

    Returns:
        dict: Dictionary containing:
            - 'ensemble_predictions': Average predictions across all models
            - 'individual_predictions': Predictions from each individual model
            - 'test_labels': True labels for the test set
            - 'metrics': Performance metrics
    """
    # 1. Data Preparation
    logger.info(f'Extracting test data (n={len(test_indices)}) from full dataset')
    X_test = all_data[test_indices]
    y_test = all_labels[test_indices]

    test_dataset = SimpleDataset(X_test, y_test)

    # Create dataloader
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False
    )

    logger.info(f'Created test DataLoader with {len(test_loader)} batches')

    # 2. Model Loading and Prediction
    individual_predictions = {}
    num_models = len(best_fold_model_paths)

    logger.info(f'Starting evaluation of {num_models} models')

    for fold_idx, model_path in best_fold_model_paths.items():
        logger.info(f'Evaluating model from fold {fold_idx}')

        try:
            # Load the model
            logger.info(f'Loading model from {model_path}')
            model = load_model(model_path)
            model.eval()  # Ensure model is in evaluation mode

            # Get predictions
            fold_predictions = []

            with torch.no_grad():
                for images, _ in tqdm(test_loader, desc=f'Fold {fold_idx} Predictions'):
                    images = images.to(DEVICE)

                    # Forward pass
                    logits = model(images)

                    # Get probabilities for the positive class
                    probs = F.softmax(logits, dim=1)[:, 1]

                    # Store predictions
                    fold_predictions.append(probs.cpu().numpy())

            # Concatenate batch predictions
            fold_predictions = np.concatenate(fold_predictions)
            individual_predictions[fold_idx] = fold_predictions

            # Clean up
            del model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            logger.error(f'Error evaluating model from fold {fold_idx}: {e}')
            continue

    # 3. Ensemble Prediction
    if not individual_predictions:
        raise ValueError('No valid predictions from any model')

    # Stack individual predictions for easier operations
    all_preds = np.array(list(individual_predictions.values()))

    # Calculate mean and standard deviation across models
    ensemble_predictions = np.mean(all_preds, axis=0)
    prediction_std = np.std(all_preds, axis=0)

    logger.info(f'Created ensemble predictions with shape {ensemble_predictions.shape}')

    # 4. Performance Evaluation
    # Calculate metrics
    mse = np.mean((ensemble_predictions - y_test) ** 2)  # Brier score
    mae = np.mean(np.abs(ensemble_predictions - y_test))

    # Calculate Expected Calibration Error (ECE)
    n_bins = 10
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(ensemble_predictions, bin_boundaries) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)

    bin_sizes = np.zeros(n_bins)
    bin_avg_preds = np.zeros(n_bins)
    bin_avg_labels = np.zeros(n_bins)

    for bin_idx in range(n_bins):
        mask = bin_indices == bin_idx
        bin_sizes[bin_idx] = mask.sum()

        if bin_sizes[bin_idx] > 0:
            bin_avg_preds[bin_idx] = ensemble_predictions[mask].mean()
            bin_avg_labels[bin_idx] = y_test[mask].mean()

    ece = np.sum((bin_sizes / len(ensemble_predictions)) * np.abs(bin_avg_preds - bin_avg_labels))

    metrics = {
        'brier_score': mse,
        'mae': mae,
        'ece': ece,
        'bin_data': {
            'sizes': bin_sizes,
            'avg_preds': bin_avg_preds,
            'avg_labels': bin_avg_labels,
            'boundaries': bin_boundaries,
        },
    }

    logger.info(f'Brier Score: {mse:.4f}, MAE: {mae:.4f}, ECE: {ece:.4f}')

    # 5. Return results
    return {
        'ensemble_predictions': ensemble_predictions,
        'individual_predictions': individual_predictions,
        'prediction_std': prediction_std,
        'test_labels': y_test,
        'metrics': metrics,
    }

In [None]:
start = time.time()
stats = evaluate_kfold_ensemble(
    best_fold_model_paths=best_fold_model_paths,
    test_indices=test_indices,
    all_data=DATA,
    all_labels=LABEL,
    batch_size=64,
)
print(f'Finished in {time.time()-start:.2f} seconds.')

In [None]:
pred, true = stats['ensemble_predictions'], stats['test_labels']

In [None]:
np.mean(pred)

In [None]:
stats

In [None]:
data = {'ID': OBJ_ID[test_indices], 'label_v2': LABEL[test_indices], 'zoobot_pred_v2': pred}
df_test_v2 = pd.DataFrame(data)
df_test_v2.to_csv('drive/MyDrive/zoobot_models/testset_eval/test_predictions_v2.csv', index=False)

In [None]:
def performance_histogram(true, pred, bins=5, save_path=None):
    color = [
        colors['blue'],
        colors['orange'],
        colors['green'],
        colors['red'],
        colors['purple'],
    ]
    plt.close('all')
    plt.figure(figsize=(8, 6))
    for i in range(bins):
        arg_in = np.where((true >= 1 / bins * i) & (true <= 1 / bins * i + 1 / bins))[0]
        plt.hist(
            pred[arg_in],
            alpha=0.5,
            label=f'[{1/bins * i:.2f}, {1/bins * i + 1/bins:.2f}]',
            edgecolor='k',
            color=color[i],
            histtype='stepfilled',
            linewidth=2,
        )
    plt.semilogy()
    plt.legend(loc='best', fontsize=15)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylabel('Number of objects', fontsize=18)
    plt.xlabel('Predicted probability', fontsize=18)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def plot_gallery(
    data_val,
    label_val,
    pred,
    bins=10,
    per_bin=10,
    figsize=(50, 50),
    title_fontsize=50,
    label_fontsize=35,
    wspace=0.0,
    hspace=0.1,
    seed=None,
    save_path=None,
):
    """
    Generate and save a gallery of image cutouts arranged in a grid with columns
    corresponding to prediction bins and rows containing individual cutouts.
    If a bin has fewer samples than requested, empty spaces will be left.

    Parameters
    ----------
    data_val : array-like
        Array of image data.
    label_val : array-like
        Array of labels corresponding to the images.
    pred : array-like
        Array of prediction values used for binning.
    bins : int, optional
        Number of bins (columns) to use (default is 10).
    per_bin : int, optional
        Number of images per bin (rows) (default is 10).
    figsize : tuple, optional
        Figure size in inches (default is (50, 50)).
    title_fontsize : int, optional
        Font size for the column titles (bin ranges).
    label_fontsize : int, optional
        Font size for the x-label under each image.
    wspace : float, optional
        The amount of width reserved for blank space between subplots (default is 0.0).
    hspace : float, optional
        The amount of height reserved for white space between subplots (default is 0.1).
    seed : int or None, optional
        Random seed for reproducibility. If provided, the same images will be chosen each time.
    save_path: str, optional
        Path to save the figure. If None, the figure is not saved.
    """
    # Set random seed for reproducibility if provided.
    if seed is not None:
        np.random.seed(seed)

    # Create a grid of subplots
    fig, axes = plt.subplots(nrows=per_bin, ncols=bins, figsize=figsize)

    # Loop over each bin (column)
    for i in range(bins):
        # Define the bin range.
        bin_lower = 1.0 / bins * i
        bin_upper = 1.0 / bins * (i + 1)

        # Select indices for images whose predictions fall within the bin.
        bin_indices = np.where((pred > bin_lower) & (pred < bin_upper))[0]

        # Get the number of available samples in this bin
        available_samples = len(bin_indices)

        if available_samples > 0:
            # Sample without replacement, but only up to the available number
            actual_samples = min(available_samples, per_bin)
            idx = np.random.choice(bin_indices, size=actual_samples, replace=False)
        else:
            # No samples in this bin
            actual_samples = 0
            idx = []

        # Loop over each row for the given bin.
        for j in range(per_bin):
            ax = axes[j, i]

            # Only plot if we have a sample for this position
            if j < actual_samples:
                cutout_rgb = data_val[idx[j]]
                cutout_label = label_val[idx[j]]

                # Display the image; note the moveaxis to get channels in correct position
                cutout_rgb = np.moveaxis(cutout_rgb, 0, -1)
                ax.imshow(np.clip(cutout_rgb, 0, 1))

                # Set the label below the image.
                ax.set_xlabel(f'{cutout_label:.3f}', fontsize=label_fontsize, labelpad=5)
            else:
                # No sample available, set a dummy image with the same aspect ratio
                # Create a blank square with the right dimensions
                h, w = 64, 64  # Assuming this is your image size - adjust as needed
                blank = np.ones((h, w, 3)) * 0.9  # Light gray square
                ax.imshow(blank)
                ax.set_xlabel('', fontsize=label_fontsize)

            # Remove ticks for all subplots
            ax.set_xticks([])
            ax.set_yticks([])

            # Only add a title to the top row, showing the bin range.
            if j == 0:
                ax.set_title(
                    f'[{bin_lower:.1f}, {bin_upper:.1f}]',
                    fontsize=title_fontsize,
                    fontweight='bold',
                    pad=20,
                )

    # Adjust spacing between subplots.
    plt.subplots_adjust(wspace=wspace, hspace=hspace)

    # Save the figure if a save path is provided
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()


def visualize_ensemble_results(results, save_dir=None):
    """
    Visualizes the results from the ensemble evaluation.

    Args:
        results (dict): The results dictionary from evaluate_kfold_ensemble
        save_dir (str, optional): Directory to save figures. If None, figures are displayed only.
    """
    ensemble_predictions = results['ensemble_predictions']
    individual_predictions = results['individual_predictions']
    prediction_std = results['prediction_std']
    test_labels = results['test_labels']
    metrics = results['metrics']

    # Create color-blind friendly palette
    colors = {
        'blue': '#377eb8',
        'orange': '#ff7f00',
        'green': '#4daf4a',
        'pink': '#f781bf',
        'brown': '#a65628',
        'purple': '#984ea3',
        'gray': '#999999',
        'red': '#e41a1c',
        'yellow': '#dede00',
    }

    # 1. Prediction Distribution (2D histogram)
    fig1, ax1 = plt.subplots(figsize=(8, 7))
    h = ax1.hist2d(
        test_labels,
        ensemble_predictions,
        bins=[25, 25],
        range=[[0, 1], [0, 1]],
        cmap='viridis',
        norm='log',
    )
    ax1.plot([0, 1], [0, 1], 'r--', label='Perfect Prediction')
    ax1.axhline(test_labels.mean(), linestyle='--', color='b', label='Mean True Label')
    fig1.colorbar(h[3], ax=ax1)
    ax1.set_ylabel('Model Predictions')
    ax1.set_xlabel('True Probabilities')
    ax1.set_title(
        f'Ensemble Prediction vs Truth Distribution\nBrier Score: {metrics["brier_score"]:.4f}'
    )
    ax1.legend()

    if save_dir:
        plt.savefig(f'{save_dir}/prediction_distribution.png', dpi=300, bbox_inches='tight')

    # 2. Reliability (Calibration) Curve from metrics
    fig2, ax2 = plt.subplots(figsize=(8, 7))
    bin_data = metrics['bin_data']

    # Filter valid bins (with samples)
    valid_bins = bin_data['sizes'] > 0
    bin_avg_preds_valid = bin_data['avg_preds'][valid_bins]
    bin_avg_labels_valid = bin_data['avg_labels'][valid_bins]
    bin_sizes_valid = bin_data['sizes'][valid_bins]

    # Draw diagonal representing perfect calibration
    ax2.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')

    # Plot calibration curve
    ax2.plot(bin_avg_preds_valid, bin_avg_labels_valid, 'bo-', label='Model Calibration')

    # Add sample count annotations
    for i in range(len(bin_avg_preds_valid)):
        ax2.text(
            bin_avg_preds_valid[i],
            bin_avg_labels_valid[i] + 0.02,
            f'n={int(bin_sizes_valid[i])}',
            ha='center',
            va='bottom',
            fontsize=8,
        )

    ax2.set_xlabel('Mean Predicted Probability (Confidence)')
    ax2.set_ylabel('Mean True Label (Accuracy)')
    ax2.set_title(f'Reliability Curve\nECE: {metrics["ece"]:.4f}')
    ax2.legend()
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.set_xlim([-0.05, 1.05])
    ax2.set_ylim([-0.05, 1.05])

    if save_dir:
        plt.savefig(f'{save_dir}/reliability_curve.png', dpi=300, bbox_inches='tight')

    # 3. Individual Model Predictions
    fig3, ax3 = plt.subplots(figsize=(10, 6))

    # Sort test data by true label for clarity
    sort_idx = np.argsort(test_labels)
    sorted_test_labels = test_labels[sort_idx]

    # Plot true labels
    ax3.plot(
        range(len(sorted_test_labels)), sorted_test_labels, 'k-', label='True Labels', linewidth=2
    )

    # Plot each model's predictions (semi-transparent)
    for fold_idx, preds in individual_predictions.items():
        sorted_preds = preds[sort_idx]
        ax3.plot(range(len(sorted_preds)), sorted_preds, 'b-', alpha=0.2, linewidth=1)

    # Plot ensemble predictions
    sorted_ensemble_preds = ensemble_predictions[sort_idx]
    ax3.plot(
        range(len(sorted_ensemble_preds)),
        sorted_ensemble_preds,
        'r-',
        label='Ensemble Prediction',
        linewidth=2,
    )

    ax3.set_xlabel('Test Sample (sorted by true label)')
    ax3.set_ylabel('Probability')
    ax3.set_title('Ensemble vs Individual Model Predictions')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    if save_dir:
        plt.savefig(f'{save_dir}/model_predictions.png', dpi=300, bbox_inches='tight')

    # 4. Prediction Uncertainty
    fig4, ax4 = plt.subplots(figsize=(8, 7))

    # Scatter plot of predictions with error bars showing standard deviation
    sorted_std = prediction_std[sort_idx]

    # Create scatter plot where color intensity represents prediction uncertainty
    scatter = ax4.scatter(
        range(len(sorted_test_labels)),
        sorted_ensemble_preds,
        c=sorted_std,
        cmap='viridis_r',
        s=15,
        alpha=0.8,
    )

    # Plot true labels
    ax4.plot(
        range(len(sorted_test_labels)), sorted_test_labels, 'k-', label='True Labels', linewidth=1.5
    )

    # Add colorbar for standard deviation
    cbar = plt.colorbar(scatter, ax=ax4)
    cbar.set_label('Standard Deviation (Uncertainty)')

    ax4.set_xlabel('Test Sample (sorted by true label)')
    ax4.set_ylabel('Probability')
    ax4.set_title('Ensemble Predictions with Uncertainty')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    if save_dir:
        plt.savefig(f'{save_dir}/prediction_uncertainty.png', dpi=300, bbox_inches='tight')

    # 5. Error Analysis
    fig5, ax5 = plt.subplots(figsize=(8, 7))

    errors = ensemble_predictions - test_labels
    abs_errors = np.abs(errors)

    # Create a histogram of errors
    ax5.hist(errors, bins=50, alpha=0.7, color=colors['blue'])
    ax5.axvline(0, color='r', linestyle='--', linewidth=1)
    ax5.axvline(
        np.mean(errors),
        color='g',
        linestyle='-',
        linewidth=1,
        label=f'Mean Error: {np.mean(errors):.4f}',
    )

    ax5.set_xlabel('Prediction Error (Predicted - True)')
    ax5.set_ylabel('Count')
    ax5.set_title(f'Error Distribution\nMAE: {metrics["mae"]:.4f}')
    ax5.legend()

    if save_dir:
        plt.savefig(f'{save_dir}/error_distribution.png', dpi=300, bbox_inches='tight')

    # Show all plots
    plt.tight_layout()
    if not save_dir:
        plt.show()

    return [fig1, fig2, fig3, fig4, fig5]

In [None]:
from matplotlib.colors import to_rgba
from matplotlib.legend_handler import HandlerErrorbar


def performance_plot(
    true,
    pred,
    n_bins=None,
    min_samples=0,
    title='Model Performance',
    figsize=(8, 6),
    show_counts=True,
    size_by_count=True,
    overlap_tolerance=0.0,
    save_path=None,
):
    """
    Generates a performance plot showing mean predicted probability vs. true label.
    Shows sample counts to provide context for bins with few samples.

    Args:
        true (np.ndarray): Array of true labels (continuous, assumed 0-1).
        pred (np.ndarray): Array of predicted probabilities (continuous, 0-1).
        n_bins (int, optional):
            - If None (default): Bins are created based on the unique values
              present in the `true` array. X-axis shows the unique true values.
            - If int > 0: Creates `n_bins` equally sized bins for the `true`
              labels between 0 and 1. X-axis shows the bin centers.
        min_samples (int): Minimum number of samples required for a bin to be plotted.
        title (str): Title for the plot.
        figsize (tuple): Figure size for the plot.
        show_counts (bool): Whether to show sample counts as text annotations.
        size_by_count (bool): Whether to size points proportionally to sample counts.
        overlap_tolerance (float): Allowed overlap between annotation boxes (0.0-0.5).
            - 0.0 means no overlap allowed (strict)
            - Higher values allow more overlap (0.2 is a reasonable value)
            - Max recommended value is 0.5 (half box overlap)
        save_path (str, optional): Path to save the figure. If None, the figure is not saved.
    """
    import matplotlib.transforms as transforms

    # Ensure overlap_tolerance is within reasonable bounds
    overlap_tolerance = max(0.0, min(0.5, overlap_tolerance))

    plt.close('all')
    # Create figure
    fig = plt.figure(figsize=figsize, clear=True)
    ax = fig.add_subplot(111)

    x_values = []
    mean_pred_bin = []
    std_pred_bin = []
    bin_counts = []
    x_label = ''  # Will be set based on binning method

    # --- Choose Binning Strategy ---
    if n_bins is None or n_bins <= 0:
        # --- Mode 1: Bin by unique true label values ---
        x_label = 'Label'
        unique_true_values = np.unique(true)

        for i in unique_true_values:
            bin_indices = np.where(true == i)[0]
            pred_in_bin = pred[bin_indices]
            count = len(pred_in_bin)

            if count >= min_samples:
                x_values.append(i)
                mean_pred_bin.append(np.mean(pred_in_bin))
                std_pred_bin.append(np.std(pred_in_bin) if count > 1 else 0.0)
                bin_counts.append(count)
    else:
        # --- Mode 2: Use n_bins equally sized bins ---
        x_label = 'Label Bin Center'
        bin_edges = np.linspace(0, 1, n_bins + 1)

        for i in range(n_bins):
            bin_min = bin_edges[i]
            bin_max = bin_edges[i + 1]

            # Create mask for samples within the current bin
            if i == n_bins - 1:  # Handle the last bin edge explicitly
                bin_mask = (true >= bin_min) & (true <= bin_max)
            else:
                bin_mask = (true >= bin_min) & (true < bin_max)

            pred_in_bin = pred[bin_mask]
            count = len(pred_in_bin)

            if count >= min_samples:
                bin_center = (bin_min + bin_max) / 2.0
                x_values.append(bin_center)
                mean_pred_bin.append(np.mean(pred_in_bin))
                std_pred_bin.append(np.std(pred_in_bin) if count > 1 else 0.0)
                bin_counts.append(count)

    # Convert lists to numpy arrays
    x_values = np.array(x_values)
    mean_pred_bin = np.array(mean_pred_bin)
    std_pred_bin = np.array(std_pred_bin)
    bin_counts = np.array(bin_counts)

    # --- Plotting ---
    if len(x_values) > 0:  # Check if there's anything to plot
        # Size points by count if requested
        if size_by_count and len(bin_counts) > 0:
            # Scale point sizes between 20 and 200 based on counts
            min_count = np.min(bin_counts)
            max_count = np.max(bin_counts)

            # Handle case where all bins have the same count
            if max_count == min_count:
                point_sizes = np.ones_like(bin_counts) * 100
            else:
                # Scale point sizes logarithmically to better show differences
                log_counts = np.log1p(bin_counts)
                log_min = np.log1p(min_count)
                log_max = np.log1p(max_count)

                # Scale between 20 and 200
                normalized_sizes = (log_counts - log_min) / (log_max - log_min)
                point_sizes = 20 + normalized_sizes * 380
        else:
            point_sizes = np.ones_like(bin_counts) * 80

        # Plot error bars
        ax.errorbar(
            x_values,
            mean_pred_bin,
            yerr=std_pred_bin,
            fmt='none',  # No markers, we'll add them explicitly
            ecolor=colors['gray'],
            elinewidth=1.2,
            capsize=5,
            alpha=1,
            zorder=0,
        )

        # Add scatter points with size based on count
        ax.scatter(
            x_values,
            mean_pred_bin,
            s=point_sizes,
            color=colors['blue'],
            alpha=0.7,
            edgecolor='darkblue',
            linewidth=1,
        )

        ms_for_legend = 7  # A representative linear marker size for the legend
        x_dummy = [-1e9]  # Single off-screen data point x
        y_dummy = [-1e9]  # Single off-screen data point y
        y_err_dummy = [0.1]  # A small, non-zero error value for the dummy point
        dummy = ax.errorbar(
            x_dummy,
            y_dummy,
            yerr=y_err_dummy,  # yerr value is a placeholder to draw the error bar part
            fmt='o',
            linestyle='None',  # Format for marker and error bar symbol
            ms=ms_for_legend,
            # Marker face color (consistent with scatter)
            mfc=to_rgba(colors['blue'], alpha=0.7),
            # Marker edge color (consistent with scatter)
            mec=to_rgba('darkblue', alpha=0.7),
            # Marker edge width (consistent with scatter)
            mew=1,
            # Error bar line color (consistent with actual error bars)
            ecolor=to_rgba(colors['gray'], alpha=1),
            elinewidth=1.2,
            capsize=4,
            capthick=2,
            label='Mean ± Std Dev',  # The desired label
        )

        # Add count annotations if requested, with better positioning
        if show_counts:
            # Create a BBoxTransform to go from display to figure coordinates
            disp_to_fig = transforms.ScaledTranslation(0, 0, fig.dpi_scale_trans)

            # Sort points by x-coordinate for left-to-right processing
            indices = np.argsort(x_values)

            # Label placement parameters
            base_offset = 10  # in points

            # Keep track of label boxes in figure coordinates
            boxes = []  # List of (x_min, y_min, x_max, y_max) in figure coordinates

            # First pass - create annotations with temporary positions
            # We need this to get approximate dimensions
            annotations = []
            for idx in indices:
                x = x_values[idx]
                y = mean_pred_bin[idx]
                count = bin_counts[idx]

                # Create annotation but don't add it to the plot yet
                annotation = ax.annotate(
                    f'n={count}',
                    xy=(x, y),
                    xytext=(0, base_offset),
                    textcoords='offset points',
                    ha='center',
                    va='bottom',
                    fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.1', fc='white', alpha=0.7),
                )

                # Store annotation for later use
                annotations.append(annotation)

                # Remove from plot for now
                annotation.remove()

            # Function to check if two boxes overlap (with tolerance)
            def boxes_overlap(box1, box2, tolerance):
                # Shrink the effective box size based on tolerance
                width1 = box1[2] - box1[0]
                height1 = box1[3] - box1[1]
                width2 = box2[2] - box2[0]
                height2 = box2[3] - box2[1]

                # Apply tolerance to create effective box coordinates
                eff_box1 = (
                    box1[0] + width1 * tolerance,
                    box1[1] + height1 * tolerance,
                    box1[2] - width1 * tolerance,
                    box1[3] - height1 * tolerance,
                )

                eff_box2 = (
                    box2[0] + width2 * tolerance,
                    box2[1] + height2 * tolerance,
                    box2[2] - width2 * tolerance,
                    box2[3] - height2 * tolerance,
                )

                # Check if effective boxes overlap
                return (
                    eff_box1[0] < eff_box2[2]
                    and eff_box1[2] > eff_box2[0]
                    and eff_box1[1] < eff_box2[3]
                    and eff_box1[3] > eff_box2[1]
                )

            # Second pass - place annotations with smart positioning
            for i, idx in enumerate(indices):
                x = x_values[idx]
                y = mean_pred_bin[idx]
                annotation = annotations[i]

                # Start with base offset
                offset = base_offset

                # Try increasing offsets until no overlap
                max_attempts = 30
                attempts = 0
                overlap = True

                # Ensure this annotation gets added to the plot regardless of overlap
                added_to_plot = False

                while overlap and attempts < max_attempts:
                    # Set the offset
                    annotation.xyann = (0, offset)

                    # Add to the axes temporarily to get box dimensions
                    ax.add_artist(annotation)

                    # Get the bounding box in display coordinates
                    box = annotation.get_window_extent(fig.canvas.get_renderer())

                    # Convert to figure coordinates
                    box_fig = box.transformed(fig.transFigure.inverted())

                    # Extract coordinates
                    box_coords = (box_fig.x0, box_fig.y0, box_fig.x1, box_fig.y1)

                    # Check for overlap with existing boxes
                    overlap = False
                    for b in boxes:
                        # Check if boxes overlap (using tolerance)
                        if boxes_overlap(box_coords, b, overlap_tolerance):
                            overlap = True
                            break

                    if overlap:
                        # Remove from axes before trying again
                        annotation.remove()

                        # Increase offset
                        offset += 3
                        attempts += 1
                    else:
                        # No overlap, keep this position
                        boxes.append(box_coords)
                        added_to_plot = True
                        break

                # If we couldn't find a non-overlapping position, use the last one anyway
                if attempts >= max_attempts:
                    # Remove the annotation if it was added during testing
                    if annotation in ax.texts:
                        annotation.remove()

                    # Add it back with the final offset
                    annotation.xyann = (0, offset)
                    ax.add_artist(annotation)
                    added_to_plot = True

                    # Get final box for future reference
                    box = annotation.get_window_extent(fig.canvas.get_renderer())
                    box_fig = box.transformed(fig.transFigure.inverted())
                    boxes.append((box_fig.x0, box_fig.y0, box_fig.x1, box_fig.y1))

                # Final check to ensure the annotation is in the plot
                if not added_to_plot:
                    annotation.xyann = (0, offset)
                    ax.add_artist(annotation)

                    # Get final box for future reference
                    box = annotation.get_window_extent(fig.canvas.get_renderer())
                    box_fig = box.transformed(fig.transFigure.inverted())
                    boxes.append((box_fig.x0, box_fig.y0, box_fig.x1, box_fig.y1))
    else:
        plt.text(
            0.5,
            0.5,
            'No data points found for plotting',
            ha='center',
            va='center',
            transform=plt.gca().transAxes,
        )

    # Plot the perfect calibration line (y=x)
    ax.plot(
        [0, 1],
        [0, 1],
        color=colors.get('red', '#e41a1c'),
        linestyle='--',
        label='Perfect Calibration',
    )

    # Add labels, title, legend, grid
    ax.set_xlabel(x_label, fontsize=18)
    ax.set_ylabel('Predicted Probability', fontsize=18)
    # set x and y tick fontsize
    ax.tick_params(axis='both', labelsize=12)
    # ax.set_title(title, fontsize=16)
    ax.legend(fontsize=13, handler_map={type(dummy): HandlerErrorbar(yerr_size=0.6)})
    ax.grid(True, linestyle='--', alpha=0.6)

    # Set axis limits
    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])

    plt.tight_layout()

    # Save the figure if a save path is provided
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        # plt.close()
    plt.show()

In [None]:
#'drive/MyDrive/unions_dwarfs_results/figures/model_performance.png'
performance_plot(
    true,
    pred,
    save_path='drive/MyDrive/unions_dwarfs_results/figures/model_performance.pdf',
    overlap_tolerance=0.08,
    show_counts=False,
)

In [None]:
performance_histogram(
    true,
    pred,
    bins=4,
    save_path='drive/MyDrive/unions_dwarfs_results/figures/performance_histogram.pdf',
)

In [None]:
from matplotlib.gridspec import GridSpec


def plot_calibration_soft_labels(soft_labels, model_probabilities, n_bins=10, strategy='uniform'):
    """
    Illustrates the calibration of a model trained on soft labels.

    This function generates a reliability diagram and calculates the
    Expected Calibration Error (ECE), adapted for soft labels.

    Args:
        soft_labels (np.ndarray): A 1D NumPy array of soft labels, where each
                                   value is between 0 and 1, representing the
                                   true probability of the positive class for each sample.
        model_probabilities (np.ndarray): A 1D NumPy array of the model's
                                          predicted probabilities for the
                                          positive class for each sample.
        n_bins (int): The number of bins to use for the reliability diagram.
        strategy (str): Strategy used to define the widths of the bins.
                        'uniform': All bins have identical widths.
                        'quantile': All bins have the same number of points.
                                    (Not yet implemented, defaults to uniform)
                                    # TODO: Implement quantile binning if needed
    Returns:
        tuple: (fig, ece)
            - fig (matplotlib.figure.Figure): The figure object for the plot.
            - ece (float): The Expected Calibration Error adapted for soft labels.
    """
    if not isinstance(soft_labels, np.ndarray) or not isinstance(model_probabilities, np.ndarray):
        raise ValueError('soft_labels and model_probabilities must be NumPy arrays.')
    if soft_labels.shape != model_probabilities.shape:
        raise ValueError('soft_labels and model_probabilities must have the same shape.')
    if np.any(soft_labels < 0) or np.any(soft_labels > 1):
        raise ValueError('Soft labels must be between 0 and 1.')
    if np.any(model_probabilities < 0) or np.any(model_probabilities > 1):
        raise ValueError('Model probabilities must be between 0 and 1.')

    # For multi-class, this function assumes soft_labels and model_probabilities
    # are for a specific class of interest, or for the positive class in a
    # binary/pseudo-binary setup.

    if strategy == 'uniform':
        bin_limits = np.linspace(0, 1, n_bins + 1)
    else:
        raise ValueError(
            f"Unknown binning strategy: {strategy}. Choose 'uniform'."
        )  # or 'quantile' when implemented

    bin_lowers = bin_limits[:-1]
    bin_uppers = bin_limits[1:]

    bin_centers = (bin_lowers + bin_uppers) / 2
    binned_soft_labels_mean = np.zeros(n_bins)
    binned_model_probs_mean = np.zeros(n_bins)
    bin_counts = np.zeros(n_bins)
    bin_gaps = np.zeros(n_bins)

    for i in range(n_bins):
        lower = bin_lowers[i]
        upper = bin_uppers[i]

        # Handle the last bin to include 1.0
        if i == n_bins - 1:
            in_bin = (model_probabilities >= lower) & (model_probabilities <= upper)
        else:
            in_bin = (model_probabilities >= lower) & (model_probabilities < upper)

        bin_counts[i] = np.sum(in_bin)

        if bin_counts[i] > 0:
            binned_soft_labels_mean[i] = np.mean(soft_labels[in_bin])
            binned_model_probs_mean[i] = np.mean(model_probabilities[in_bin])
        else:
            binned_soft_labels_mean[i] = 0
            binned_model_probs_mean[i] = (
                0  # or bin_centers[i] - this affects ECE slightly if bin is empty
            )

        bin_gaps[i] = np.abs(binned_soft_labels_mean[i] - binned_model_probs_mean[i])

    # Calculate Expected Calibration Error (ECE) for soft labels
    # ECE = sum_{m=1}^{M} ( |B_m| / N ) * | acc(B_m) - conf(B_m) |
    # Here, acc(B_m) is the average soft label in bin m,
    # and conf(B_m) is the average model probability in bin m.
    total_samples = len(model_probabilities)
    if total_samples == 0:
        ece = 0.0
    else:
        ece = np.sum((bin_counts / total_samples) * bin_gaps)

    # Plotting
    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(4, 1)  # Grid for reliability diagram and histogram

    # Reliability diagram
    ax1 = fig.add_subplot(gs[:3, 0])
    ax1.plot([0, 1], [0, 1], 'k:', label='Perfectly calibrated')

    # Plot bars for the gaps
    # The bottom of the bar will be min(avg_soft_label, avg_model_prob)
    # The height of the bar will be the gap
    bar_bottoms = np.minimum(binned_soft_labels_mean, binned_model_probs_mean)
    ax1.bar(
        bin_centers,
        bin_gaps,
        bottom=bar_bottoms,
        width=(bin_uppers[0] - bin_lowers[0]) * 0.9,  # Adjust bar width slightly
        color='lightcoral',
        edgecolor='firebrick',
        alpha=0.7,
        label='Gap (Over/Under-confidence)',
    )

    # Plot points for average soft label vs average model probability
    # Only plot for bins with samples
    valid_bins = bin_counts > 0
    ax1.plot(
        binned_model_probs_mean[valid_bins],
        binned_soft_labels_mean[valid_bins],
        's-',
        color='navy',
        markersize=8,
        label='Model calibration',
    )

    ax1.set_xlabel('Average Predicted Probability (Confidence) in bin', fontsize=12)
    ax1.set_ylabel('Average True Probability (Soft Label) in bin', fontsize=12)
    ax1.set_title(f'Reliability Diagram (Soft Labels)\nECE: {ece:.4f}', fontsize=14)
    ax1.legend(loc='upper left', fontsize=10)
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])

    # Histogram of predicted probabilities
    ax2 = fig.add_subplot(gs[3, 0])
    ax2.hist(
        model_probabilities,
        range=(0, 1),
        bins=n_bins,
        color='cornflowerblue',
        edgecolor='black',
        alpha=0.8,
    )
    ax2.set_xlabel('Predicted Probability', fontsize=12)
    ax2.set_ylabel('Number of Samples', fontsize=12)
    ax2.set_title('Histogram of Predicted Probabilities', fontsize=14)
    ax2.grid(axis='y', linestyle='--', alpha=0.7)
    ax2.set_xlim([0, 1])

    plt.tight_layout()
    return fig, ece

In [None]:
plot_calibration_soft_labels(
    soft_labels=true, model_probabilities=pred, n_bins=10, strategy='uniform'
)

In [None]:
len(pred)

In [None]:
data_test = DATA[test_indices]
label_test = LABEL[test_indices]

plot_gallery(
    data_test,
    label_test,
    pred,
    bins=10,
    per_bin=5,
    figsize=(30, 15),
    title_fontsize=28,
    label_fontsize=23,
    wspace=0.2,
    hspace=0.15,
    seed=42,
    save_path='drive/MyDrive/unions_dwarfs_results/figures/gallery_kfold.pdf',
)
# 'drive/MyDrive/unions_dwarfs_results/figures/gallery_kfold_5bin.pdf'

In [None]:
visualize_ensemble_results(stats, save_dir=None)