In [None]:
# All-in-one DIET finetuning experiment
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, adjusted_rand_score, normalized_mutual_info_score
from sklearn.model_selection import train_test_split
from torchvision import transforms  # For image transformations in AIM fallback
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision
from torchvision.models import resnet50, ResNet50_Weights
import copy  # Add this import here
# Add imports for DINOv2

from transformers import AutoImageProcessor, AutoModel
import torch.nn.functional as F
print("Setting up experiment...")
import torchvision.datasets as datasets
try:
    import medmnist
    from medmnist import INFO
    MEDMNIST_AVAILABLE = True
except ImportError:
    print("MedMNIST not found, install with: pip install medmnist")
    MEDMNIST_AVAILABLE = False
# Basic settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
import os
import wandb
import torch
import torch.nn as nn
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from io import BytesIO
import wandb

def init_wandb(args):
    """Initialize wandb for experiment tracking
    
    Args:
        args: Dictionary containing experiment configuration parameters
    
    Returns:
        run: wandb run object
    """
    # Create experiment name with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = f"DIET_{args['backbone_type']}_{args['model_size']}_{args['dataset_name']}_{timestamp}"
    
    # Initialize wandb run
    run = wandb.init(
        project="DIET-Finetuning",
        name=experiment_name,
        config=args,
        settings=wandb.Settings(start_method="thread"),
        tags=[
            args['backbone_type'],
            args['model_size'],
            args['dataset_name'],
            "DIET" if args['label_smoothing'] > 0 else "Baseline"
        ]
    )
    
    print(f"WandB initialized: {run.name}")
    return run

def log_training_metrics(run, metrics, epoch, lr=None):
    """Log training metrics to wandb
    
    Args:
        run: wandb run object
        metrics: Dictionary of training metrics
        epoch: Current epoch number
        lr: Current learning rate (optional)
    """
    log_dict = {
        "train/diet_loss": metrics["diet_loss"],
        "train/probe_loss": metrics["probe_loss"],
        "train/accuracy": metrics["accuracy"],
        "epoch": epoch
    }
    
    # Add learning rate if provided
    if lr is not None:
        log_dict["train/learning_rate"] = lr
    
    # Log metrics to wandb
    run.log(log_dict)

def log_evaluation_metrics(run, metrics, epoch):
    """Log evaluation metrics to wandb
    
    Args:
        run: wandb run object
        metrics: Dictionary of evaluation metrics
        epoch: Current epoch number
    """
    log_dict = {
        "eval/accuracy": metrics["accuracy"],
        "epoch": epoch
    }
    
    # Log metrics to wandb
    run.log(log_dict)

def log_zero_shot_metrics(run, metrics, epoch, initial_metrics=None):
    """Log zero-shot evaluation metrics to wandb
    
    Args:
        run: wandb run object
        metrics: Dictionary of zero-shot metrics
        epoch: Current epoch number
        initial_metrics: Initial zero-shot metrics for comparison (optional)
    """
    log_dict = {
        "epoch": epoch
    }
    
    # Log each zero-shot metric
    for metric_name, value in metrics.items():
        log_dict[f"zero_shot/{metric_name}"] = value
        
        # Log improvements if initial metrics are provided
        if initial_metrics is not None:
            improvement = value - initial_metrics[metric_name]
            relative_improvement = (improvement / initial_metrics[metric_name]) * 100 if initial_metrics[metric_name] > 0 else float('inf')
            log_dict[f"zero_shot/{metric_name}_improvement"] = improvement
            log_dict[f"zero_shot/{metric_name}_relative_improvement"] = relative_improvement
    
    # Calculate average improvement if initial metrics are provided
    if initial_metrics is not None:
        improvements = [metrics[m] - initial_metrics[m] for m in metrics.keys()]
        avg_improvement = np.mean(improvements)
        log_dict["zero_shot/average_improvement"] = avg_improvement
    
    # Log metrics to wandb
    run.log(log_dict)

def save_model_checkpoint(run, model, optimizer, projection_head, W_probe, W_diet, epoch, metrics, save_dir="checkpoints"):
    """Save model checkpoint and log it to wandb
    
    Args:
        run: wandb run object
        model: The backbone model
        optimizer: Optimizer
        projection_head: DIET projection head
        W_probe: Probe linear layer
        W_diet: DIET linear layer
        epoch: Current epoch number
        metrics: Metrics to determine if this is a best checkpoint
        save_dir: Directory to save checkpoints locally
    """
    # Create checkpoint directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Determine if this is the best checkpoint based on test accuracy
    test_acc = metrics.get("test_acc", 0)
    is_best = test_acc > getattr(save_model_checkpoint, "best_acc", 0)
    if is_best:
        save_model_checkpoint.best_acc = test_acc
    
    # Create checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'projection_head_state_dict': projection_head.state_dict(),
        'W_probe_state_dict': W_probe.state_dict(),
        'W_diet_state_dict': W_diet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'test_acc': test_acc,
        'metrics': metrics
    }
    
    # Save checkpoint locally
    checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pt")
    torch.save(checkpoint, checkpoint_path)
    
    # Save best checkpoint separately
    if is_best:
        best_path = os.path.join(save_dir, "best_checkpoint.pt")
        torch.save(checkpoint, best_path)
        
        # Log best checkpoint to wandb
        best_artifact = wandb.Artifact(
            name=f"best_model_{run.id}", 
            type="model",
            description=f"Best model checkpoint (epoch {epoch}, acc={test_acc:.4f})"
        )
        best_artifact.add_file(best_path)
        run.log_artifact(best_artifact)
    
    # Log regular checkpoint to wandb every 5 epochs or final epoch
    if epoch % 5 == 0 or is_best:
        artifact = wandb.Artifact(
            name=f"model_e{epoch}_{run.id}", 
            type="model",
            description=f"Model checkpoint from epoch {epoch}"
        )
        artifact.add_file(checkpoint_path)
        run.log_artifact(artifact)

# Initialize static variable for best accuracy
save_model_checkpoint.best_acc = 0

def log_model_architecture(run, model, projection_head, W_probe, W_diet):
    """Log model architecture details to wandb
    
    Args:
        run: wandb run object
        model: The backbone model
        projection_head: DIET projection head
        W_probe: Probe linear layer
        W_diet: DIET linear layer
    """
    # Log architecture as a text table
    architecture_text = "# Model Architecture\n\n"
    
    # Count parameters
    model_params = sum(p.numel() for p in model.parameters())
    trainable_model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    projection_params = sum(p.numel() for p in projection_head.parameters())
    W_probe_params = sum(p.numel() for p in W_probe.parameters())
    W_diet_params = sum(p.numel() for p in W_diet.parameters())
    total_params = model_params + projection_params + W_probe_params + W_diet_params
    total_trainable = trainable_model_params + projection_params + W_probe_params + W_diet_params
    
    # Add parameter counts
    architecture_text += f"## Parameter Counts\n\n"
    architecture_text += f"| Component | Total Parameters | Trainable Parameters | % of Total |\n"
    architecture_text += f"|-----------|-----------------|----------------------|------------|\n"
    architecture_text += f"| Backbone Model | {model_params:,} | {trainable_model_params:,} | {100 * model_params / total_params:.2f}% |\n"
    architecture_text += f"| Projection Head | {projection_params:,} | {projection_params:,} | {100 * projection_params / total_params:.2f}% |\n"
    architecture_text += f"| Classification Head | {W_probe_params:,} | {W_probe_params:,} | {100 * W_probe_params / total_params:.2f}% |\n"
    architecture_text += f"| DIET Head | {W_diet_params:,} | {W_diet_params:,} | {100 * W_diet_params / total_params:.2f}% |\n"
    architecture_text += f"| **Total** | **{total_params:,}** | **{total_trainable:,}** | **100%** |\n\n"
    
    # Log the architecture text
    run.log({"model_architecture": wandb.Html(architecture_text)})

def log_zero_shot_comparison_table(run, metrics_history, tracked_epochs, metrics_list):
    """Log zero-shot metrics comparison table to wandb
    
    Args:
        run: wandb run object
        metrics_history: Dictionary of metrics history
        tracked_epochs: List of epochs to include in the table
        metrics_list: List of metric names to include
    """
    # Create table data
    columns = ["Epoch"] + metrics_list
    data = []
    
    for epoch in tracked_epochs:
        row = [epoch]
        for metric in metrics_list:
            row.append(metrics_history["zero_shot_metrics"][epoch][metric])
        data.append(row)
    
    # Create wandb Table
    table = wandb.Table(columns=columns, data=data)
    
    # Log table
    run.log({"zero_shot_comparison": table})

def log_figure_to_wandb(run, figure, name):
    """Convert matplotlib figure to wandb Image and log it
    
    Args:
        run: wandb run object
        figure: Matplotlib figure
        name: Name for the logged figure
    """
    # Save figure to a BytesIO object
    buf = BytesIO()
    figure.savefig(buf, format='png')
    buf.seek(0)
    
    # Log figure as an image
    run.log({name: wandb.Image(buf)})

def create_zero_shot_progression_plot(metrics_history, tracked_epochs, metrics_list):
    """Create zero-shot metrics progression plot
    
    Args:
        metrics_history: Dictionary of metrics history
        tracked_epochs: List of epochs to track
        metrics_list: List of metric names to include
        
    Returns:
        fig: Matplotlib figure
    """
    # Create figure
    fig = Figure(figsize=(15, 10))
    
    # Plot each metric's progression
    for i, metric in enumerate(metrics_list):
        ax = fig.add_subplot(2, 2, i+1)
        values = [metrics_history["zero_shot_metrics"][e][metric] for e in tracked_epochs]
        ax.plot(tracked_epochs, values, marker='o', linewidth=2)
        ax.set_xlabel('Epoch')
        ax.set_ylabel(f'{metric} Score')
        ax.set_title(f'Zero-shot {metric} Progression')
        ax.grid(True)
        
        # Add initial and final values as text annotations
        ax.annotate(f'{values[0]:.4f}', (tracked_epochs[0], values[0]), 
                    textcoords="offset points", xytext=(0,10), ha='center')
        ax.annotate(f'{values[-1]:.4f}', (tracked_epochs[-1], values[-1]),
                    textcoords="offset points", xytext=(0,10), ha='center')
    
    fig.tight_layout()
    fig.suptitle('Zero-shot Metrics Progression During Training', fontsize=16)
    fig.subplots_adjust(top=0.9)
    
    return fig

def create_training_progress_plot(metrics_history):
    """Create training progress plot
    
    Args:
        metrics_history: Dictionary of metrics history
        
    Returns:
        fig: Matplotlib figure
    """
    # Create figure
    fig = Figure(figsize=(15, 5))
    
    # Plot loss
    ax1 = fig.add_subplot(1, 3, 1)
    ax1.plot(metrics_history["train_loss_diet"], label="DIET Loss")
    ax1.plot(metrics_history["train_loss_probe"], label="Probe Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training Loss")
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracy
    ax2 = fig.add_subplot(1, 3, 2)
    ax2.plot(metrics_history["train_acc"], label="Train Accuracy")
    ax2.plot(metrics_history["test_acc"], label="Test Accuracy")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.set_title("Model Accuracy")
    ax2.legend()
    ax2.grid(True)
    
    # Plot zero-shot metrics
    ax3 = fig.add_subplot(1, 3, 3)
    
    # Get initial and final zero-shot metrics
    tracked_epochs = sorted(metrics_history["zero_shot_metrics"].keys())
    if len(tracked_epochs) >= 2:
        initial_epoch = tracked_epochs[0]
        final_epoch = tracked_epochs[-1]
        
        metrics = list(metrics_history["zero_shot_metrics"][initial_epoch].keys())
        x = range(len(metrics))
        width = 0.35
        
        # Plot bar chart comparing initial and final metrics
        ax3.bar(x, [metrics_history["zero_shot_metrics"][initial_epoch][m] for m in metrics], 
                width, label='Initial')
        ax3.bar([i + width for i in x], [metrics_history["zero_shot_metrics"][final_epoch][m] for m in metrics], 
                width, label='Final')
        ax3.set_xlabel("Metrics")
        ax3.set_ylabel("Score")
        ax3.set_title("Zero-Shot Performance")
        ax3.set_xticks([i + width/2 for i in x])
        ax3.set_xticklabels(metrics)
        ax3.legend()
        ax3.grid(True)
    
    fig.tight_layout()
    
    return fig

In [None]:
# Add this function to your codebase to log zero-shot metrics to W&B tables
def log_metrics_table(run, metrics_history):
    """
    Log zero-shot metrics to a W&B table.
    
    Args:
        run: The wandb run object
        metrics_history: Dictionary containing metrics history with zero_shot_metrics
    """
    # Get all tracked epochs and metrics
    tracked_epochs = sorted(metrics_history["zero_shot_metrics"].keys())
    metrics_list = list(metrics_history["zero_shot_metrics"][tracked_epochs[0]].keys())
    
    # Create table columns
    columns = ["epoch"]
    for metric in metrics_list:
        columns.extend([f"{metric}", f"{metric}_change", f"{metric}_relative_change"])
    
    # Create table
    zero_shot_table = wandb.Table(columns=columns)
    
    # Initial values for calculating changes
    initial_values = metrics_history["zero_shot_metrics"][0]
    
    # Add data for each epoch
    for epoch in tracked_epochs:
        current_metrics = metrics_history["zero_shot_metrics"][epoch]
        
        # Create a row for this epoch
        row = [epoch]
        
        # Add metrics, absolute change and relative change for each metric
        for metric in metrics_list:
            current_value = current_metrics[metric]
            change = current_value - initial_values[metric]
            rel_change = (change / initial_values[metric]) * 100 if initial_values[metric] > 0 else float('inf')
            
            # Add to row: current value, absolute change, relative change
            row.extend([current_value, change, rel_change])
        
        # Add row to table
        zero_shot_table.add_data(*row)
    
    # Log the table
    run.log({"zero_shot_metrics_table": zero_shot_table})
    
    # Also create a summary table for final results
    final_metrics = metrics_history["zero_shot_metrics"][tracked_epochs[-1]]
    
    summary_columns = ["metric", "initial", "final", "change", "relative_change"]
    summary_table = wandb.Table(columns=summary_columns)
    
    for metric in metrics_list:
        initial = initial_values[metric]
        final = final_metrics[metric]
        change = final - initial
        rel_change = (change / initial) * 100 if initial > 0 else float('inf')
        
        summary_table.add_data(metric, initial, final, change, rel_change)
    
    # Log the summary table
    run.log({"zero_shot_summary_table": summary_table})


# Add this function to improve visualization of metrics with proper names
def log_zero_shot_comparison_table(run, metrics_history, tracked_epochs, metrics_list):
    """
    Create and log a more visually appealing zero-shot metrics comparison table.
    
    Args:
        run: The wandb run object
        metrics_history: Dictionary containing metrics history
        tracked_epochs: List of epochs to include
        metrics_list: List of metrics to include
    """
    # Create a nicely formatted table
    metric_names = {
        "knn_acc": "K-NN Accuracy",
        "kmeans_ari": "K-Means ARI",
        "kmeans_nmi": "K-Means NMI",
        "linear_acc": "Linear Probe Accuracy"
    }
    
    # Create nice column headers
    columns = ["Epoch"]
    for metric in metrics_list:
        # Use friendly names if available
        nice_name = metric_names.get(metric, metric)
        columns.append(nice_name)
    
    # Create the table
    table = wandb.Table(columns=columns)
    
    # Add a row for each epoch
    for epoch in tracked_epochs:
        row = [epoch]
        for metric in metrics_list:
            value = metrics_history["zero_shot_metrics"][epoch][metric]
            row.append(value)
        table.add_data(*row)
    
    # Log the table
    run.log({"zero_shot_progression": table})


# EXAMPLE USAGE:
# Add this where you're already doing your final logging
# This should be placed around line 850 in paste.txt where you're
# creating other visualizations of your data

# Right after the "Create a table for the report" section,
# add this code to log the tables to W&B:


In [None]:
# Add this to your code - this creates a more advanced table with visualization helpers

def log_enhanced_metrics_table(run, metrics_history, initial_results, final_results):
    """
    Create an enhanced metrics table that includes visualization markers for easier interpretation.
    
    Args:
        run: The wandb run object
        metrics_history: Dictionary with metrics history
        initial_results: Dictionary with initial metrics
        final_results: Dictionary with final metrics
    """
    # Get all tracked epochs and metrics
    tracked_epochs = sorted(metrics_history["zero_shot_metrics"].keys())
    metrics_list = list(initial_results.keys())
    
    # Create columns with epoch, metric name, and values
    columns = ["epoch", "metric", "value", "change_from_init", "percent_change", "trend"]
    table = wandb.Table(columns=columns)
    
    # Add data for each epoch and metric
    for epoch in tracked_epochs:
        epoch_metrics = metrics_history["zero_shot_metrics"][epoch]
        
        for metric in metrics_list:
            current = epoch_metrics[metric]
            initial = initial_results[metric]
            change = current - initial
            percent = (change / initial) * 100 if initial > 0 else 0
            
            # Create a trend indicator (↑ for improvement, ↓ for decline)
            # This is simplistic - for some metrics higher might be worse
            trend = "↑" if change > 0 else "↓" if change < 0 else "→"
            
            # Add row to table
            table.add_data(epoch, metric, current, change, percent, trend)
    
    # Log the table
    run.log({"metrics_detailed_progression": table})
    
    # Create a summary table with color indicators
    summary_columns = ["metric", "initial", "final", "absolute_change", 
                       "percent_change", "trend", "significance"]
    summary_table = wandb.Table(columns=summary_columns)
    
    # Add each metric to the summary
    for metric in metrics_list:
        initial = initial_results[metric]
        final = final_results[metric]
        change = final - initial
        percent = (change / initial) * 100 if initial > 0 else 0
        
        # Create a trend indicator
        trend = "↑" if change > 0 else "↓" if change < 0 else "→"
        
        # Significance level (arbitrary thresholds)
        if abs(percent) > 20:
            significance = "High"
        elif abs(percent) > 5:
            significance = "Medium"
        else:
            significance = "Low"
        
        # Add to summary table
        summary_table.add_data(
            metric, initial, final, change, percent, trend, significance
        )
    
    # Log the summary table
    run.log({"metrics_final_summary": summary_table})



In [None]:
def log_sanity_check_results(run, results_dict, model_type):
    """
    Log sanity check results to W&B tables
    
    Args:
        run: W&B run object
        results_dict: Dictionary of sanity check results
        model_type: Type of model that was evaluated
    """
    if results_dict is None:
        print(f"No results to log for {model_type}")
        return
    
    # Extract data
    k_values = results_dict['k_values']
    accuracies = results_dict['accuracies']
    best_acc = results_dict['best_acc']
    best_k = results_dict['best_k']
    
    # Check if we have linear probe results (for IJEPA)
    has_linear_probe = 'linear_probe_acc' in results_dict
    
    # Create accuracy table by k value
    k_table = wandb.Table(columns=["model", "k_value", "accuracy"])
    
    # Add data for each k value
    for k, acc in zip(k_values, accuracies):
        k_table.add_data(model_type, k, acc * 100)  # Convert to percentage
    
    # Log the table
    run.log({f"sanity_check_{model_type}_knn": k_table})
    
    # Create summary table
    if has_linear_probe:
        summary_columns = ["model", "method", "accuracy", "best_k"]
        summary_table = wandb.Table(columns=summary_columns)
        
        # Add k-NN result
        summary_table.add_data(model_type, "k-NN", best_acc * 100, best_k)
        
        # Add linear probe result
        summary_table.add_data(model_type, "Linear Probe", 
                             results_dict['linear_probe_acc'] * 100, "N/A")
    else:
        summary_columns = ["model", "best_k_value", "best_accuracy"]
        summary_table = wandb.Table(columns=summary_columns)
        summary_table.add_data(model_type, best_k, best_acc * 100)
    
    # Log the summary table
    run.log({f"sanity_check_{model_type}_summary": summary_table})
    
    return


def log_combined_sanity_check_results(run, results_dict):
    """
    Log a combined table of all sanity check results
    
    Args:
        run: W&B run object
        results_dict: Dictionary mapping model types to their sanity check results
    """
    # Create table for combined results
    combined_table = wandb.Table(
        columns=["model", "method", "best_accuracy", "best_k", "passed_check"]
    )
    
    # Expected thresholds for each model type
    thresholds = {
        "dinov2": 0.91,
        "mae": 0.85,
        "mambavision": 0.85,
        "ijepa": 0.85,
        "aim": 0.75
    }
    
    # Add data for each model
    for model_type, results in results_dict.items():
        if results is None:
            continue
            
        # Get model's threshold
        threshold = thresholds.get(model_type, 0.85)
        
        # Add k-NN result
        best_acc = results['best_acc']
        best_k = results['best_k']
        passed = best_acc >= threshold
        
        combined_table.add_data(
            model_type, 
            "k-NN", 
            best_acc * 100,  # Convert to percentage
            best_k,
            "✓" if passed else "✗"
        )
        
        # Add linear probe result if available
        if 'linear_probe_acc' in results:
            linear_acc = results['linear_probe_acc']
            linear_passed = linear_acc >= threshold
            
            combined_table.add_data(
                model_type,
                "Linear Probe",
                linear_acc * 100,  # Convert to percentage 
                "N/A",
                "✓" if linear_passed else "✗"
            )
    
    # Log the combined table
    run.log({"sanity_check_combined_results": combined_table})
    
    return

# Example usage:
# sanity_results = {
#     "dinov2": sanity_results_dinov2,
#     "mae": sanity_results_mae,
#     "mambavision": sanity_results_mambavision,
#     "ijepa": sanity_results_ijepa,
#     "aim": sanity_results_aim
# }
# 
# # Log individual results
# for model_type, results in sanity_results.items():
#     log_sanity_check_results(run, results, model_type)
# 
# # Log combined results
# log_combined_sanity_check_results(run, sanity_results)

In [None]:
def unified_sanity_check(
    model_type,
    model_size=None,
    model_variant=None,
    expected_threshold=None,  # Now optional, will be set based on model_type
    batch_size=None,          # Now optional, will be set based on model_type
    k_values=None,
    num_workers=0,
    log_to_wandb=True        # Added parameter to control W&B logging
):
    """
    Unified sanity check with integrated W&B logging: Evaluate model's zero-shot 
    performance on CIFAR10 using k-NN.
    
    Args:
        model_type (str): Type of model ("dinov2", "mae", "mambavision", "ijepa", "aim")
        model_size (str, optional): Model size for relevant models. Defaults based on model_type.
        model_variant (str, optional): Model variant (for mambavision). Defaults based on model_type.
        expected_threshold (float, optional): Expected accuracy threshold. Defaults based on model_type.
        batch_size (int, optional): Batch size for data loading. Defaults based on model_type.
        k_values (list, optional): List of k values to test. Defaults to [1, 5, 20, 50, 100, 200].
        num_workers (int, optional): Number of workers for data loader. Defaults to 0.
        log_to_wandb (bool, optional): Whether to log results to W&B. Defaults to True.
    
    Returns:
        dict: Results containing accuracies, k values, best accuracy, and best k
    """
    if k_values is None:
        k_values = [1, 5, 20, 50, 100, 200]
    
    # Set model-specific parameters based on model type
    model_defaults = {
        "dinov2": {
            "model_size": "small",
            "expected_threshold": 0.91,
            "batch_size": 256
        },
        "mae": {
            "model_size": "base",
            "expected_threshold": 0.85,
            "batch_size": 256
        },
        "mambavision": {
            "model_variant": "T",
            "expected_threshold": 0.85,
            "batch_size": 64
        },
        "ijepa": {
            "model_size": "b16_1k",
            "expected_threshold": 0.85,
            "batch_size": 64
        },
        "aim": {
            "model_size": "600M",
            "expected_threshold": 0.75,
            "batch_size": 256
        }
    }
    
    # Apply default parameters if not provided
    if model_type in model_defaults:
        defaults = model_defaults[model_type]
        if model_size is None and "model_size" in defaults:
            model_size = defaults["model_size"]
        if model_variant is None and "model_variant" in defaults:
            model_variant = defaults["model_variant"]
        if expected_threshold is None and "expected_threshold" in defaults:
            expected_threshold = defaults["expected_threshold"]
        if batch_size is None and "batch_size" in defaults:
            batch_size = defaults["batch_size"]
    
    # Set some safe defaults if model type is not recognized
    if expected_threshold is None:
        expected_threshold = 0.85
    if batch_size is None:
        batch_size = 128
    
    print("\n" + "="*70)
    print(f"SANITY CHECK: {model_type.upper()} ZERO-SHOT k-NN ON CIFAR10")
    print("="*70)
    
    # Initialize W&B tracking
    run = None
    if log_to_wandb:
        if wandb.run is None:
            # No active run, create one for this sanity check
            run_name = f"sanity_check_{model_type}"
            if model_size:
                run_name += f"_{model_size}"
            elif model_variant:
                run_name += f"_{model_variant}"
                
            run = wandb.init(project="diet_finetuning", name=run_name, config={
                "model_type": model_type,
                "model_size": model_size,
                "model_variant": model_variant,
                "expected_threshold": expected_threshold,
                "batch_size": batch_size,
                "k_values": k_values
            })
        else:
            # Use existing run
            run = wandb.run
            # Log config to existing run
            run.config.update({
                f"sanity_{model_type}_model_size": model_size,
                f"sanity_{model_type}_model_variant": model_variant,
                f"sanity_{model_type}_expected_threshold": expected_threshold,
                f"sanity_{model_type}_batch_size": batch_size,
                f"sanity_{model_type}_k_values": k_values
            }, allow_val_change=True)
    
    # Load CIFAR10 dataset
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                         (0.2470, 0.2435, 0.2616))
    ])
    
    cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    # Create data loaders
    train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    # Create and load the appropriate model
    print(f"Loading fresh {model_type} model...")
    
    try:
        if model_type == "dinov2":
            sanity_model, embedding_dim = get_dinov2_model(device, model_size=model_size)
        elif model_type == "mae":
            sanity_model, embedding_dim = get_mae_model(device, model_size=model_size)
        elif model_type == "mambavision":
            sanity_model, embedding_dim = get_mambavision_model(device, model_variant=model_variant)
        elif model_type == "ijepa":
            sanity_model, embedding_dim = get_ijepa_model(device, model_size=model_size)
        elif model_type == "aim":
            sanity_model, embedding_dim = get_aim_model(device, model_size=model_size)
        else:
            raise ValueError(f"Unknown model type: {model_type}")
    except Exception as e:
        error_msg = f"Failed to load model: {e}"
        print(error_msg)
        if log_to_wandb and run:
            run.log({f"sanity_check_{model_type}_error": error_msg})
        return None
    
    sanity_model.eval()  # Set to evaluation mode
    
    # Special case for I-JEPA: check if we need to run both kNN and linear probe
    run_linear_probe = (model_type == "ijepa")
    linear_accuracy = None
    
    # Extract features from training set
    print("Extracting features from CIFAR10 training set...")
    train_features = []
    train_labels = []
    
    with torch.no_grad():
        for x, y in tqdm(train_loader, desc="Extracting train features"):
            x = x.to(device)
            try:
                feat = sanity_model(x)
                train_features.append(feat.cpu().numpy())
                train_labels.append(y.numpy())
            except Exception as e:
                print(f"Error processing batch: {e}")
                continue
    
    if not train_features:
        error_msg = "No features were extracted. Sanity check failed."
        print(error_msg)
        if log_to_wandb and run:
            run.log({f"sanity_check_{model_type}_error": error_msg})
        return None
    
    train_features = np.vstack(train_features)
    train_labels = np.concatenate(train_labels)
    
    # Extract features from test set
    print("Extracting features from CIFAR10 test set...")
    test_features = []
    test_labels = []
    
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Extracting test features"):
            x = x.to(device)
            try:
                feat = sanity_model(x)
                test_features.append(feat.cpu().numpy())
                test_labels.append(y.numpy())
            except Exception as e:
                print(f"Error processing batch: {e}")
                continue
    
    if not test_features:
        error_msg = "No test features were extracted. Sanity check failed."
        print(error_msg)
        if log_to_wandb and run:
            run.log({f"sanity_check_{model_type}_error": error_msg})
        return None
    
    test_features = np.vstack(test_features)
    test_labels = np.concatenate(test_labels)
    
    print(f"Features extracted: {train_features.shape} train, {test_features.shape} test")
    
    # Normalize features (important for k-NN)
    train_features_normalized = train_features / np.linalg.norm(train_features, axis=1, keepdims=True)
    test_features_normalized = test_features / np.linalg.norm(test_features, axis=1, keepdims=True)
    
    # Run k-NN evaluation
    if run_linear_probe:
        print("\n" + "="*50)
        print("k-NN EVALUATION")
        print("="*50)
    
    print("\nEvaluating k-NN performance:")
    print("-"*50)
    print(f"{'k value':<10} {'Accuracy':<10}")
    print("-"*50)
    
    best_acc = 0
    best_k = 0
    accuracies = []
    
    # Create a W&B table for k-NN results
    if log_to_wandb and run:
        knn_table = wandb.Table(columns=["k_value", "accuracy"])
    
    for k in k_values:
        knn = KNeighborsClassifier(n_neighbors=k, metric='cosine')
        knn.fit(train_features_normalized, train_labels)
        predictions = knn.predict(test_features_normalized)
        accuracy = accuracy_score(test_labels, predictions)
        accuracies.append(accuracy)
        print(f"{k:<10} {accuracy*100:.2f}%")
        
        # Log to W&B table
        if log_to_wandb and run:
            knn_table.add_data(k, accuracy * 100)  # Convert to percentage
        
        if accuracy > best_acc:
            best_acc = accuracy
            best_k = k
    
    print("-"*50)
    
    # Log the k-NN table to W&B
    if log_to_wandb and run:
        run.log({f"sanity_check_{model_type}_knn": knn_table})
    
    # Linear probe evaluation for I-JEPA
    if run_linear_probe:
        print("\n" + "="*50)
        print("LINEAR PROBE EVALUATION")
        print("="*50)
        
        # Convert features to PyTorch tensors
        train_features_tensor = torch.FloatTensor(train_features).to(device)
        train_labels_tensor = torch.LongTensor(train_labels).to(device)
        test_features_tensor = torch.FloatTensor(test_features).to(device)
        test_labels_tensor = torch.LongTensor(test_labels).to(device)
        
        # Set up linear probe
        num_classes = 10  # CIFAR10 has 10 classes
        linear_probe = nn.Linear(embedding_dim, num_classes).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(linear_probe.parameters(), lr=0.001)
        
        # Train linear probe
        num_epochs = 50
        linear_batch_size = 1024
        linear_probe.train()
        
        # Prepare data for batch training
        dataset = torch.utils.data.TensorDataset(train_features_tensor, train_labels_tensor)
        loader = torch.utils.data.DataLoader(dataset, batch_size=linear_batch_size, shuffle=True)
        
        # Create list to track loss for W&B
        if log_to_wandb and run:
            loss_history = []
        
        print("Training linear probe...")
        for epoch in range(num_epochs):
            total_loss = 0
            for batch_features, batch_labels in loader:
                optimizer.zero_grad()
                logits = linear_probe(batch_features)
                loss = criterion(logits, batch_labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            epoch_loss = total_loss / len(loader)
            if log_to_wandb and run:
                loss_history.append(epoch_loss)
                
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
        
        # Log linear probe training curve
        if log_to_wandb and run:
            run.log({f"sanity_check_{model_type}_linear_probe_loss": wandb.plot.line(
                table=wandb.Table(data=[[i, loss] for i, loss in enumerate(loss_history)],
                                  columns=["epoch", "loss"]),
                x="epoch",
                y="loss",
                title="Linear Probe Training Loss"
            )})
        
        # Evaluate linear probe
        linear_probe.eval()
        with torch.no_grad():
            logits = linear_probe(test_features_tensor)
            predictions = logits.argmax(dim=1).cpu().numpy()
            linear_accuracy = accuracy_score(test_labels, predictions)
        
        print(f"Linear probe accuracy: {linear_accuracy*100:.2f}%")
        
        # Log linear probe accuracy
        if log_to_wandb and run:
            run.log({f"sanity_check_{model_type}_linear_probe_acc": linear_accuracy * 100})
    
    # Determine if sanity check passed
    passed_check = best_acc >= expected_threshold
    status = "PASSED ✓" if passed_check else "FAILED ✗"
    
    print(f"Best accuracy: {best_acc*100:.2f}% (k={best_k})")
    print(f"Sanity check status: {status}")
    print(f"Expected accuracy: >{expected_threshold*100}%, Achieved: {best_acc*100:.2f}%")
    print("="*70)
    
    # Log summary results to W&B
    if log_to_wandb and run:
        run.log({
            f"sanity_check_{model_type}_best_acc": best_acc * 100,
            f"sanity_check_{model_type}_best_k": best_k,
            f"sanity_check_{model_type}_passed": passed_check
        })
        
        # Create summary table
        summary_table = wandb.Table(
            columns=["model", "method", "best_accuracy", "best_k", "threshold", "passed_check"]
        )
        
        # Add k-NN row
        summary_table.add_data(
            model_type, 
            "k-NN", 
            best_acc * 100, 
            best_k,
            expected_threshold * 100,
            "✓" if passed_check else "✗"
        )
        
        # Add linear probe row if applicable
        if run_linear_probe and linear_accuracy is not None:
            linear_passed = linear_accuracy >= expected_threshold
            summary_table.add_data(
                model_type,
                "Linear Probe",
                linear_accuracy * 100,
                "N/A",
                expected_threshold * 100,
                "✓" if linear_passed else "✗"
            )
        
        # Log the summary table
        run.log({f"sanity_check_{model_type}_summary": summary_table})
    
    # Create visualization plot
    if run_linear_probe and linear_accuracy is not None:
        plt.figure(figsize=(15, 6))
        
        # Plot k-NN results
        plt.subplot(1, 2, 1)
        plt.plot(k_values, [acc*100 for acc in accuracies], marker='o', linewidth=2)
        plt.axhline(y=expected_threshold*100, color='r', linestyle='--', label=f'{expected_threshold*100}% threshold')
        plt.xlabel('k value')
        plt.ylabel('Accuracy (%)')
        plt.title(f'{model_type.upper()} Zero-Shot k-NN Performance on CIFAR10')
        plt.grid(True)
        plt.legend()
        plt.xticks(k_values)
        
        # Plot comparison of methods
        plt.subplot(1, 2, 2)
        methods = ['k-NN (best)', 'Linear Probe']
        method_accuracies = [best_acc*100, linear_accuracy*100]
        
        plt.bar(methods, method_accuracies, color=['blue', 'orange'])
        plt.ylabel('Accuracy (%)')
        plt.title('Zero-Shot Evaluation Methods Comparison')
        plt.grid(axis='y', alpha=0.3)
        plt.axhline(y=expected_threshold*100, color='r', linestyle='--', label=f'{expected_threshold*100}% threshold')
        plt.legend()
        
        # Add text on top of bars
        for i, v in enumerate(method_accuracies):
            plt.text(i, v+1, f"{v:.2f}%", ha='center')
    else:
        plt.figure(figsize=(10, 6))
        plt.plot(k_values, [acc*100 for acc in accuracies], marker='o', linewidth=2)
        plt.axhline(y=expected_threshold*100, color='r', linestyle='--', label=f'Expected threshold ({expected_threshold*100}%)')
        plt.xlabel('k value')
        plt.ylabel('Accuracy (%)')
        plt.title(f'{model_type.upper()} Zero-Shot k-NN Performance on CIFAR10')
        plt.grid(True)
        plt.legend()
        plt.xticks(k_values)
    
    plt.tight_layout()
    
    # Log the figure to W&B
    if log_to_wandb and run:
        run.log({f"sanity_check_{model_type}_plot": wandb.Image(plt)})
    
    # Display the plot
    plt.show()
    
    # Prepare return value
    results = {
        'accuracies': accuracies,
        'k_values': k_values,
        'best_acc': best_acc,
        'best_k': best_k,
        'passed_check': passed_check,
        'expected_threshold': expected_threshold
    }
    
    if run_linear_probe and linear_accuracy is not None:
        results['linear_probe_acc'] = linear_accuracy
    
    return results

# Example usage:
# sanity_results_dinov2 = unified_sanity_check("dinov2")  # W&B logging included
# 
# # If you don't want W&B logging for a specific run:
# sanity_results_mae = unified_sanity_check("mae", log_to_wandb=False)

## AIMModel

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_aim_model(device, model_size="600M"):
    """Create AIM model using the properly installed AIM package
    
    Args:
        device: The device to load the model on
        model_size: Size of the model - "600M", "1B", "3B", or "7B"
        
    Returns:
        model: The wrapped AIM model
        embedding_dim: The ACTUAL embedding dimension from the loaded model
    """
    print(f"Loading AIM-{model_size} model...")
    
    # These are just reference values, we'll detect the actual dimension
    dim_map = {
        "600M": 768,    # Reference value, will be overridden by actual dimension
        "1B": 1024,
        "3B": 1536,
        "7B": 2048
    }
    
    # Map model size to model ID
    model_map = {
        "600M": "apple/aim-600M",
        "1B": "apple/aim-1B",
        "3B": "apple/aim-3B",
        "7B": "apple/aim-7B"
    }
    
    model_id = model_map[model_size]
    
    try:
        # Use the proper AIM imports
        from aim.v1.torch.models import AIMForImageClassification
        from aim.v1.torch.data import val_transforms
        
        # Load the model and transforms
        print(f"Loading AIM model: {model_id}")
        base_model = AIMForImageClassification.from_pretrained(model_id)
        transform = val_transforms()
        print(f"Successfully loaded AIM model: {model_id}")
        
    except Exception as e:
        print(f"Error loading AIM model: {e}")
        raise ValueError(f"Failed to load AIM model: {e}")
    
    # UNFREEZE ALL PARAMETERS
    print("Unfreezing all AIM parameters...")
    unfrozen_params = 0
    for param in base_model.parameters():
        param.requires_grad = True
        unfrozen_params += 1
    print(f"Unfrozen {unfrozen_params} parameters in AIM backbone")
    
    # Define wrapper
    class AIMWrapper(nn.Module):
        def __init__(self, model, transform):
            super().__init__()
            self.model = model
            self.transform = transform
            self._feature_dim_detected = False
            self._feature_dim = None
            
        def forward(self, x):
            # Make x require gradients to force gradient flow
            x = x.detach().requires_grad_(True)
            
            # Process smaller batches if needed
            batch_size = x.shape[0]
            if batch_size > 8 and x.device.type == 'cuda':
                # Process in chunks to save memory
                outputs_list = []
                chunk_size = 4 if model_size in ["3B", "7B"] else 8
                
                for i in range(0, batch_size, chunk_size):
                    # Get batch chunk
                    x_chunk = x[i:i+chunk_size]
                    
                    # Resize to expected input (224x224)
                    if x_chunk.shape[-1] != 224:
                        x_chunk = F.interpolate(x_chunk, size=(224, 224), mode='bilinear', align_corners=False)
                    
                    # Extract features
                    features = self._extract_features(x_chunk)
                    outputs_list.append(features)
                
                return torch.cat(outputs_list, dim=0)
            else:
                # Standard processing for smaller batches
                if x.shape[-1] != 224:
                    x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
                
                # Extract features
                features = self._extract_features(x)
                return features
                
        def _extract_features(self, x):
            """Extract features from the AIM model"""
            # Get model output
            output = self.model(x)
            
            # Check output on first run only
            if not self._feature_dim_detected:
                print(f"AIM model output type: {type(output)}")
                if isinstance(output, tuple):
                    print(f"Output tuple length: {len(output)}")
                    print(f"Features shape: {output[1].shape}")
                    self._feature_dim = output[1].shape[1]
                else:
                    print(f"Output shape: {output.shape}")
                    self._feature_dim = output.shape[1]
                
                print(f"Detected feature dimension: {self._feature_dim}")
                self._feature_dim_detected = True
            
            # Extract features
            if isinstance(output, tuple) and len(output) >= 2:
                return output[1]  # Features
            else:
                # Fallback
                return output
            
        @property
        def feature_dim(self):
            """Get the detected feature dimension"""
            if self._feature_dim is None:
                # This will force feature dimension detection with a dummy input
                with torch.no_grad():
                    dummy_input = torch.randn(1, 3, 224, 224).to(next(self.parameters()).device)
                    _ = self._extract_features(dummy_input)
            return self._feature_dim
    
    model = AIMWrapper(base_model, transform).to(device)
    
    # Detect the actual embedding dimension
    with torch.no_grad():
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        _ = model(dummy_input)
    
    embedding_dim = model.feature_dim
    print(f"AIM-{model_size} loaded. Detected embedding dimension: {embedding_dim}")
    
    return model, embedding_dim

## MambaVision

In [None]:
def get_mambavision_model(device, model_variant="T"):
    """Create MambaVision model using direct feature extraction approach
    
    Args:
        device: The device to put the model on
        model_variant: Model variant (T, T2, S, B, L, L2, etc.) or full name
        
    Returns:
        model: Wrapped MambaVision model for feature extraction
        embedding_dim: Embedding dimension of the model
    """
    # Map model variants to their configurations based on the documentation
    model_configs = {
        # ImageNet-1K models
        "T": {"id": "nvidia/MambaVision-T-1K", "dim": 512, "res": 224, "params": 31.8},
        "T2": {"id": "nvidia/MambaVision-T2-1K", "dim": 512, "res": 224, "params": 35.1},
        "S": {"id": "nvidia/MambaVision-S-1K", "dim": 768, "res": 224, "params": 50.1},
        "B": {"id": "nvidia/MambaVision-B-1K", "dim": 1024, "res": 224, "params": 97.7},
        "L": {"id": "nvidia/MambaVision-L-1K", "dim": 1280, "res": 224, "params": 227.9},
        "L2": {"id": "nvidia/MambaVision-L2-1K", "dim": 1408, "res": 224, "params": 241.5},
        
        # ImageNet-21K models
        "B-21K": {"id": "nvidia/MambaVision-B-21K", "dim": 1024, "res": 224, "params": 97.7},
        "L-21K": {"id": "nvidia/MambaVision-L-21K", "dim": 1280, "res": 224, "params": 227.9},
        "L2-512-21K": {"id": "nvidia/MambaVision-L2-512-21K", "dim": 1408, "res": 512, "params": 241.5},
        "L3-256-21K": {"id": "nvidia/MambaVision-L3-256-21K", "dim": 1568, "res": 256, "params": 739.6},
        "L3-512-21K": {"id": "nvidia/MambaVision-L3-512-21K", "dim": 1568, "res": 512, "params": 739.6},
    }
    
    # Handle full model names too (e.g., "MambaVision-T" or just "T")
    if model_variant.startswith("MambaVision-"):
        model_variant = model_variant[12:]  # Remove "MambaVision-" prefix
    
    if model_variant not in model_configs:
        raise ValueError(f"Model variant {model_variant} not supported. Choose from {list(model_configs.keys())}")
    
    config = model_configs[model_variant]
    model_id = config["id"]
    embedding_dim = config["dim"]
    input_res = config["res"]
    
    print(f"Loading MambaVision-{model_variant} model...")
    print(f"Model ID: {model_id}")
    print(f"Embedding dimension: {embedding_dim}")
    print(f"Input resolution: {input_res}x{input_res}")
    print(f"Parameter count: {config['params']} million")
    
    try:
        from transformers import AutoModel
        
        # Use AutoModel for feature extraction
        base_model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
        print(f"Successfully loaded MambaVision model")
        
        # Unfreeze all parameters
        print("Unfreezing all MambaVision parameters...")
        unfrozen_params = 0
        for param in base_model.parameters():
            param.requires_grad = True
            unfrozen_params += 1
        print(f"Unfrozen {unfrozen_params} parameters in MambaVision backbone")
        
        # Create a wrapper class for feature extraction
        class MambaVisionWrapper(nn.Module):
            def __init__(self, model, input_res, emb_dim):
                super().__init__()
                self.model = model
                self.input_res = input_res
                self.emb_dim = emb_dim
                
            def forward(self, x):
                # Make x require gradients to force gradient flow
                x = x.detach().requires_grad_(True)
                
                # Process in small batches to save memory
                batch_size = x.shape[0]
                if batch_size > 4 and x.device.type == 'cuda':  # Use very small batch size for large models
                    outputs_list = []
                    for i in range(0, batch_size, 4):
                        # Get batch chunk
                        x_chunk = x[i:i+4]
                        
                        # Resize to expected input size
                        if x_chunk.shape[-1] != self.input_res:
                            x_chunk = F.interpolate(x_chunk, size=(self.input_res, self.input_res), 
                                                   mode='bilinear', align_corners=False)
                        
                        # Extract features from model
                        try:
                            # MambaVision AutoModel returns (avg_pool, features)
                            avg_pool, _ = self.model(x_chunk)
                            outputs_list.append(avg_pool)
                        except Exception as e:
                            print(f"Error in forward pass: {e}")
                            # Return zeros if there's an error
                            dummy = torch.zeros((x_chunk.size(0), self.emb_dim), device=x_chunk.device)
                            outputs_list.append(dummy)
                    
                    return torch.cat(outputs_list, dim=0)
                else:
                    # Process as a single batch
                    if x.shape[-1] != self.input_res:
                        x = F.interpolate(x, size=(self.input_res, self.input_res), 
                                         mode='bilinear', align_corners=False)
                    
                    try:
                        # MambaVision AutoModel returns (avg_pool, features)
                        avg_pool, _ = self.model(x)
                        return avg_pool
                    except Exception as e:
                        print(f"Error in forward pass: {e}")
                        return torch.zeros((x.size(0), self.emb_dim), device=x.device)
        
        # Create and return wrapped model
        model = MambaVisionWrapper(base_model, input_res, embedding_dim).to(device)
        return model, embedding_dim
    
    except Exception as e:
        print(f"Error setting up MambaVision: {e}")
        raise ValueError(f"Failed to load MambaVision. Consider using DINOv2 or MAE instead. Error: {e}")

## MAE

In [None]:
def get_mae_model(device, model_size="base"):
    """Create MAE model with memory optimization"""
    print(f"Loading MAE-{model_size} model...")
    
    # Model size to embedding dimension mapping
    dim_map = {
        "base": 768,      # ViT-Base dimension
        "large": 1024,    # ViT-Large dimension
        "huge": 1280      # ViT-Huge dimension
    }
    
    # Map model size to Hugging Face model ID
    model_map = {
        "base": "facebook/vit-mae-base",
        "large": "facebook/vit-mae-large",  # This may need verification
        "huge": "facebook/vit-mae-huge"     # This may need verification
    }
    
    model_id = model_map[model_size]
    
    try:
        # Import the required classes
        from transformers import AutoImageProcessor, ViTMAEForPreTraining
        
        # Load processor and model
        processor = AutoImageProcessor.from_pretrained(model_id)
        base_model = ViTMAEForPreTraining.from_pretrained(model_id)
        print(f"Successfully loaded MAE model from {model_id}")
    except Exception as e:
        print(f"Error loading MAE model: {e}")
        raise ValueError(f"Could not load MAE model. Please check if transformers is installed.")
    
    # UNFREEZE ALL PARAMETERS - exactly like your DINOv2 function
    print("Unfreezing all MAE parameters...")
    unfrozen_params = 0
    for param in base_model.parameters():
        param.requires_grad = True
        unfrozen_params += 1
    print(f"Unfrozen {unfrozen_params} parameters in MAE backbone")
    
    # Define wrapper with same structure as DINOv2Wrapper
    class MAEWrapper(nn.Module):
        def __init__(self, model, processor):
            super().__init__()
            self.model = model
            self.processor = processor
            
        def forward(self, x):
            # Make x require gradients to force gradient flow
            x = x.detach().requires_grad_(True)
            
            # Process smaller batches if needed
            batch_size = x.shape[0]
            if batch_size > 16 and x.device.type == 'cuda':
                # Process in chunks to save memory
                outputs_list = []
                for i in range(0, batch_size, 16):
                    # Get batch chunk
                    x_chunk = x[i:i+16]
                    # Resize to expected input (224x224)
                    if x_chunk.shape[-1] != 224:
                        x_chunk = F.interpolate(x_chunk, size=(224, 224), mode='bilinear', align_corners=False)
                    
                    # For feature extraction, we use the encoder part of MAE
                    features = self.model.vit(x_chunk).last_hidden_state[:, 0]  # CLS token
                    outputs_list.append(features)
                
                return torch.cat(outputs_list, dim=0)
            else:
                # Standard processing for smaller batches
                if x.shape[-1] != 224:
                    x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
                
                # For feature extraction, we use the encoder part of MAE
                features = self.model.vit(x).last_hidden_state[:, 0]  # CLS token
                return features
    
    model = MAEWrapper(base_model, processor).to(device)
    embedding_dim = dim_map[model_size]
    
    print(f"MAE-{model_size} loaded. Embedding dimension: {embedding_dim}")
    return model, embedding_dim

## DinoV2

In [None]:
def get_dinov2_model(device, model_size="small"):
    """Create DINOv2 model with memory optimization"""
    print(f"Loading DINOv2-{model_size} model...")
    
    # Model size to embedding dimension mapping
    dim_map = {
        "small": 384,
        "base": 768,
        "large": 1024,
        "giant": 1536
    }
    
    # Load model and processor
    model_name = f"facebook/dinov2-{model_size}"
    processor = AutoImageProcessor.from_pretrained(model_name)
    base_model = AutoModel.from_pretrained(model_name)
    
    # Disable gradient checkpointing as it may interfere with gradient flow
    # base_model.gradient_checkpointing_enable()  # Comment out or remove this line
    
    # UNFREEZE ALL PARAMETERS
    print("Unfreezing all DINOv2 parameters...")
    unfrozen_params = 0
    for param in base_model.parameters():
        param.requires_grad = True
        unfrozen_params += 1
    print(f"Unfrozen {unfrozen_params} parameters in DINOv2 backbone")
    
    class DINOv2Wrapper(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
            self.processor = processor
            
        def forward(self, x):
            # Make x require gradients to force gradient flow
            x = x.detach().requires_grad_(True)
            
            # Process smaller batches if needed
            batch_size = x.shape[0]
            if batch_size > 16 and x.device.type == 'cuda':
                # Process in chunks to save memory
                outputs_list = []
                for i in range(0, batch_size, 16):
                    # Get batch chunk
                    x_chunk = x[i:i+16]
                    # Resize to expected input (224x224)
                    if x_chunk.shape[-1] != 224:
                        x_chunk = F.interpolate(x_chunk, size=(224, 224), mode='bilinear', align_corners=False)
                    # Process chunk WITHOUT autocast and with gradient tracking
                    chunk_output = self.model(x_chunk)
                    outputs_list.append(chunk_output.last_hidden_state[:, 0])
                return torch.cat(outputs_list, dim=0)
            else:
                # Standard processing for smaller batches
                if x.shape[-1] != 224:
                    x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
                # Process WITHOUT autocast and with gradient tracking
                outputs = self.model(x)
                return outputs.last_hidden_state[:, 0]
    
    model = DINOv2Wrapper(base_model).to(device)
    embedding_dim = dim_map[model_size]
    
    print(f"DINOv2-{model_size} loaded. Embedding dimension: {embedding_dim}")
    return model, embedding_dim

## IJEPA

In [None]:
def get_ijepa_model(device, model_size):
    """
    Create I-JEPA model using the Hugging Face transformers library
    with the correct model identifiers and dimensions.
    
    Args:
        device: The device to put the model on
        model_size: Model size, one of "b16_1k", "l14_22k", "h14_1k", etc.
        
    Returns:
        model: I-JEPA model wrapped in a custom wrapper
        embedding_dim: Embedding dimension of the model
    """
    # Map model size to exact model IDs based on the available options and their correct dimensions
    model_map = {
        "b16_1k": {"id": "facebook/ijepa_vith16_1k", "dim": 1280, "img_size": 448},
        "b16_22k": {"id": "facebook/ijepa_vitg16_22k", "dim": 1280, "img_size": 448},
        "l14_22k": {"id": "facebook/ijepa_vith14_22k", "dim": 1280, "img_size": 448},
        "h14_1k": {"id": "facebook/ijepa_vith14_1k", "dim": 1280, "img_size": 448},
    }
    
    if model_size not in model_map:
        raise ValueError(f"Model size {model_size} not supported. Choose from {list(model_map.keys())}")
    
    model_id = model_map[model_size]["id"]
    embedding_dim = model_map[model_size]["dim"]
    img_size = model_map[model_size]["img_size"]
    
    print(f"Loading I-JEPA model {model_id} using transformers...")
    print(f"Embedding dimension: {embedding_dim}, Image size: {img_size}x{img_size}")
    
    try:
        # Import the required libraries
        from transformers import AutoModel
        
        # Load the model with the correct ID
        base_model = AutoModel.from_pretrained(model_id)
        print(f"Successfully loaded I-JEPA model {model_id}")
    except Exception as e:
        print(f"Error loading I-JEPA model: {e}")
        raise ValueError(f"Could not load I-JEPA model. Please check if transformers is installed and the model ID is correct.")
    
    # Unfreeze all parameters as requested
    print("Unfreezing all I-JEPA parameters...")
    unfrozen_params = 0
    for param in base_model.parameters():
        param.requires_grad = True
        unfrozen_params += 1
    print(f"Unfrozen {unfrozen_params} parameters in I-JEPA backbone")
    
    # Define wrapper class that uses the correct image size
    class IJEPAWrapper(nn.Module):
        def __init__(self, model, img_size=448):
            super().__init__()
            self.model = model
            self.img_size = img_size
            
        def forward(self, x):
            # Make x require gradients for proper gradient flow
            x = x.detach().requires_grad_(True)
            
            # Process smaller batches if needed (for memory efficiency)
            batch_size = x.shape[0]
            if batch_size > 8 and x.device.type == 'cuda':  # Reduced batch size for large images
                # Process in chunks
                outputs_list = []
                for i in range(0, batch_size, 8):
                    # Get batch chunk
                    x_chunk = x[i:i+8]
                    # Resize to expected input size (448x448 by default)
                    if x_chunk.shape[-1] != self.img_size:
                        x_chunk = F.interpolate(x_chunk, size=(self.img_size, self.img_size), 
                                               mode='bilinear', align_corners=False)
                    
                    try:
                        # Forward pass with the correct image size
                        outputs = self.model(x_chunk)
                        # Extract embeddings using mean pooling
                        features = outputs.last_hidden_state.mean(dim=1)
                        outputs_list.append(features)
                    except Exception as e:
                        print(f"Error in forward pass: {e}")
                        try:
                            # Try with interpolate_pos_encoding=True
                            outputs = self.model(x_chunk, interpolate_pos_encoding=True)
                            features = outputs.last_hidden_state.mean(dim=1)
                            outputs_list.append(features)
                        except Exception as e2:
                            print(f"Second attempt failed: {e2}")
                            # Return zeros as a last resort to avoid crashing
                            dummy_features = torch.zeros((x_chunk.size(0), embedding_dim), device=x_chunk.device)
                            outputs_list.append(dummy_features)
                
                return torch.cat(outputs_list, dim=0)
            else:
                # Standard processing for smaller batches
                if x.shape[-1] != self.img_size:
                    x = F.interpolate(x, size=(self.img_size, self.img_size), 
                                     mode='bilinear', align_corners=False)
                
                try:
                    # Forward pass with the correct image size
                    outputs = self.model(x)
                    # Extract embeddings using mean pooling
                    features = outputs.last_hidden_state.mean(dim=1)
                except Exception as e:
                    print(f"Error in forward pass: {e}")
                    try:
                        # Try with interpolate_pos_encoding=True
                        outputs = self.model(x, interpolate_pos_encoding=True)
                        features = outputs.last_hidden_state.mean(dim=1)
                    except Exception as e2:
                        print(f"Second attempt failed: {e2}")
                        # Return zeros as a last resort to avoid crashing
                        features = torch.zeros((x.size(0), embedding_dim), device=x.device)
                
                return features
    
    # Create and return wrapped model
    model = IJEPAWrapper(base_model, img_size=img_size).to(device)
    print(f"I-JEPA model wrapper created with image size {img_size}x{img_size}")
    return model, embedding_dim


## Paramaters Selection

In [None]:

# Set seed for reproducibility
#torch.manual_seed(42)
#np.random.seed(42)
#torch.backends.cudnn.deterministic = True

# Hyperparameters you can easily modify
num_epoch = 30         # Number of training epochs
batch_size = 20         # Batch size: dinov2 128, aim:0.01 
da_strength = 1         # 3
lr =   5e-4          # resnesT: 1e-4, 5e-4 Dinov2:  1e-6(terrible),1e-7 (didntwork,5e-4 (dino For cifar best so far) 
weight_decay = 0.05     # 0.05
label_smoothing = 0.3    # resnest: 0.3, 0.5, Dinov2:  label_smoothing = 0.1  # Lower from 0.3
limit_data = 1000     # or np.inf for full dataset# 


## DataSet

In [None]:

class DatasetWithIndices(Dataset):
    def __init__(self, dataset, num_diet_classes=200):
        self.dataset = dataset
        self.num_diet_classes = num_diet_classes
        # Assign each sample to one of num_diet_classes
        self.class_assignments = torch.randint(0, num_diet_classes, (len(dataset),))
        
    def __getitem__(self, n):
        # Convert tensor index to int if needed
        if isinstance(n, torch.Tensor):
            n = int(n.item())
        
        # Get sample from wrapped dataset
        item = self.dataset[n]
        
        # Handle different return formats
        if isinstance(item, tuple) and len(item) >= 2:
            x, y = item[0], item[1]
        else:
            x = item
            y = torch.tensor(0)  # Default label if not provided
        
        # Ensure label is a proper tensor with correct dimension
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.long)
        
        # Make sure y has the right dimension (not a scalar)
        if y.dim() == 0:
            y = y.view(1)
        
        # Ensure diet class is a proper tensor with correct dimension
        diet_class = self.class_assignments[n]
        if diet_class.dim() == 0:
            diet_class = diet_class.view(1)
        
        return x, y, diet_class

        
    def __len__(self):
        return int(len(self.dataset))
# Add after your existing functions

def get_dataset(dataset_name="cifar10", root='./data'):
    """
    Load the specified dataset with predetermined statistics.
    """
    # Predetermined mean and std values for common datasets
    dataset_stats = {
        "cifar10": {
            "mean": (0.4914, 0.4822, 0.4465),
            "std": (0.2470, 0.2435, 0.2616),
            "input_size": 32,
            "is_rgb": True
        },
        "pathmnist": {
            "mean": (0.5, 0.5, 0.5),
            "std": (0.5, 0.5, 0.5),
            "input_size": 28,
            "is_rgb": True
        },
        "chestmnist": {
            "mean": (0.4984),
            "std": (0.2483),
            "input_size": 28,
            "is_rgb": False
        },
        "dermamnist": {
            "mean": (0.7634, 0.5423, 0.5698),
            "std": (0.0841, 0.1246, 0.1043),
            "input_size": 28,
            "is_rgb": True
        },
        "octmnist": {
            "mean": (0.1778),
            "std": (0.1316),
            "input_size": 28,
            "is_rgb": False
        },
        "pneumoniamnist": {
            "mean": (0.5060),
            "std": (0.2537),
            "input_size": 28,
            "is_rgb": False
        },
        "plantnet300k": {
            "mean": (0.485, 0.456, 0.406),  # ImageNet stats as starting point
            "std": (0.229, 0.224, 0.225),
            "input_size": 224,  # PlantNet images are resized to 224
            "is_rgb": True
        },
        "galaxy10_decals": {
            "mean": (0.097, 0.097, 0.097),  # Approximate for astronomy images (dark background)
            "std": (0.174, 0.164, 0.156),   # Astronomical images have different distribution
            "input_size": 256,  # Original image size is 256x256
            "is_rgb": True
        },

        "crop14_balance": {
            # Based on the dataset card, images are rescaled to a maximum side length of 512.
            "mean": (0.485, 0.456, 0.406),  # Using ImageNet stats as a placeholder
            "std": (0.229, 0.224, 0.225),
            "input_size": 512,
            "is_rgb": True
        }
    } 
    
    # Define improved HuggingFace dataset wrapper
    from torch.utils.data import Dataset
    import torchvision.transforms as transforms
    from PIL import Image
        
    class HFImageDataset(Dataset):
        def __init__(self, hf_dataset, transform=None, input_size=224):
            self.dataset = hf_dataset
            self.transform = transform
            self.input_size = input_size
            self.resize = transforms.Resize((input_size, input_size))
            
        def __len__(self):
            return int(len(self.dataset))
            
        def __getitem__(self, idx):
            # Convert idx to an integer if needed.
            if isinstance(idx, torch.Tensor):
                idx = int(idx.item())
            item = self.dataset[idx]
            image = item['image']
            label = item.get('label', item.get('labels'))
            if not isinstance(image, Image.Image):
                try:
                    image = Image.fromarray(image)
                except Exception as e:
                    print(f"Warning: Unexpected image format at index {idx}: {e}")
            image = self.resize(image)
            if self.transform:
                image = self.transform(image)
            return image, label

    
    # Load dataset based on name
    if dataset_name.lower() == "cifar10":
        train_dataset = datasets.CIFAR10(root=root, train=True, download=True)
        test_dataset = datasets.CIFAR10(root=root, train=False, download=True)
        num_classes = 10
        
    elif MEDMNIST_AVAILABLE and dataset_name.lower() in INFO.keys():
        data_flag = dataset_name.lower()
        info = INFO[data_flag]
        DataClass = getattr(medmnist, info['python_class'])
        
        # Get dataset information
        num_classes = len(info['label'])
        
        train_dataset = DataClass(split='train', download=True, root=root)
        test_dataset = DataClass(split='test', download=True, root=root)
        
        print(f"Dataset: {info['description']}")
        print(f"Task: {info['task']}")
        print(f"Number of classes: {num_classes}")
    
    elif dataset_name.lower() == "plantnet300k":
        # For PlantNet300K, we'll use HuggingFace datasets
        try:
            from datasets import load_dataset
            
            print("Loading PlantNet300K dataset from HuggingFace...")
            
            # Load the dataset
            dataset = load_dataset("mikehemberger/plantnet300K")
            
            # Number of classes is 85 according to the dataset card
            num_classes = 85
            
            # Get the input_size from stats
            input_size = dataset_stats["plantnet300k"]["input_size"]
            
            # Create train and test datasets with consistent sizing
            train_dataset = HFImageDataset(dataset['train'], input_size=input_size)
            
            # Use validation set as test set
            if 'validation' in dataset:
                test_dataset = HFImageDataset(dataset['validation'], input_size=input_size)
            else:
                test_dataset = HFImageDataset(dataset['test'], input_size=input_size)
            
            print(f"PlantNet300K loaded: {len(train_dataset)} training, {len(test_dataset)} test samples")
            print(f"Number of classes: {num_classes}")
            print(f"All images will be resized to {input_size}x{input_size}")
            
        except ImportError:
            raise ImportError("HuggingFace datasets library is required for PlantNet300K. Install with 'pip install datasets'")
        except Exception as e:
            raise ValueError(f"Error loading PlantNet300K dataset: {e}")
    
    elif dataset_name.lower() == "galaxy10_decals":
        # For Galaxy10 DECals, we'll use HuggingFace datasets
        try:
            from datasets import load_dataset
            
            print("Loading Galaxy10 DECals dataset from HuggingFace...")
            
            # Load the dataset
            dataset = load_dataset("matthieulel/galaxy10_decals")
            
            # Number of classes is 10 according to the dataset card
            num_classes = 10
            
            # Get the input_size from stats
            input_size = dataset_stats["galaxy10_decals"]["input_size"]
            
            # Create train and test datasets with consistent sizing
            train_dataset = HFImageDataset(dataset['train'], input_size=input_size)
            test_dataset = HFImageDataset(dataset['test'], input_size=input_size)
            
            print(f"Galaxy10 DECals loaded: {len(train_dataset)} training, {len(test_dataset)} test samples")
            print(f"Number of classes: {num_classes}")
            print(f"All images will be resized to {input_size}x{input_size}")
            print("Galaxy class labels:")
            print("0: Disturbed Galaxies")
            print("1: Merging Galaxies")
            print("2: Round Smooth Galaxies")
            print("3: In-between Round Smooth Galaxies") 
            print("4: Cigar Shaped Smooth Galaxies")
            print("5: Barred Spiral Galaxies")
            print("6: Unbarred Tight Spiral Galaxies")
            print("7: Unbarred Loose Spiral Galaxies")
            print("8: Edge-on Galaxies without Bulge")
            print("9: Edge-on Galaxies with Bulge")
            
        except ImportError:
            raise ImportError("HuggingFace datasets library is required for Galaxy10 DECals. Install with 'pip install datasets'")
        except Exception as e:
            raise ValueError(f"Error loading Galaxy10 DECals dataset: {e}")


    elif dataset_name.lower() == "crop14_balance":
        try:
            from datasets import load_dataset
            print("Loading crop14_balance dataset from Hugging Face (gary109/crop14_balance)...")
            dataset = load_dataset("gary109/crop14_balance")
            # Use the provided splits; here, 'train' and 'validation' are available
            train_dataset_hf = dataset["train"]
            test_dataset_hf = dataset["validation"]
            num_classes = 14  # As given in the features ("14 classes") 
            input_size = dataset_stats["crop14_balance"]["input_size"]
            train_dataset = HFImageDataset(train_dataset_hf, transform=None, input_size=input_size)
            test_dataset = HFImageDataset(test_dataset_hf, transform=None, input_size=input_size)
            print(f"crop14_balance loaded: {len(train_dataset)} training, {len(test_dataset)} test samples")
        except Exception as e:
            raise ValueError(f"Error loading crop14_balance dataset: {e}")


    else:
        raise ValueError(f"Dataset {dataset_name} not supported or MedMNIST not installed")
    
    
    # Get stats from our predefined dictionary
    stats = dataset_stats.get(dataset_name.lower(), {
        "mean": (0.5,),
        "std": (0.5,),
        "input_size": 28,
        "is_rgb": False
    })
    


    return train_dataset, test_dataset, num_classes, stats["input_size"], stats["mean"], stats["std"], stats["is_rgb"]

def calculate_dataset_stats(dataset, batch_size=64, max_samples=10000):
    """Calculate mean and std for dataset
    
    Args:
        dataset: PyTorch dataset or HuggingFace dataset
        batch_size: Batch size for loading
        max_samples: Maximum number of samples to use (for large datasets)
    
    Returns:
        mean, std as lists
    """
    from torch.utils.data import DataLoader, Subset
    import random
    
    # Limit samples for large datasets
    if hasattr(dataset, '__len__') and len(dataset) > max_samples:
        indices = random.sample(range(len(dataset)), max_samples)
        dataset_subset = Subset(dataset, indices)
    else:
        dataset_subset = dataset
    
    # Create a copy of the dataset with only ToTensor transform
    if hasattr(dataset, 'transform'):
        # Standard PyTorch dataset
        original_transform = dataset.transform
        dataset.transform = torchvision.transforms.ToTensor()
    elif hasattr(dataset, 'dataset') and hasattr(dataset.dataset, 'transform'):
        # Handle subset case
        original_transform = dataset.dataset.transform
        dataset.dataset.transform = torchvision.transforms.ToTensor()
    else:
        # Create a wrapper for HuggingFace datasets or other types
        class StatsDataset(torch.utils.data.Dataset):
            def __init__(self, original_dataset):
                self.dataset = original_dataset
                self.transform = torchvision.transforms.ToTensor()
            
            def __len__(self):
                return len(self.dataset)
            
            def __getitem__(self, idx):
                if hasattr(self.dataset, '__getitem__'):
                    item = self.dataset[idx]
                    if isinstance(item, tuple) and len(item) >= 2:
                        img, label = item[0], item[1]
                    else:
                        # For HuggingFace datasets
                        img = item['image']
                        label = item.get('label', 0)
                else:
                    # Fallback for unusual dataset structures
                    raise ValueError("Dataset structure not supported for statistics calculation")
                
                if self.transform:
                    img = self.transform(img)
                
                return img, label
        
        dataset_subset = StatsDataset(dataset_subset)
        original_transform = None
    
    # Create loader
    loader = DataLoader(
        dataset_subset, 
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    
    print("Calculating dataset statistics...")
    for data, _ in tqdm(loader):
        # Check if data is already a tensor
        if not isinstance(data, torch.Tensor):
            print(f"Warning: Expected tensor but got {type(data)}. Skipping batch.")
            continue
            
        # Handle both single-channel and multi-channel images
        if data.dim() == 3:  # [batch, height, width]
            data = data.unsqueeze(1)  # Add channel dimension
        
        # Mean over batch, height and width, but not over channels
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1
    
    # Restore original transform
    if hasattr(dataset, 'transform') and original_transform is not None:
        dataset.transform = original_transform
    elif hasattr(dataset, 'dataset') and hasattr(dataset.dataset, 'transform') and original_transform is not None:
        dataset.dataset.transform = original_transform
    
    # Calculate mean and std
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean**2) ** 0.5
    
    return mean.tolist(), std.tolist()



## Metrics

In [None]:
def zero_shot_eval(net, test_loader, num_classes, eval_id=None):
    """Evaluate model using zero-shot methods"""
    if eval_id is None:
        eval_id = int(time.time()) % 10000

    start_time = time.time()
    print("Extracting features for zero-shot evaluation...")

    features = []
    labels = []
    net.eval()

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Extracting features"):
            # Unpack batch flexibly
            if isinstance(batch, (list, tuple)) and len(batch) == 3:
                x, y, _ = batch
            elif isinstance(batch, (list, tuple)) and len(batch) == 2:
                x, y = batch
            else:
                raise ValueError("Unexpected batch structure")
            x = x.to(device)
            feat = net(x)
            features.append(feat.cpu().numpy())
            labels.append(y.numpy())

    features = np.vstack(features)
    labels = np.concatenate(labels).ravel()

    feature_hash = hash(str(features[:3].sum()))
    print(f"Feature hash: {feature_hash} (should change between evaluations)")
    print(f"Features extracted: {features.shape}, time: {time.time() - start_time:.2f}s")

    results = {}
    print("Running k-NN evaluation...")
    knn_time = time.time()
    knn = KNeighborsClassifier(n_neighbors=10)
    knn.fit(features, labels)
    knn_pred = knn.predict(features)
    results["knn_acc"] = accuracy_score(labels, knn_pred)
    print(f"k-NN accuracy: {results['knn_acc']:.4f}, time: {time.time() - knn_time:.2f}s")

    print("Running k-means clustering evaluation...")
    kmeans_time = time.time()
    kmeans = KMeans(n_clusters=num_classes, random_state=0, n_init=10)
    cluster_pred = kmeans.fit_predict(features)
    results["kmeans_ari"] = adjusted_rand_score(labels, cluster_pred)
    results["kmeans_nmi"] = normalized_mutual_info_score(labels, cluster_pred)
    print(f"k-means ARI: {results['kmeans_ari']:.4f}, NMI: {results['kmeans_nmi']:.4f}, time: {time.time() - kmeans_time:.2f}s")

    print("Running linear probe evaluation...")
    linear_time = time.time()
    X_train, X_test = features[:len(features)//2], features[len(features)//2:]
    y_train, y_test = labels[:len(labels)//2], labels[len(labels)//2:]
    linear_clf = nn.Linear(features.shape[1], num_classes).to(device)
    optimizer = torch.optim.Adam(linear_clf.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    linear_clf.train()
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    y_train_tensor = torch.LongTensor(y_train).to(device)
    for epoch in range(50):
        optimizer.zero_grad()
        output = linear_clf(X_train_tensor)
        loss = criterion(output, y_train_tensor)
        loss.backward()
        optimizer.step()

    linear_clf.eval()
    with torch.no_grad():
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        output = linear_clf(X_test_tensor)
        pred = output.argmax(dim=1).cpu().numpy()
        results["linear_acc"] = accuracy_score(y_test, pred)

    print(f"Linear probe accuracy: {results['linear_acc']:.4f}, time: {time.time() - linear_time:.2f}s")
    print(f"Total zero-shot evaluation time: {time.time() - start_time:.2f}s")
    return results


## Dataset_name

In [None]:
# Dataset selection - change this to use different datasets
dataset_name = "crop14_balance"  # Options: "cifar10", "pathmnist", "chestmnist", "dermamnist", "plantnet300k", galaxy10_decals
num_diet_classes = 10  # Adjust based on dataset (use 200 for PlantNet300K), galaxy10_decals, crop14_balance
# Get the appropriate dataset
print(f"Loading {dataset_name} dataset...")
training_data_raw, test_data_raw, num_classes, input_size, mean, std, is_rgb = get_dataset(dataset_name)
print(f"Dataset loaded: input_size={input_size}, mean={mean}, std={std}, is_rgb={is_rgb}")

# Set up data transforms
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std)
])

# Create stronger augmentations for training
if da_strength > 0:
    aug_list = [
        torchvision.transforms.RandomResizedCrop(input_size, antialias=True),
        torchvision.transforms.RandomHorizontalFlip(),
    ]
    
    if is_rgb and da_strength > 1:
        aug_list.extend([
            torchvision.transforms.RandomApply([
                torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)
            ], p=0.3),
            torchvision.transforms.RandomGrayscale(p=0.2),
        ])
    
    if da_strength > 2 and is_rgb:
        aug_list.append(torchvision.transforms.RandomApply([
            torchvision.transforms.GaussianBlur((3, 3), (1.0, 2.0))
        ], p=0.2))
    
    aug_list.extend([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, std)
    ])
    
    if da_strength > 2:
        aug_list.append(torchvision.transforms.RandomErasing(p=0.25))
    
    train_transform = torchvision.transforms.Compose(aug_list)
else:
    train_transform = test_transform



# Define a custom collate function for handling tensor dimensions in HuggingFace datasets
def custom_collate_fn(batch):
    """Custom collate function for properly handling tensors with different dimensions"""
    images, labels, diet_classes = zip(*batch)
    
    # Stack images
    images = torch.stack(images)
    
    # Handle labels - ensure they're properly batched
    if isinstance(labels[0], torch.Tensor):
        # Convert to 1D tensors with consistent dimensions
        labels = [label.reshape(-1) for label in labels]
        labels = torch.cat(labels)
    else:
        labels = torch.tensor(labels)
    
    # Handle diet classes - ensure they're properly batched
    if isinstance(diet_classes[0], torch.Tensor):
        # Convert to 1D tensors with consistent dimensions
        diet_classes = [cls.reshape(-1) for cls in diet_classes]
        diet_classes = torch.cat(diet_classes)
    else:
        diet_classes = torch.tensor(diet_classes)
    
    return images, labels, diet_classes

## Apply transforms to the datasets
if dataset_name.lower() in ["plantnet300k", "galaxy10_decals"]:
    # For HuggingFace datasets we need to handle the custom dataset wrapper
    training_data = training_data_raw  # No deepcopy needed
    test_data = test_data_raw
    if hasattr(training_data, 'transform'):
        training_data.transform = train_transform
        test_data.transform = test_transform
    else:
        print(f"Note: {dataset_name} dataset structure is using custom transform handling")
else:
    # Standard datasets
    training_data = copy.deepcopy(training_data_raw)
    try:
        training_data.transform = train_transform
        test_data = copy.deepcopy(test_data_raw)
        test_data.transform = test_transform
    except AttributeError:
        # Handle if dataset doesn't have a transform attribute (like Subset)
        print("Note: Using dataset that requires special transform handling")
        # This will be handled by the DataLoader

# Limit training data if specified
if limit_data < np.inf and limit_data < len(training_data):
    print(f"Limiting training data to {limit_data} samples (out of {len(training_data)})")
    indices = torch.randperm(len(training_data))[:limit_data]
    training_data = Subset(training_data, indices)
else:
    print(f"Using full training set: {len(training_data)} samples")

# --------- ADD THE GALAXY FIX HERE IF NEEDED ---------
# Then in your dataset code:
if dataset_name.lower() == "galaxy10_decals":
    print("\n===== REBUILDING GALAXY DATASET FROM SCRATCH =====")
    
    # Get the raw dataset again
    from datasets import load_dataset
    from PIL import Image  # Add the import here as well for safety
    
    raw_dataset = load_dataset("matthieulel/galaxy10_decals")
    train_data = raw_dataset["train"]
    test_data = raw_dataset["test"]
    
    # Build a completely new dataset class from scratch that avoids all dimension issues
    class RobustGalaxyDataset(torch.utils.data.Dataset):
        def __init__(self, hf_dataset, transform=None, diet_classes=100, limit_samples=None):
            """A robust dataset class that guarantees consistent tensor dimensions"""
            self.dataset = hf_dataset
            self.transform = transform
            
            # Limit samples if requested
            if limit_samples is not None and limit_samples < len(hf_dataset):
                indices = torch.randperm(len(hf_dataset))[:limit_samples].tolist()
                self.indices = indices
            else:
                self.indices = list(range(len(hf_dataset)))
            
            # Create diet class assignments - one per sample
            self.diet_classes = torch.randint(0, diet_classes, (len(self.indices),))
            
            # Create resize transform to ensure consistent image sizes
            self.resize = torchvision.transforms.Resize((256, 256))
            
        def __len__(self):
            return len(self.indices)
        
        def __getitem__(self, idx):
            # Get the sample using our saved indices
            original_idx = self.indices[idx]
            sample = self.dataset[original_idx]
            
            # Get image and label
            image = sample['image']
            label = torch.tensor([sample['label']], dtype=torch.long)  # Create as 1D tensor
            
            # Convert to PIL if needed
            if not isinstance(image, Image.Image):
                try:
                    image = Image.fromarray(image)
                except:
                    print(f"Warning: Could not convert image to PIL at index {idx}")
            
            # Resize to ensure consistent dimensions
            image = self.resize(image)
            
            # Apply additional transforms
            if self.transform:
                image = self.transform(image)
            
            # Get diet class for this sample - ensure it's a 1D tensor
            diet_class = torch.tensor([self.diet_classes[idx].item()], dtype=torch.long)
            
            return image, label, diet_class
    
    # Create new robust training dataset
    print("Creating robust galaxy dataset...")
    robust_train_dataset = RobustGalaxyDataset(
        train_data, 
        transform=train_transform,
        diet_classes=num_diet_classes,
        limit_samples=limit_data if limit_data < np.inf else None
    )
    
    # Replace the original wrapped training data
    training_data = robust_train_dataset
    test_data = RobustGalaxyDataset(
        test_data,
        transform=test_transform,
        diet_classes=num_diet_classes
    )
    
    print(f"Created robust datasets: {len(training_data)} training, {len(test_data)} test")
    print("===== GALAXY DATASET REBUILDING COMPLETE =====\n")
else:
    # For non-Galaxy datasets, use the regular DatasetWithIndices wrapper
    training_data = DatasetWithIndices(training_data, num_diet_classes=num_diet_classes)

# Print test set size
print(f"Test set size: {len(test_data)} samples")

# --------- CONTINUE WITH DATALOADERS ---------

# Create data loaders
training_loader = DataLoader(
    training_data, 
    batch_size=batch_size,
    shuffle=True, 
    drop_last=False, 
    num_workers=0
)

test_loader = DataLoader(
    test_data, 
    batch_size=batch_size,
    shuffle=False, 
    drop_last=False, 
    num_workers=0
)

## Backbones-ModelSelection

In [None]:


# Create model - Select backbone
backbone_type = "dinov2"  # Options: "resnet50", "dinov2", "ijepa","MAE", "aim"
model_size = "small"      # Dinov2-Options: "small", "base", "large", "giant", Ijepa-options:"b16_1k", "l14", "h14", MAE-options: "base", "large", or "huge"
                        # mambavision: model_size:  "T", "S", "B", "L", "L2", "L3"
                        # aim: "600M", "1B", "3B", "7B"

sanity_results_dinov2 = unified_sanity_check(backbone_type, model_size)

In [None]:

if backbone_type == "resnet50":
    # Original ResNet50 code
    print("Creating ResNet50 model with ImageNet pre-training...")
    net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    
    # Modify for 32x32 images (CIFAR10)
    print("Adapting ResNet50 for 32x32 images...")
    net.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    net.maxpool = nn.Identity()
    embedding_dim = net.fc.in_features
    net.fc = nn.Identity()

elif backbone_type == "mae":
    # Use MAE
    print(f"Creating MAE-{model_size} model...")
    net, embedding_dim = get_mae_model(device, model_size=model_size)
    # All parameters should be trainable already from the get_mae_model function
    # Just add diagnostic to verify this
    
    # ADD THE DIAGNOSTIC CODE HERE
    print("\nChecking trainable parameters...")
    trainable_params_count = 0
    total_params_count = 0
    trainable_wrapper_params = 0
    for name, param in net.named_parameters():
        total_params_count += 1
        if param.requires_grad:
            trainable_params_count += 1
            if 'model' in name:  # Parameters within the wrapped MAE model
                trainable_wrapper_params += 1
    print(f"Total parameters: {total_params_count}")
    print(f"Trainable parameters: {trainable_params_count}")
    print(f"Trainable wrapper MAE parameters: {trainable_wrapper_params}")

elif backbone_type == "ijepa":
    print(f"Creating I-JEPA-{model_size} model...")
    net, embedding_dim = get_ijepa_model(device, model_size=model_size)
    
    # OPTIONAL: Apply strategic freezing to focus training on upper layers
    print("Applying strategic freezing to I-JEPA...")
    frozen_params = 0
    total_params = 0
    
    for name, param in net.model.named_parameters():
        total_params += 1
        # Only freeze embeddings, unfreeze all transformer layers
        if 'embeddings' in name:  # Only freeze embeddings, not encoder layers
            param.requires_grad = False
            frozen_params += 1
        else:
            param.requires_grad = True  # Explicitly set other layers to trainable
    
    print(f"Frozen {frozen_params} out of {total_params} parameters")
    
    # Add diagnostic to check which layers are trainable
    print("\nTrainable layers in I-Jepa:")
    for name, param in net.model.named_parameters():
        if param.requires_grad:
            print(f"  ✓ {name}")
        else:
            print(f"  ✗ {name}")

     # ADD THE DIAGNOSTIC CODE HERE
    print("\nChecking trainable parameters...")
    trainable_params_count = 0
    total_params_count = 0
    trainable_wrapper_params = 0
    for name, param in net.named_parameters():
        total_params_count += 1
        if param.requires_grad:
            trainable_params_count += 1
            if 'model' in name:  # Parameters within the wrapped I-jepa model
                trainable_wrapper_params += 1
    print(f"Total parameters: {total_params_count}")
    print(f"Trainable parameters: {trainable_params_count}")
    print(f"Trainable wrapper I-jepa parameters: {trainable_wrapper_params}")
    
    # Rest of your DIET training code remains the same
elif backbone_type == "mambavision":
    # Use MambaVision
    print(f"Creating MambaVision {model_size} model...")
    net, embedding_dim = get_mambavision_model(device, model_variant=model_size)
    
    # Add diagnostic to check trainable parameters
    print("\nChecking trainable parameters...")
    trainable_params_count = 0
    total_params_count = 0
    trainable_wrapper_params = 0
    for name, param in net.named_parameters():
        total_params_count += 1
        if param.requires_grad:
            trainable_params_count += 1
            if 'model' in name:  # Parameters within the wrapped model
                trainable_wrapper_params += 1
    print(f"Total parameters: {total_params_count}")
    print(f"Trainable parameters: {trainable_params_count}")
    print(f"Trainable wrapper MambaVision parameters: {trainable_wrapper_params}")

    
elif backbone_type == "dinov2":
    # Use DINOv2
    print(f"Creating DINOv2-{model_size} model...")
    net, embedding_dim = get_dinov2_model(device, model_size=model_size)

    # UNFREEZE MORE LAYERS - Modified freezing strategy
    print("Applying minimal freezing to DINOv2 (allowing more gradients)...")
    frozen_params = 0
    total_params = 0
    
    for name, param in net.model.named_parameters():
        total_params += 1
        # Only freeze embeddings, unfreeze all transformer layers
        if 'embeddings' in name:  # Only freeze embeddings, not encoder layers
            param.requires_grad = False
            frozen_params += 1
        else:
            param.requires_grad = True  # Explicitly set other layers to trainable
    
    print(f"Frozen {frozen_params} out of {total_params} parameters")
    
    # Add diagnostic to check which layers are trainable
    print("\nTrainable layers in DINOv2:")
    for name, param in net.model.named_parameters():
        if param.requires_grad:
            print(f"  ✓ {name}")
        else:
            print(f"  ✗ {name}")

     # ADD THE DIAGNOSTIC CODE HERE
    print("\nChecking trainable parameters...")
    trainable_params_count = 0
    total_params_count = 0
    trainable_wrapper_params = 0
    for name, param in net.named_parameters():
        total_params_count += 1
        if param.requires_grad:
            trainable_params_count += 1
            if 'model' in name:  # Parameters within the wrapped DINOv2 model
                trainable_wrapper_params += 1
    print(f"Total parameters: {total_params_count}")
    print(f"Trainable parameters: {trainable_params_count}")
    print(f"Trainable wrapper DINOv2 parameters: {trainable_wrapper_params}")

elif backbone_type == "aim":
    # Use AIM (Autoregressive Image Models)
    print(f"Creating AIM-{model_size} model...")
    net, embedding_dim = get_aim_model(device, model_size=model_size)
    
    # Add diagnostic to check trainable parameters
    print("\nChecking trainable parameters...")
    trainable_params_count = 0
    total_params_count = 0
    trainable_wrapper_params = 0
    for name, param in net.named_parameters():
        total_params_count += 1
        if param.requires_grad:
            trainable_params_count += 1
            if 'model' in name:  # Parameters within the wrapped AIM model
                trainable_wrapper_params += 1
    print(f"Total parameters: {total_params_count}")
    print(f"Trainable parameters: {trainable_params_count}")
    print(f"Trainable wrapper AIM parameters: {trainable_wrapper_params}")

## The layers are organized in a hierarchical fashion, with earlier layers (0, 1, 2) processing more basic features and later layers (9, 10, 11) handling more complex semantic features.

## Training Section


In [None]:

# Move model to device
net = net.to(device)

print(f"Using {backbone_type} backbone with embedding dimension: {embedding_dim}")

# Add projection head definition here
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.layer1 = nn.Linear(in_dim, hidden_dim)
        self.norm = nn.BatchNorm1d(hidden_dim)  # Add normalization
        self.dropout = nn.Dropout(0.1)  # Add dropout for better generalization
        self.layer2 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.norm(x)  # Apply normalization
        x = F.relu(x)
        x = self.dropout(x)  # Apply dropout
        x = self.layer2(x)
        return x
    
    
# Create the projection head
projection_dim = 256  # You can experiment with this value
projection_head = ProjectionHead(embedding_dim, embedding_dim, projection_dim).to(device)


# Create heads for DIET and probing
W_probe = nn.Linear(embedding_dim, num_classes).to(device)

W_diet = nn.Linear(projection_dim, num_diet_classes, bias=False).to(device)

# Create optimizer
print(f"Creating optimizer with lr={lr}, weight_decay={weight_decay}")
optimizer = torch.optim.AdamW(
    list(net.parameters()) + list(W_probe.parameters()) + 
    list(W_diet.parameters()) + list(projection_head.parameters()),
    lr=lr, weight_decay=weight_decay
)


# Add learning rate scheduler
print("Creating learning rate scheduler (cosine annealing)")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoch, eta_min=1e-5) #, eta_min=1e-5


# Loss functions
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
criterion_diet = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

# DIET active status
is_diet_active = label_smoothing > 0
print(f"DIET is {'active' if is_diet_active else 'inactive'} (label_smoothing={label_smoothing})")

# Run initial zero-shot evaluation
print("\n" + "="*50)
print("INITIAL ZERO-SHOT EVALUATION (BEFORE TRAINING)")
print("="*50)
initial_time = time.time()
###cHANGE THIS 
initial_results = zero_shot_eval(net, test_loader, num_classes, eval_id=0)
#####


print(f"Initial evaluation completed in {time.time() - initial_time:.2f}s")

# Print hyperparameters used in the experiment
print("==========================================")
print("Experiment Hyperparameters:")
print("==========================================")
print(f"Number of Epochs     : {num_epoch}")
print(f"Batch Size           : {batch_size}")
print(f"Data Augmentation Strength : {da_strength}")
print(f"Learning Rate        : {lr}")
print(f"Weight Decay         : {weight_decay}")
print(f"Label Smoothing      : {label_smoothing}")
print(f"Data Limit           : {limit_data}")
print(f"num_diet_classes           : {num_diet_classes}")

print("==========================================")









In [None]:
def evaluate_test_set(net, test_loader, device, W_probe):
    print("\nStarting evaluation on test set:")
    net.eval()
    with torch.no_grad():
        run_acc_test = []
        for i, batch in enumerate(test_loader):
            # Flexible unpacking: works with (x, y) or (x, y, n)
            if isinstance(batch, (list, tuple)):
                if len(batch) == 3:
                    x, y, _ = batch
                    print(f"Test Batch {i}: Unpacked as (x, y, n)")
                elif len(batch) == 2:
                    x, y = batch
                    print(f"Test Batch {i}: Unpacked as (x, y)")
                else:
                    print(f"Warning: Unexpected test batch length ({len(batch)}) at batch {i}. Skipping.")
                    continue
            else:
                print(f"Warning: Unexpected test batch type at batch {i}. Skipping.")
                continue

            x = x.to(device)
            y = y.to(device)
            z = net(x)
            logits_probe = W_probe(z)
            
            # Adjust dimensions if necessary
            if y.dim() != logits_probe.argmax(1).dim():
                y = y.squeeze() if y.dim() > logits_probe.argmax(1).dim() else y.unsqueeze(0)
            batch_acc = torch.mean((y == logits_probe.argmax(1)).float()).item()
            run_acc_test.append(batch_acc)
            
            print(f"Test Batch {i}: Accuracy={batch_acc:.4f}")
            
        test_acc = np.mean(run_acc_test) if run_acc_test else 0
        print(f"\nOverall Test Accuracy: {test_acc:.4f}")
    return test_acc


In [None]:
# Replace your training loop with this safe version
import wandb

import os
os.makedirs("checkpoints", exist_ok=True)



# Training loop with extensive error handling
print("\n" + "="*50)
print(f"STARTING TRAINING FOR {num_epoch} EPOCHS (SAFE MODE)")
print("="*50)
print(f"Device being used: {device}")
print(f"Dataset: {dataset_name} | Size: {len(training_data)} samples")
print(f"Number of batches per epoch: {len(training_loader)}")
print(f"Batch size: {batch_size}")

# Add wandb configuration here
# Create experiment configuration dictionary
experiment_config = {
    # Model parameters
    "backbone_type": backbone_type,
    "model_size": model_size,
    "embedding_dim": embedding_dim,
    "projection_dim": projection_dim,
    
    # Dataset parameters
    "dataset_name": dataset_name,
    "num_classes": num_classes,
    "num_diet_classes": num_diet_classes,
    "input_size": input_size,
    "is_rgb": is_rgb,
    "limit_data": limit_data if 'limit_data' in locals() else None,
    
    # Training parameters
    "num_epoch": num_epoch,
    "batch_size": batch_size,
    "lr": lr,
    "weight_decay": weight_decay,
    "da_strength": da_strength if 'da_strength' in locals() else None,
    "label_smoothing": label_smoothing,
    
    # DIET-specific settings
    "is_diet_active": is_diet_active
}

# Initialize wandb with your configuration
run = init_wandb(experiment_config)
# Then log model architecture
log_model_architecture(run, net, projection_head, W_probe, W_diet)

# Then log initial zero-shot results
log_zero_shot_metrics(run, initial_results, 0)

train_start_time = time.time()
epoch_times = []
metrics_history = {
    "train_loss_diet": [],
    "train_loss_probe": [],
    "train_acc": [],
    "test_acc": [],
    "zero_shot_metrics": {}
}
metrics_history["zero_shot_metrics"][0] = initial_results  # Store epoch 0 results

# ----- Begin Training & Evaluation Loop -----
for epoch in range(num_epoch):
    epoch_start = time.time()
    net.train()
    run_loss_diet, run_loss_probe, run_acc = [], [], []
    
    print(f"\n==========================")
    print(f"Starting epoch {epoch+1}/{num_epoch} at {time.strftime('%H:%M:%S')}")
    print(f"==========================\n")
    
    print("Initializing training loop...")
    # Use tqdm progress bar
    pbar = tqdm(training_loader, desc=f"Epoch {epoch+1}/{num_epoch}", 
                leave=False)  # Set leave=False to prevent multiple progress bars
    
    for i, batch in enumerate(pbar):
        batch_start = time.time()
        # Flexible batch unpacking: support (x, y, n) and (x, y)
        if isinstance(batch, (list, tuple)):
            if len(batch) == 3:
                x, y, n = batch
                # Don't print this for each batch - too verbose
                # print(f"Batch {i}: Unpacked as (x, y, n)")
            elif len(batch) == 2:
                x, y = batch
                n = None  # No diet class provided
                # Don't print this for each batch - too verbose
                # print(f"Batch {i}: Unpacked as (x, y) – no diet class")
            else:
                print(f"Warning: Unexpected batch length ({len(batch)}) at batch {i}. Skipping.")
                continue
        else:
            print(f"Warning: Unexpected batch type at batch {i}. Skipping.")
            continue

        # Send tensors to device
        x = x.to(device)
        y = y.to(device)
        # Ensure y is 1D (flatten if needed)
        if y.dim() > 1:
            y = y.view(-1)
        if n is not None:
            n = n.to(device).long()
            if n.dim() > 1:
                n = n.view(-1)
        
        # Forward pass
        z = net(x)  # Original features
        z_norm = F.normalize(z, p=2, dim=1)  # L2 normalize
        z_proj = projection_head(z_norm)       # Projection through MLP
        
        temperature = 3
        if n is not None:
            logits_diet = W_diet(z_proj) / temperature
            loss_diet = criterion_diet(logits_diet, n)
        else:
            loss_diet = torch.tensor(0.0, device=device)
        
        logits_probe = W_probe(z)
        loss_probe = criterion(logits_probe, y)
        
        # Combine losses dynamically
        if n is not None:
            if epoch < 15:
                loss = 0.6 * loss_diet + 0.4 * loss_probe
            elif epoch < 25:
                loss = 0.4 * loss_diet + 0.6 * loss_probe
            else:
                loss = 0.2 * loss_diet + 0.8 * loss_probe
        else:
            loss = loss_probe

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track gradient norm to confirm parameters are updating
        if i % 10 == 0:  # Only check every 10 batches to save computation
            total_grad_norm = 0
            trainable_params = 0
            for name, param in net.named_parameters():
                if param.grad is not None:
                    total_grad_norm += param.grad.norm().item()
                    trainable_params += 1
            # Print without tqdm interference by using tqdm.write
            tqdm.write(f"Batch {i}: grad_norm={total_grad_norm:.4f} across {trainable_params} trainable params")
        
        # Record training metrics
        run_loss_diet.append(loss_diet.item())
        run_loss_probe.append(loss_probe.item())
        preds = logits_probe.argmax(dim=1)
        batch_acc = torch.mean((y == preds).float()).item()
        run_acc.append(batch_acc)
        
        batch_time = time.time() - batch_start
        pbar.set_postfix({
            'DIET loss': f"{np.mean(run_loss_diet):.4e}",
            'Probe loss': f"{np.mean(run_loss_probe):.4e}",
            'Acc': f"{np.mean(run_acc):.4f}",
            'Batch time': f"{batch_time:.3f}s"
        })
    
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    print(f"\nEpoch {epoch+1} completed in {epoch_time:.2f}s\n")
    
    # Save metrics
    metrics_history["train_loss_diet"].append(np.mean(run_loss_diet))
    metrics_history["train_loss_probe"].append(np.mean(run_loss_probe))
    metrics_history["train_acc"].append(np.mean(run_acc))
    
    # Step the learning rate scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f"Learning rate updated to: {current_lr:.6f}")
    
    print(f"Epoch {epoch+1} Metrics - DIET Loss: {np.mean(run_loss_diet):.4e}, Probe Loss: {np.mean(run_loss_probe):.4e}, Accuracy: {np.mean(run_acc):.4f}")

    # ADD THIS LINE HERE - log training metrics to wandb
    log_training_metrics(
        run, 
        {"diet_loss": np.mean(run_loss_diet), "probe_loss": np.mean(run_loss_probe), "accuracy": np.mean(run_acc)}, 
        epoch+1, 
        current_lr
    )

    # ----- Test Evaluation Loop -----
    net.eval()
    with torch.no_grad():
        run_acc_test = []
        for i, batch in enumerate(test_loader):
            # Flexible unpacking for test batches (supports (x, y, n) and (x, y))
            if isinstance(batch, (tuple, list)):
                if len(batch) == 3:
                    x, y, _ = batch
                    # Removed verbose logging
                    # print(f"Test Batch {i}: Unpacked as (x, y, n)")
                elif len(batch) == 2:
                    x, y = batch
                    # Removed verbose logging
                    # print(f"Test Batch {i}: Unpacked as (x, y)")
                else:
                    raise ValueError("Unexpected batch structure in test set")
            else:
                raise ValueError("Test batch is not a tuple or list")
            
            x = x.to(device)
            y = y.to(device)
            if y.dim() > 1:
                y = y.view(-1)
            z = net(x)
            logits_probe = W_probe(z)
            if y.dim() != logits_probe.argmax(1).dim():
                if y.dim() > logits_probe.argmax(1).dim():
                    y = y.squeeze()
                else:
                    y = y.unsqueeze(0)
            test_batch_acc = torch.mean((y == logits_probe.argmax(1)).float()).item()
            run_acc_test.append(test_batch_acc)
            
            # Removed verbose batch logging
            # print(f"Test Batch {i}: Accuracy={test_batch_acc:.4f}")
        
        test_acc = np.mean(run_acc_test) if run_acc_test else 0
        metrics_history["test_acc"].append(test_acc)
        
        #wandb
        log_evaluation_metrics(run, {"accuracy": test_acc}, epoch+1)
    # ----- End Test Evaluation Loop -----
    
    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epoch} summary:")
    print(f"  Train - DIET loss: {np.mean(run_loss_diet) if run_loss_diet else 0:.4e}, "
          f"Probe loss: {np.mean(run_loss_probe) if run_loss_probe else 0:.4e}, "
          f"Acc: {np.mean(run_acc) if run_acc else 0:.4f}")
    print(f"  Test  - Acc: {test_acc:.4f}")
    
    # ----- Zero-Shot Evaluation (every 5 epochs or last epoch) -----
    if (epoch + 1) % 5 == 0 or epoch == num_epoch - 1:
        print(f"\nRunning zero-shot evaluation at epoch {epoch+1}...")
        try:
            # Perturb model slightly to force a different evaluation state
            for param in net.parameters():
                if param.requires_grad:
                    with torch.no_grad():
                        original_data = param.data.clone()
                        param.data.add_(torch.randn_like(param) * 1e-6)  # Tiny perturbation
                        break
            
            # Run your zero-shot evaluation function (assumes it's defined)
            epoch_zero_shot = zero_shot_eval(net, test_loader, num_classes, eval_id=epoch+1)

            # Restore original parameter
            with torch.no_grad():
                param.data.copy_(original_data)
            
            metrics_history["zero_shot_metrics"][epoch+1] = copy.deepcopy(epoch_zero_shot)
            
            # Log zero-shot metrics to wandb
            log_zero_shot_metrics(run, epoch_zero_shot, epoch+1, initial_results)

            # Print formatted zero-shot performance table
            print(f"\nZero-shot Performance at epoch {epoch+1}:")
            print("-" * 60)
            print(f"{'Metric':<15} {'Initial':<10} {'Current':<10} {'Change':<10} {'Relative %':<10}")
            print("-" * 60)
            for metric in epoch_zero_shot.keys():
                initial = initial_results[metric]
                current = epoch_zero_shot[metric]
                change = current - initial
                rel_change = (change / initial) * 100 if initial > 0 else float('inf')
                print(f"{metric:<15} {initial:.4f}     {current:.4f}     {change:+.4f}     {rel_change:+.2f}%")
        except Exception as zs_err:
            print(f"Error in zero-shot evaluation: {zs_err}")
    # ----- End Zero-Shot Evaluation -----
  
    # Add wandb checkpoint saving here
    # Save model checkpoint to wandb
    checkpoint_metrics = {
        "train_acc": np.mean(run_acc) if run_acc else 0,
        "test_acc": test_acc,
        "train_loss_diet": np.mean(run_loss_diet) if run_loss_diet else 0,
        "train_loss_probe": np.mean(run_loss_probe) if run_loss_diet else 0,
    }
    save_model_checkpoint(
        run, net, optimizer, projection_head, W_probe, W_diet, 
        epoch+1, checkpoint_metrics, save_dir="checkpoints"
    )

# End of Epoch Loop
print(f"\nTraining completed in {time.time() - train_start_time:.2f}s")

# Remaining code for plotting and evaluation remains unchanged...








# Plot training progress
plt.figure(figsize=(15, 5))

# Plot loss
plt.subplot(1, 3, 1)
plt.plot(metrics_history["train_loss_diet"], label="DIET Loss")
plt.plot(metrics_history["train_loss_probe"], label="Probe Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid(True)

# Plot accuracy
plt.subplot(1, 3, 2)
plt.plot(metrics_history["train_acc"], label="Train Accuracy")
plt.plot(metrics_history["test_acc"], label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Model Accuracy")
plt.legend()
plt.grid(True)

# Final zero-shot evaluation
print("\n" + "="*50)
print("FINAL ZERO-SHOT EVALUATION (AFTER TRAINING)")
print("="*50)

##here

# Create visualization of zero-shot metrics progression
plt.figure(figsize=(15, 10))

# Get all metrics and epochs
tracked_epochs = sorted(metrics_history["zero_shot_metrics"].keys())
metrics_list = list(initial_results.keys())

# Plot each metric's progression
for i, metric in enumerate(metrics_list):
    plt.subplot(2, 2, i+1)
    values = [metrics_history["zero_shot_metrics"][e][metric] for e in tracked_epochs]
    plt.plot(tracked_epochs, values, marker='o', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel(f'{metric} Score')
    plt.title(f'Zero-shot {metric} Progression')
    plt.grid(True)
    
    # Add initial and final values as text annotations
    plt.annotate(f'{values[0]:.4f}', (tracked_epochs[0], values[0]), 
                 textcoords="offset points", xytext=(0,10), ha='center')
    plt.annotate(f'{values[-1]:.4f}', (tracked_epochs[-1], values[-1]),
                 textcoords="offset points", xytext=(0,10), ha='center')

plt.tight_layout()
plt.suptitle('Zero-shot Metrics Progression During Training', fontsize=16)
plt.subplots_adjust(top=0.9)
plt.savefig(f'{dataset_name}_zeroshot_progression.png')
plt.show()






# Add wandb figure and table logging here
# Log zero-shot progression plot to wandb
zero_shot_fig = plt.gcf()  # Get current figure
log_figure_to_wandb(run, zero_shot_fig, "zero_shot_progression")

# Log zero-shot comparison table
log_zero_shot_comparison_table(run, metrics_history, tracked_epochs, metrics_list)












# Create a table for the report
print("\nZero-shot Progression Table:")
print("-"*80)
header = "Epoch".ljust(10)
for metric in metrics_list:
    header += f"{metric}".ljust(15)
print(header)
print("-"*80)

for epoch in tracked_epochs:
    row = f"{epoch}".ljust(10)
    for metric in metrics_list:
        value = metrics_history["zero_shot_metrics"][epoch][metric]
        row += f"{value:.4f}".ljust(15)
    print(row)
print("-"*80)








# ADD THE TABLE LOGGING RIGHT HERE:
# Log metrics tables to W&B
log_metrics_table(run, metrics_history)
log_zero_shot_comparison_table(run, metrics_history, tracked_epochs, metrics_list)













final_time = time.time()
final_results = zero_shot_eval(net, test_loader, num_classes, eval_id=num_epoch+1)
print(f"Final evaluation completed in {time.time() - final_time:.2f}s")




# Add wandb final zero-shot metrics logging here
# Log final zero-shot results to wandb
log_zero_shot_metrics(run, final_results, num_epoch+1, initial_results)








# Calculate improvements
improvements = {
    f"improvement_{k}": final_results[k] - initial_results[k]
    for k in initial_results.keys()
}

# Plot zero-shot metrics
plt.subplot(1, 3, 3)
metrics = list(initial_results.keys())
x = range(len(metrics))
width = 0.35
plt.bar(x, [initial_results[m] for m in metrics], width, label='Initial')
plt.bar([i + width for i in x], [final_results[m] for m in metrics], width, label='Final')
plt.xlabel("Metrics")
plt.ylabel("Score")
plt.title("Zero-Shot Performance")
plt.xticks([i + width/2 for i in x], metrics)
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()






# Add wandb figure logging here
# Create and log training progress plot to wandb
training_fig = plt.gcf()  # Get current figure
log_figure_to_wandb(run, training_fig, "training_progress")









# ADD TABLE LOGGING FOR FINAL RESULTS HERE:
# Create and log final results table to wandb
final_results_table = wandb.Table(
    columns=["Metric", "Initial", "Final", "Improvement", "Relative %"]
)

for metric in metrics:
    initial = initial_results[metric]
    final = final_results[metric]
    imp = improvements[f"improvement_{metric}"]
    rel_imp = (imp / initial) * 100 if initial > 0 else float('inf')
    final_results_table.add_data(metric, initial, final, imp, rel_imp)











# Final sum
print("\n" + "="*50)
print("EXPERIMENT RESULTS SUMMARY")
print("="*50)
print(f"Dataset: {dataset_name.upper()}")    
print(f"Model: {backbone_type.upper()} {'('+model_size+')' if backbone_type=='dinov2' else '(ImageNet pre-trained)'}")
print(f"DIET label smoothing: {label_smoothing}" + (" (DIET active)" if is_diet_active else " (DIET inactive)"))
print(f"Training samples: {len(training_data)}, Epochs: {num_epoch}")
print("\nZero-shot performance:")
print("-"*60)

print(f"{'Metric':<15} {'Initial':<10} {'Final':<10} {'Improvement':<10} {'Relative %':<10}")
print("-"*60)
for metric in metrics:
    initial = initial_results[metric]
    final = final_results[metric]
    imp = improvements[f"improvement_{metric}"]
    rel_imp = (imp / initial) * 100 if initial > 0 else float('inf')
    print(f"{metric:<15} {initial:.4f}     {final:.4f}     {imp:+.4f}     {rel_imp:+.2f}%")

print("\nCONCLUSION:")
avg_improvement = np.mean([improvements[f"improvement_{k}"] for k in initial_results.keys()])
if avg_improvement > 0:
    print(f"DIET finetuning {'improved' if is_diet_active else 'would likely improve'} zero-shot performance " +
          f"by an average of {avg_improvement:.4f} ({(avg_improvement / np.mean(list(initial_results.values()))) * 100:.2f}%)")
else:
    print(f"DIET finetuning {'did not improve' if is_diet_active else 'would likely not improve'} zero-shot performance")











# Add wandb summary and finish here
# Log summary to wandb
summary_text = f"""
# DIET Experiment Summary

## Configuration
- **Dataset**: {dataset_name.upper()}
- **Model**: {backbone_type.upper()} {f"({model_size})" if backbone_type in ['dinov2', 'mae', 'ijepa', 'aim'] else '(ImageNet pre-trained)'}
- **DIET**: {"Active" if is_diet_active else "Inactive"} (label_smoothing={label_smoothing})
- **Training**: {len(training_data)} samples, {num_epoch} epochs

## Zero-shot Performance
| Metric | Initial | Final | Improvement | Relative % |
|--------|---------|-------|------------|------------|
"""

for metric in metrics:
    initial = initial_results[metric]
    final = final_results[metric]
    imp = improvements[f"improvement_{metric}"]
    rel_imp = (imp / initial) * 100 if initial > 0 else float('inf')
    summary_text += f"| {metric} | {initial:.4f} | {final:.4f} | {imp:+.4f} | {rel_imp:+.2f}% |\n"

# Log summary to wandb
run.log({"experiment_summary": wandb.Html(summary_text)})

# Update run summary with final metrics
run.summary.update({
    "avg_improvement": avg_improvement,
    "avg_relative_improvement": (avg_improvement / np.mean(list(initial_results.values()))) * 100 if np.mean(list(initial_results.values())) > 0 else 0,
    "final_test_acc": test_acc,
    "best_test_acc": getattr(save_model_checkpoint, "best_acc", test_acc),
    "training_time": time.time() - train_start_time
})

# Finish the wandb run
run.finish()








In [None]:
import json
import os
import pandas as pd
import uuid
from datetime import datetime

def save_experiment_results(model_name, dataset_name, metrics_history, hyperparams, save_dir="results"):
    """Save experiment results with guaranteed unique folders for each experiment"""
    
    # Create directory structure if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate a unique experiment ID (combining timestamp and a UUID)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")  # Added milliseconds
    exp_id = f"{timestamp}_{str(uuid.uuid4())[:8]}"
    
    # Create unique experiment folder with key parameters in name
    lr_str = f"lr{hyperparams['lr']:.1e}"
    diet_str = f"diet{hyperparams['num_diet_classes']}"
    ls_str = f"ls{hyperparams['label_smoothing']}"
    
    # Create folder path with key hyperparameters
    exp_folder = f"{model_name}_{dataset_name}_{lr_str}_{diet_str}_{ls_str}_{exp_id}"
    full_path = f"{save_dir}/{exp_folder}"
    os.makedirs(full_path, exist_ok=True)
    
    # 1. Save experiment description with ALL hyperparameters
    with open(f"{full_path}/experiment_info.json", 'w') as f:
        # Add any additional metadata you want to track
        metadata = {
            "model": model_name,
            "dataset": dataset_name,
            "experiment_id": exp_id,
            "timestamp": timestamp,
            "hyperparameters": hyperparams
        }
        json.dump(metadata, f, indent=2)
    
    # 2. Save metrics history as JSON
    with open(f"{full_path}/metrics_history.json", 'w') as f:
        # Convert NumPy values to Python native types for JSON serialization
        serializable_metrics = {}
        for key, value in metrics_history.items():
            if key == "zero_shot_metrics":
                serializable_metrics[key] = {
                    str(epoch): {m: float(v) for m, v in metrics.items()}
                    for epoch, metrics in value.items()
                }
            else:
                serializable_metrics[key] = [float(v) for v in value]
        
        json.dump(serializable_metrics, f, indent=2)
    
    # 3. Create and save progression table as CSV
    if "zero_shot_metrics" in metrics_history:
        epochs = sorted(metrics_history["zero_shot_metrics"].keys())
        metrics_list = list(metrics_history["zero_shot_metrics"][epochs[0]].keys())
        
        table_data = []
        for epoch in epochs:
            row = {"epoch": epoch}
            for metric in metrics_list:
                row[metric] = metrics_history["zero_shot_metrics"][epoch][metric]
            table_data.append(row)
        
        # Save as CSV
        pd.DataFrame(table_data).to_csv(f"{full_path}/progression_table.csv", index=False)
    
    # 4. Save plots
    plt.figure(figsize=(15, 10))
    
    # Training metrics
    plt.subplot(2, 2, 1)
    plt.plot(metrics_history["train_loss_diet"], label="DIET Loss")
    plt.plot(metrics_history["train_loss_probe"], label="Probe Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()
    plt.grid(True)
    
    plt.subplot(2, 2, 2)
    plt.plot(metrics_history["train_acc"], label="Train Accuracy")
    plt.plot(metrics_history["test_acc"], label="Test Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Model Accuracy")
    plt.legend()
    plt.grid(True)
    
    # Zero-shot progression
    if "zero_shot_metrics" in metrics_history:
        tracked_epochs = sorted(metrics_history["zero_shot_metrics"].keys())
        metrics_list = list(metrics_history["zero_shot_metrics"][tracked_epochs[0]].keys())
        
        for i, metric in enumerate(metrics_list[:2]):  # First 2 metrics
            plt.subplot(2, 2, 3+i)
            values = [metrics_history["zero_shot_metrics"][e][metric] for e in tracked_epochs]
            plt.plot(tracked_epochs, values, marker='o', linewidth=2)
            plt.xlabel('Epoch')
            plt.ylabel(f'{metric}')
            plt.title(f'Zero-shot {metric} Progression')
            plt.grid(True)
    
    plt.tight_layout()
    plt.suptitle(f"{model_name} on {dataset_name} - {diet_str}, {ls_str}", fontsize=14)
    plt.subplots_adjust(top=0.92)
    plt.savefig(f"{full_path}/summary_plot.png", dpi=300)
    
    # 5. Also save a text summary with key results
    with open(f"{full_path}/results_summary.txt", 'w') as f:
        f.write(f"Experiment Summary: {model_name} on {dataset_name}\n")
        f.write("="*50 + "\n\n")
        
        f.write("Key Hyperparameters:\n")
        f.write(f"- Learning Rate: {hyperparams['lr']}\n")
        f.write(f"- DIET Classes: {hyperparams['num_diet_classes']}\n")
        f.write(f"- Label Smoothing: {hyperparams['label_smoothing']}\n")
        f.write(f"- Weight Decay: {hyperparams['weight_decay']}\n\n")
        
        if "zero_shot_metrics" in metrics_history:
            epochs = sorted(metrics_history["zero_shot_metrics"].keys())
            initial_metrics = metrics_history["zero_shot_metrics"][epochs[0]]
            final_metrics = metrics_history["zero_shot_metrics"][epochs[-1]]
            
            f.write("Zero-Shot Performance:\n")
            f.write("-"*50 + "\n")
            f.write(f"{'Metric':<15} {'Initial':<10} {'Final':<10} {'Change':<10} {'Relative %':<10}\n")
            f.write("-"*50 + "\n")
            
            for metric in initial_metrics.keys():
                initial = initial_metrics[metric]
                final = final_metrics[metric]
                change = final - initial
                rel_change = (change / initial) * 100 if initial > 0 else float('inf')
                f.write(f"{metric:<15} {initial:.4f}     {final:.4f}     {change:+.4f}     {rel_change:+.2f}%\n")
    
    print(f"Experiment results saved to {full_path}")
    return full_path


In [None]:
# At the end of your script, after training
hyperparams = {
    "num_epoch": num_epoch,
    "batch_size": batch_size,
    "da_strength": da_strength,
    "lr": lr,
    "weight_decay": weight_decay,
    "label_smoothing": label_smoothing,
    "num_diet_classes": num_diet_classes,
    "temperature": 0.1,  # Your temperature parameter
    "model_size": model_size if backbone_type == "dinov2" else "n/a",
    "freezing_layers": "last2" if backbone_type == "dinov2" else "n/a"
}

model_name = f"{backbone_type}_{model_size}" if backbone_type == "dinov2" else backbone_type
save_path = save_experiment_results(model_name, dataset_name, metrics_history, hyperparams)

## Sanity Check