
An example for distributing the Pytorch training with Hyperparameter tuning optimizations using [Optuna](https://optuna.org/)   
**Including checkpointing** --> saved to UC Volumes and logged to MLflow post training after each Optuna trial


The cluster used is the same as before -- you are welcome to test other cluster configs. 

```
"spark_version": "16.4.x-scala2.13",
"node_type_id": "g5.12xlarge",

## omit autoscale --> n_workers = 4 


In [0]:
%pip install mlflow[skinny]>=3 optuna nvidia-ml-py3 --upgrade

dbutils.library.restartPython()

In [0]:
# MLflow version: 3.4.0
# Optuna version: 4.5.0
# PyTorch version: 2.6.0+cu124

In [0]:
import time
import os
import io
import base64
import shutil
import tempfile
import uuid
from datetime import datetime, timedelta
from functools import partial

import numpy as np
import torch
import torch.nn as nn  
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image

import mlflow
import optuna
from pyspark.ml.torch.distributor import TorchDistributor
from streaming import StreamingDataset, StreamingDataLoader
from streaming.base.util import clean_stale_shared_memory

print(f"MLflow version: {mlflow.__version__}")
print(f"Optuna version: {optuna.__version__}")
print(f"PyTorch version: {torch.__version__}")

In [0]:
# ============================================================
# DATASET CONFIGURATION
# ============================================================
CATALOG = "mmt"
SCHEMA = "pytorch"
VOLUME_NAME = "torch_data"

# Dataset paths
mds_train_dir = 'imagenet_tiny200_mds_train'
mds_val_dir = 'imagenet_tiny200_mds_val'
data_storage_location = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

# Dataset parameters
num_classes = 200  # ImageNet Tiny has 200 classes
num_workers = 4

# ============================================================
# TRAINING CONFIGURATION
# ============================================================
NUM_EPOCHS = 2 #5
N_TRIALS = 3 #10  # Number of Optuna trials

# ============================================================
# CHECKPOINTING CONFIGURATION
# ============================================================
ENABLE_CHECKPOINTING = True
CHECKPOINT_FREQUENCY = 1  # Save every N epochs
RESUME_FROM_CHECKPOINT = False

# ============================================================
# EXPERIMENT-SPECIFIC PATH CONFIGURATION
# ============================================================
import os
from datetime import datetime
import mlflow

BASE_VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"
BASE_VOLUME_DBFS = f"dbfs:/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

# Experiment identifier
EXPERIMENT_SHORT_NAME = "imagenet_mobilenetv2_hpt_chkpt"
EXPERIMENT_TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# User and MLflow experiment name
USER_NAME = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
EXPERIMENT_NAME = f"/Users/{USER_NAME}/mlflow_experiments/pytorch_{EXPERIMENT_SHORT_NAME}"

# Experiment root directory
EXPERIMENT_ROOT = f"{BASE_VOLUME_PATH}/{EXPERIMENT_SHORT_NAME}"
EXPERIMENT_ROOT_DBFS = f"{BASE_VOLUME_DBFS}/{EXPERIMENT_SHORT_NAME}"

# Separate paths for checkpoints and MLflow artifacts
CHECKPOINT_BASE_DIR = f"{EXPERIMENT_ROOT}/checkpoints/{EXPERIMENT_TIMESTAMP}"
MLFLOW_ARTIFACT_LOCATION = f"{EXPERIMENT_ROOT_DBFS}/mlflow_artifacts"

# Create directories
os.makedirs(CHECKPOINT_BASE_DIR, exist_ok=True)
os.makedirs(f"{EXPERIMENT_ROOT}/mlflow_artifacts", exist_ok=True)

# MLflow setup
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
mlflow.enable_system_metrics_logging()

# ============================================================
# GET OR CREATE EXPERIMENT (REUSE IF EXISTS)
# ============================================================

experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)

if experiment:
    experiment_id = experiment.experiment_id
    print(f"✓ Reusing existing MLflow experiment")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Artifact Location: {experiment.artifact_location}")
    
    # Check if artifact location matches desired location
    if experiment.artifact_location != MLFLOW_ARTIFACT_LOCATION:
        print(f"\nNote: Artifact location differs from desired")
        print(f"     Desired:  {MLFLOW_ARTIFACT_LOCATION}")
        print(f"     Actual:   {experiment.artifact_location}")
        print(f"     This is OK - using existing location for consistency")
        # Use existing artifact location to avoid confusion
        MLFLOW_ARTIFACT_LOCATION = experiment.artifact_location
else:
    experiment_id = mlflow.create_experiment(
        name=EXPERIMENT_NAME,
        artifact_location=MLFLOW_ARTIFACT_LOCATION
    )
    experiment = mlflow.get_experiment(experiment_id)
    print(f"✓ Created new MLflow experiment")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Artifact Location: {MLFLOW_ARTIFACT_LOCATION}")

# Set the experiment as active
mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)

# ============================================================
# CONFIGURATION SUMMARY
# ============================================================
print(f"\n{'='*80}")
print(f"EXPERIMENT CONFIGURATION")
print(f"{'='*80}")
print(f"Experiment: {EXPERIMENT_SHORT_NAME}")
print(f"Timestamp:  {EXPERIMENT_TIMESTAMP}")
print(f"\nDataset Configuration:")
print(f"  Storage Location: {data_storage_location}")
print(f"  Training Data:    {data_storage_location}/{mds_train_dir}")
print(f"  Validation Data:  {data_storage_location}/{mds_val_dir}")
print(f"  Num Classes:      {num_classes}")
print(f"\nDirectory Structure:")
print(f"  Root: {EXPERIMENT_ROOT}/")
print(f"  ├── checkpoints/")
print(f"  │   └── {EXPERIMENT_TIMESTAMP}/")
print(f"  │       ├── trial_000_*/")
print(f"  │       ├── trial_001_*/")
print(f"  │       └── ...")
print(f"  └── mlflow_artifacts/")
print(f"      ├── {{parent_run_id}}/artifacts/")
print(f"      ├── {{trial_0_run_id}}/artifacts/")
print(f"      └── ...")
print(f"\nPaths:")
print(f"  Experiment Root:     {EXPERIMENT_ROOT}")
print(f"  Checkpoint Base:     {CHECKPOINT_BASE_DIR}")
print(f"  MLflow Artifacts:    {MLFLOW_ARTIFACT_LOCATION}")
print(f"  MLflow Experiment:   {EXPERIMENT_NAME}")
print(f"  Experiment ID:       {experiment_id}")
print(f"\nTraining Configuration:")
print(f"  Num Classes:      {num_classes}")
print(f"  Num Epochs:       {NUM_EPOCHS}")
print(f"  Num Workers:      {num_workers}")
print(f"  Num Trials:       {N_TRIALS}")
print(f"  Checkpointing:    {ENABLE_CHECKPOINTING}")
print(f"  Checkpoint Freq:  {CHECKPOINT_FREQUENCY}")
print(f"  Resume Enabled:   {RESUME_FROM_CHECKPOINT}")
print(f"{'='*80}\n")

# ============================================================
# VERIFY PATHS
# ============================================================
print("Verifying paths...")
assert os.path.exists(data_storage_location), f"Data storage location not found: {data_storage_location}"
assert os.path.exists(f"{data_storage_location}/{mds_train_dir}"), f"Training data not found: {data_storage_location}/{mds_train_dir}"
assert os.path.exists(f"{data_storage_location}/{mds_val_dir}"), f"Validation data not found: {data_storage_location}/{mds_val_dir}"
assert os.path.exists(CHECKPOINT_BASE_DIR), f"Checkpoint directory not created: {CHECKPOINT_BASE_DIR}"
assert os.path.exists(f"{EXPERIMENT_ROOT}/mlflow_artifacts"), f"MLflow artifacts directory not created"
print("✓ All paths verified")
print(f"  ✓ Training data: {len(os.listdir(f'{data_storage_location}/{mds_train_dir}'))} files")
print(f"  ✓ Validation data: {len(os.listdir(f'{data_storage_location}/{mds_val_dir}'))} files")
print()

# ============================================================
# GLOBAL VARIABLES FOR OBJECTIVE FUNCTION
# ============================================================
EXPERIMENT_RUN_ID = None  # Will be set when parent run starts

# print(f"{'='*80}")
# print(f"READY TO START OPTIMIZATION")
# print(f"{'='*80}")
# print(f"Configuration complete. Run the optimization cell to begin training.")
# print(f"{'='*80}\n")

In [0]:
# /Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/
# └── imagenet_mobilenetv2_hpt_chkpt/                          # Experiment root
#     ├── checkpoints/                          # Operational checkpoints
#     │   ├── 20251013_104011/                  # Run timestamp
#     │   │   ├── trial_0_AdamW_lr0.001/
#     │   │   │   ├── checkpoint_epoch_0.pt
#     │   │   │   ├── checkpoint_epoch_1.pt
#     │   │   │   └── best_checkpoint.pt
#     │   │   ├── trial_1_SGD_lr0.01/
#     │   │   └── ...
#     │   ├── 20251013_150000/                  # Another run
#     │   └── ...
#     └── mlflow_artifacts/                     # MLflow tracking
#         ├── {parent_run_id}/
#         │   └── artifacts/
#         │       └── study_results.json
#         ├── {trial_0_run_id}/
#         │   └── artifacts/
#         │       ├── model/                    # if logged model
#         │       ├── plots/                    # if logged Training curves
#         │       └── best_checkpoint.pt        # Copy for governance
#         └── {trial_1_run_id}/
#             └── artifacts/

In [0]:
def get_dataloader_with_mosaic(remote_path, local_path, batch_size, rank=0):
    """Fixed dataloader with unique cache directories for train/val"""
    print(f"Rank {rank}: Getting optimized MDS data from {remote_path}")
    
    try:
        clean_stale_shared_memory()
    except Exception as e:
        print(f"Shared memory cleanup warning: {e}")
    
    # Create UNIQUE cache directory using remote path to differentiate train/val ##this helped to speed up data handling
    import hashlib
    path_hash = hashlib.md5(remote_path.encode()).hexdigest()[:8]
    unique_id = f"{int(time.time())}_{rank}_{path_hash}"
    local_path_unique = f"/local_disk0/tmp/mds_cache_rank_{rank}_{unique_id}"
    
    # Force cleanup of any existing directory
    if os.path.exists(local_path_unique):
        try:
            shutil.rmtree(local_path_unique)
            time.sleep(0.1)  # Brief pause to ensure cleanup
        except Exception as e:
            print(f"Warning: Could not remove existing cache: {e}")
    
    os.makedirs(local_path_unique, mode=0o755, exist_ok=True)
    print(f"Rank {rank}: Created unique cache directory: {local_path_unique}")
    
    try:
        # MDS configuration for network storage
        dataset = StreamingDataset(
            remote=remote_path,
            local=local_path_unique,
            shuffle=True,
            batch_size=batch_size,
            num_canonical_nodes=1,
            predownload=batch_size * 4,  # Reasonable predownload
            keep_zip=False,  # Don't keep zip files to save space
            download_retry=2,  # Fewer retries
            download_timeout=300,  # 5 minute timeout
            validate_hash=False,
            epoch_size=None,
        )
        
        print(f"Rank {rank}: Created optimized StreamingDataset with {len(dataset)} samples")
        
        dataloader = StreamingDataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=0,  # No multiprocessing to avoid contention
            pin_memory=True,  # Use pinned memory for faster GPU transfer
            drop_last=True,
            persistent_workers=False,
        )
        
        print(f"Rank {rank}: Created optimized dataloader with {len(dataloader)} batches")
        return dataloader, local_path_unique
        
    except Exception as e:
        print(f"Rank {rank}: Error creating optimized dataloader: {e}")
        # Cleanup on error
        if os.path.exists(local_path_unique):
            try:
                shutil.rmtree(local_path_unique)
            except:
                pass
        raise

## e.g. Update configuration - for testing
# NUM_EPOCHS = 1  # Test with just 1 epoch first
# num_workers = 2  # Reduce to 2 workers to decrease I/O contention

In [0]:
### your data structure may be different -->  verify_mds_dataset will need to be updated 

import io
import tempfile
import shutil
import os
from streaming import StreamingDataset

def create_comprehensive_label_mapping(remote_path):
    """Extract class names using DataLoader approach"""
    print(f"Creating comprehensive label mapping from: {remote_path}")
    
    temp_cache = tempfile.mkdtemp(prefix="label_mapping_")
    
    try:
        dataset = StreamingDataset(
            remote=remote_path,
            local=temp_cache,
            shuffle=True,
            batch_size=32  # Use a reasonable batch size
        )
        
        # Create DataLoader
        from torch.utils.data import DataLoader
        dataloader = DataLoader(
            dataset,
            batch_size=32,
            num_workers=0  # Use 0 to avoid multiprocessing issues during class extraction
        )
        
        print(f"Dataset created with {len(dataset)} total samples")
        
        unique_classes = set()
        sample_count = 0
        max_batches = 50  # Process 50 batches to get class names
        
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= max_batches:
                break
                
            # batch['class_name'] is now a list of class names
            class_names = batch['class_name']
            unique_classes.update(class_names)
            sample_count += len(class_names)
            
            if batch_idx % 10 == 0:
                print(f"  Processed {batch_idx} batches, {sample_count} samples, found {len(unique_classes)} unique classes")
            
            if len(unique_classes) >= 200:
                break
        
        classes = sorted(list(unique_classes))
        print(f"Found {len(classes)} unique classes")
        print(f"Sample classes: {classes[:5]}")
        
        label_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        
        print(f"\n=== Label Mapping Summary ===")
        print(f"Total classes: {len(label_to_idx)}")
        
        return label_to_idx
            
    except Exception as e:
        print(f"Error creating label mapping: {e}")
        import traceback
        print(f"Full error: {traceback.format_exc()}")
        return create_imagenet_tiny200_mapping()
    
    finally:
        if os.path.exists(temp_cache):
            shutil.rmtree(temp_cache, ignore_errors=True)

def create_imagenet_tiny200_mapping():
    """Fallback mapping"""
    print("Creating fallback ImageNet Tiny-200 mapping...")
    classes = [f"class_{i:03d}" for i in range(200)]
    return {cls: idx for idx, cls in enumerate(classes)}

def verify_mds_dataset(remote_path):
    """Simple verification for multi-shard MDS dataset"""
    print(f"Verifying multi-shard MDS dataset at: {remote_path}")
    
    if not os.path.exists(remote_path):
        print(f"ERROR: Remote path does not exist: {remote_path}")
        return False
    
    files_in_dir = os.listdir(remote_path)
    print(f"Items in root directory: {len(files_in_dir)}")
    
    # Check for main index.json
    has_main_index = 'index.json' in files_in_dir
    if has_main_index:
        print("Found main index.json")
    
    # Check for numbered shard directories
    shard_dirs = [f for f in files_in_dir if f.isdigit() and os.path.isdir(os.path.join(remote_path, f))]
    shard_dirs.sort(key=int)
    
    print(f"Found {len(shard_dirs)} shard directories")
    if len(shard_dirs) > 0:
        print(f"  Shard range: {shard_dirs[0]} to {shard_dirs[-1]}")
        print(f"  Sample shards: {shard_dirs[:5]}{'...' if len(shard_dirs) > 5 else ''}")
        
        # Check first shard
        if len(shard_dirs) > 0:
            shard_path = os.path.join(remote_path, shard_dirs[0])
            shard_contents = os.listdir(shard_path)
            print(f"Shard {shard_dirs[0]} contents: {shard_contents}")
            
            # Check for shard files
            shard_files = [f for f in shard_contents if f.endswith('.mds') or f.endswith('.mds.zstd')]
            if shard_files:
                shard_file_path = os.path.join(shard_path, shard_files[0])
                file_size = os.path.getsize(shard_file_path)
                print(f"Shard file size: {file_size:,} bytes")
        
        print(f"Verified sample shards")
        return True
    
    print("No valid shard directories found")
    return False


In [0]:
def convert_batch_to_tensors(batch, device=None, class_to_idx=None, rank=None):
    """Convert MDS batch to PyTorch tensors - with automatic parameter handling"""
    import io
    from PIL import Image
    import torchvision.transforms as transforms
    import torch
    
    # Auto-detect device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Handle rank parameter - get from distributed training if available
    if rank is None:
        try:
            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
            else:
                rank = 0
        except:
            rank = 0
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
        ])
    
    # Debug info only for rank 0 and only occasionally
    debug_output = (rank == 0)
    
    # Process images
    images = []
    for i, img_bytes in enumerate(batch['image_data']):
        try:
            img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
            img_tensor = transform(img)
            images.append(img_tensor)
        except Exception as e:
            if debug_output and i < 3:  # Only show first few errors
                print(f"Error processing image {i}: {e}")
            images.append(torch.zeros((3, 224, 224)))
    
    images_tensor = torch.stack(images).to(device)
    
    # Process labels 
    class_names = batch['class_name']
    
    # Ensure class_to_idx is a dictionary, not an integer
    if class_to_idx is None or not isinstance(class_to_idx, dict):
        unique_classes = sorted(set(class_names))
        class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}
        if debug_output:
            print(f"Created temporary class mapping with {len(class_to_idx)} classes")
    
    labels = []
    for class_name in class_names:
        # Use isinstance check instead of 'in' operator with potential int
        if isinstance(class_to_idx, dict) and class_name in class_to_idx:
            labels.append(class_to_idx[class_name])
        else:
            labels.append(0)  # Default to class 0 for unknown classes
            if debug_output:
                print(f"Warning: Unknown class '{class_name}', using default")
    
    labels_tensor = torch.tensor(labels, dtype=torch.long).to(device)
    
    return images_tensor, labels_tensor

In [0]:
import torch

# 1. FIRST: Define device globally
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 2. Create class mapping
print("Creating class mapping...")
label_to_idx = create_comprehensive_label_mapping(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}")

# 3. Verify datasets
print("Verifying datasets...")
train_valid = verify_mds_dataset(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}")
val_valid = verify_mds_dataset(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_val_dir}")

print("Setup complete! Ready for training.")

In [0]:
# Label Mapping
print("=== Creating Comprehensive Label Mapping ===")
train_path = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}"

# Global variables that need to be defined
label_to_idx = {}

try:
    label_to_idx = create_comprehensive_label_mapping(train_path)
    num_classes = len(label_to_idx)
    
    # If we didn't get enough classes, use the fallback
    if num_classes < 50:
        print("Insufficient classes found, using fallback mapping...")
        label_to_idx = create_imagenet_tiny200_mapping()
        num_classes = len(label_to_idx)
        
except Exception as e:
    print(f"Error in comprehensive mapping: {e}")
    print("Using fallback mapping...")
    label_to_idx = create_imagenet_tiny200_mapping()
    num_classes = len(label_to_idx)

print(f"\nFinal label mapping created with {num_classes} classes")

# Test the mapping with a few samples
print("\n=== Testing Label Mapping ===")
test_classes = ['barrel, cask', 'school bus', 'pizza, pizza pie']
for test_class in test_classes:
    idx = label_to_idx.get(test_class, -1)
    print(f"'{test_class}' -> index {idx}")

# Show some actual mappings that exist
print(f"\nFirst 10 actual mappings:")
for i, (class_name, idx) in enumerate(list(label_to_idx.items())[:10]):
    print(f"  {idx}: {class_name}")

print(f"\nLabel mapping setup complete!")

In [0]:
def get_model(lr=0.001):
    """Create MobileNetV2 model for ImageNet Tiny-200"""
    from torchvision.models import MobileNet_V2_Weights
    
    model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    
    # Freeze feature extraction layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace classifier for num_classes
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes)
    
    # Only train the classifier
    for param in model.classifier.parameters():
        param.requires_grad = True
    
    return model

In [0]:
def distributed_train_and_evaluate(lr=0.001, batch_size=32, optimizer_name='AdamW',
                                 weight_decay=1e-4, step_size=7, gamma=0.1,
                                 dropout_rate=0.2, label_smoothing=0.1,
                                 momentum=0.9, nesterov=False, beta1=0.9, beta2=0.999, eps=1e-8,
                                 data_storage_location=None, mds_train_dir=None, 
                                 mds_val_dir=None, num_epochs=2, num_classes=200,
                                 mlflow_run_id=None, mlflow_tracking_uri=None,
                                 mlflow_experiment_name=None, trial_number=0,
                                 run_name=None, checkpoint_dir=None):
    """
    Distributed training with best checkpoint tracking per trial.
    Uses experiment-specific checkpoint directory structure.
    
    Args:
        num_classes: Number of output classes (added for architecture metadata)
        ... (other args same as before)
    """
    import os
    import time
    import shutil
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from datetime import timedelta
    import numpy as np
    
    # Initialize core variables
    cache_paths = []
    training_start_time = time.time()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    
    # Core metrics storage
    epoch_metrics = []
    essential_model_metrics = {
        'parameter_count': 0,
        'trainable_parameters': 0,
        'model_size_mb': 0.0,
        'gradient_norms': [],
        'weight_norms': [],
        'learning_curves': {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    }
    
    # Checkpointing variables
    start_epoch = 0
    best_val_acc = 0.0
    best_checkpoint_path = None
    best_checkpoint_epoch = -1
    checkpoint_paths = []
    
    # Device info
    if global_rank == 0:
        print(f"CUDA available: {torch.cuda.is_available()}")
        print(f"CUDA device count: {torch.cuda.device_count()}")
        print(f"Current CUDA device: {torch.cuda.current_device()}")
        if torch.cuda.is_available():
            print(f"Device name: {torch.cuda.get_device_name(0)}")
        print(f"Checkpoint directory: {checkpoint_dir}")
        print(f"Model configuration: {num_classes} classes, dropout={dropout_rate}")

    try:
        if global_rank == 0:
            print(f"Starting training with lr={lr}, batch_size={batch_size}, optimizer={optimizer_name}")
        
        # Initialize distributed training
        if world_size > 1:
            dist.init_process_group("nccl", timeout=timedelta(seconds=1800))
        
        device = torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
            torch.cuda.empty_cache()
        
        # Create model
        model = get_model(lr=lr)
        
        # Add dropout to classifier
        if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Sequential):
            if len(model.classifier) > 1 and hasattr(model.classifier[1], 'in_features'):
                num_features = model.classifier[1].in_features
                model.classifier = nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.Linear(num_features, 512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(dropout_rate * 0.5),
                    nn.Linear(512, num_classes)  # Use num_classes parameter
                )
        
        model = model.to(device)
        
        # Calculate essential model metrics
        if global_rank == 0:
            essential_model_metrics['parameter_count'] = sum(p.numel() for p in model.parameters())
            essential_model_metrics['trainable_parameters'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
            essential_model_metrics['model_size_mb'] = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
            print(f"Model: {essential_model_metrics['parameter_count']:,} total params, "
                  f"{essential_model_metrics['trainable_parameters']:,} trainable")
        
        if world_size > 1:
            model = DDP(model, device_ids=[local_rank], output_device=local_rank)
        
        # Setup training components
        criterion = LabelSmoothingCrossEntropy(label_smoothing) if label_smoothing > 0 else nn.CrossEntropyLoss()
        
        model_params = [p for p in (model.module if world_size > 1 else model).parameters() if p.requires_grad]
        
        if optimizer_name == 'SGD':
            optimizer = torch.optim.SGD(model_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
        elif optimizer_name == 'Adam':
            optimizer = torch.optim.Adam(model_params, lr=lr, weight_decay=weight_decay, betas=(beta1, beta2), eps=eps)
        else:  # AdamW
            optimizer = torch.optim.AdamW(model_params, lr=lr, weight_decay=weight_decay, betas=(beta1, beta2), eps=eps)
        
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
        
        # Check for existing checkpoint to resume from
        if RESUME_FROM_CHECKPOINT and global_rank == 0 and checkpoint_dir:
            latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
            if latest_checkpoint:
                checkpoint_metadata = load_checkpoint(latest_checkpoint, model, optimizer, scheduler, device)
                if checkpoint_metadata:
                    start_epoch = checkpoint_metadata['epoch'] + 1
                    best_val_acc = checkpoint_metadata['best_val_acc']
                    print(f"Resuming from epoch {start_epoch}, best_val_acc: {best_val_acc:.4f}")
        
        # Synchronize start_epoch and best_val_acc across all ranks
        if world_size > 1:
            start_epoch_tensor = torch.tensor(start_epoch, device=device)
            best_val_acc_tensor = torch.tensor(best_val_acc, device=device)
            dist.broadcast(start_epoch_tensor, src=0)
            dist.broadcast(best_val_acc_tensor, src=0)
            start_epoch = int(start_epoch_tensor.item())
            best_val_acc = float(best_val_acc_tensor.item())
        
        # Create data loaders with predictable cache paths
        effective_batch_size = max(batch_size // world_size, 8) if world_size > 1 else batch_size
        
        train_input_remote_path = os.path.join(data_storage_location, mds_train_dir)
        val_input_remote_path = os.path.join(data_storage_location, mds_val_dir)
        
        # Use predictable cache paths based on checkpoint_dir
        if checkpoint_dir:
            train_cache_base = f"{checkpoint_dir}/cache/train"
            val_cache_base = f"{checkpoint_dir}/cache/val"
        else:
            train_cache_base = f"/local_disk0/tmp/train_cache_{run_name or trial_number}"
            val_cache_base = f"/local_disk0/tmp/val_cache_{run_name or trial_number}"
        
        train_cache_path = f"{train_cache_base}_rank{global_rank}"
        val_cache_path = f"{val_cache_base}_rank{global_rank}"
        
        train_dataloader, _ = get_dataloader_with_mosaic(
            train_input_remote_path, train_cache_path, effective_batch_size, global_rank
        )
        cache_paths.append(train_cache_path)
        
        val_dataloader, _ = get_dataloader_with_mosaic(
            val_input_remote_path, val_cache_path, effective_batch_size, global_rank
        )
        cache_paths.append(val_cache_path)
        
        # ============================================================
        # DATALOADER METRICS COLLECTION
        # ============================================================
        if global_rank == 0:
            train_batches = len(train_dataloader) if train_dataloader else 0
            val_batches = len(val_dataloader) if val_dataloader else 0
            
            dataloader_metrics = {
                'dataloader_train_batches_per_epoch': train_batches,
                'dataloader_val_batches_per_epoch': val_batches,
                'dataloader_effective_batch_size': effective_batch_size,
                'dataloader_world_size': world_size,
                'dataloader_samples_per_epoch_train': train_batches * effective_batch_size * world_size,
                'dataloader_samples_per_epoch_val': val_batches * effective_batch_size * world_size,
            }
            
            essential_model_metrics.update(dataloader_metrics)
            
            print(f"\nDataLoader Configuration:")
            print(f"  Train batches/epoch: {train_batches}")
            print(f"  Val batches/epoch: {val_batches}")
            print(f"  Effective batch size: {effective_batch_size}")
            print(f"  World size: {world_size}")
            print(f"  Est. train samples/epoch: {dataloader_metrics['dataloader_samples_per_epoch_train']}")
            print(f"  Est. val samples/epoch: {dataloader_metrics['dataloader_samples_per_epoch_val']}\n")
        
        if world_size > 1:
            dist.barrier()
        
        # ============================================================
        # TRAINING LOOP WITH ARCHITECTURE METADATA IN CHECKPOINTS
        # ============================================================
        for epoch in range(start_epoch, num_epochs):
            if global_rank == 0:
                print(f"Epoch {epoch+1}/{num_epochs}")
            
            # Training phase
            train_results = train_one_epoch(
                model, criterion, optimizer, scheduler, train_dataloader, 
                epoch, device, global_rank, label_to_idx
            )
            
            # Validation phase
            val_results = evaluate(
                model, criterion, val_dataloader, epoch, device, global_rank, label_to_idx
            )
            
            train_loss, train_acc = train_results['loss'], train_results['accuracy']
            val_loss, val_acc = val_results['loss'], val_results['accuracy']
            
            # Check if this epoch is better than all previous epochs
            is_best = val_acc > best_val_acc
            
            if is_best:
                best_val_acc = val_acc
                best_checkpoint_epoch = epoch
                if global_rank == 0:
                    print(f"*** NEW BEST MODEL *** Epoch {epoch+1}, Val Acc: {best_val_acc:.4f}")
            
            # Store metrics
            if global_rank == 0:
                essential_model_metrics['gradient_norms'].append(train_results.get('gradient_norm', 0.0))
                essential_model_metrics['weight_norms'].append(train_results.get('weight_norm', 0.0))
                essential_model_metrics['learning_curves']['train_loss'].append(float(train_loss))
                essential_model_metrics['learning_curves']['val_loss'].append(float(val_loss))
                essential_model_metrics['learning_curves']['train_acc'].append(float(train_acc))
                essential_model_metrics['learning_curves']['val_acc'].append(float(val_acc))
                
                epoch_data = {
                    'epoch': epoch,
                    'train_loss': float(train_loss),
                    'train_acc': float(train_acc),
                    'val_loss': float(val_loss),
                    'val_acc': float(val_acc),
                    'learning_rate': float(optimizer.param_groups[0]['lr']),
                    'gradient_norm': train_results.get('gradient_norm', 0.0),
                    'weight_norm': train_results.get('weight_norm', 0.0),
                    'is_best': is_best,
                    'best_val_acc_so_far': float(best_val_acc)
                }
                epoch_metrics.append(epoch_data)
                
                # Enhanced logging with best indicator
                status = "[BEST]" if is_best else f"[Best: {best_val_acc:.4f}]"
                print(f'Epoch {epoch+1}: Train {train_acc:.4f}, Val {val_acc:.4f}, '
                      f'Loss {train_loss:.4f}/{val_loss:.4f}, '
                      f'LR {optimizer.param_groups[0]["lr"]:.2e} {status}')
                
                # ============================================================
                # SAVE CHECKPOINT WITH ARCHITECTURE METADATA
                # ============================================================
                if ENABLE_CHECKPOINTING and checkpoint_dir and (epoch + 1) % CHECKPOINT_FREQUENCY == 0:
                    checkpoint_filename = f"checkpoint_epoch_{epoch:03d}.pt"
                    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
                    
                    # Create checkpoint with full architecture metadata
                    checkpoint_data = {
                        # Training state
                        'epoch': epoch,
                        'trial_number': trial_number,
                        'run_name': run_name,
                        'best_val_acc': best_val_acc,
                        'val_acc': val_acc,
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        'is_best': is_best,
                        
                        # Model state
                        'model_state_dict': (model.module if world_size > 1 else model).state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        
                        # Architecture metadata (CRITICAL)
                        'num_classes': num_classes,
                        'dropout_rate': dropout_rate,
                        'model_architecture': 'mobilenetv2',
                        
                        # Hyperparameters
                        'hyperparameters': {
                            'lr': lr,
                            'batch_size': batch_size,
                            'optimizer': optimizer_name,
                            'weight_decay': weight_decay,
                            'dropout_rate': dropout_rate,
                            'label_smoothing': label_smoothing,
                            'momentum': momentum,
                            'nesterov': nesterov,
                            'beta1': beta1,
                            'beta2': beta2,
                            'eps': eps,
                            'step_size': step_size,
                            'gamma': gamma
                        },
                        
                        # Metadata
                        'timestamp': time.time(),
                        'world_size': world_size
                    }
                    
                    torch.save(checkpoint_data, checkpoint_path)
                    checkpoint_paths.append(checkpoint_path)
                    print(f" Saved checkpoint: {checkpoint_filename}")
                    
                    # If this is best, also save as best_checkpoint.pt
                    if is_best:
                        best_checkpoint_path = os.path.join(checkpoint_dir, "best_checkpoint.pt")
                        torch.save(checkpoint_data, best_checkpoint_path)
                        print(f" Saved best checkpoint: best_checkpoint.pt")
                    
                    # Save metadata JSON
                    metadata = {
                        'run_name': run_name,
                        'trial_number': trial_number,
                        'epoch': epoch,
                        'best_val_acc': float(best_val_acc),
                        'val_acc': float(val_acc),
                        'train_loss': float(train_loss),
                        'val_loss': float(val_loss),
                        'is_best': is_best,
                        'checkpoint_path': checkpoint_path,
                        'num_classes': num_classes,
                        'dropout_rate': dropout_rate,
                        'model_architecture': 'mobilenetv2',
                        'timestamp': time.time()
                    }
                    
                    metadata_filename = f"metadata_epoch_{epoch:03d}.json"
                    metadata_path = os.path.join(checkpoint_dir, metadata_filename)
                    import json
                    with open(metadata_path, 'w') as f:
                        json.dump(metadata, f, indent=2)
        
        training_time = time.time() - training_start_time
        
        # ============================================================
        # SAVE FINAL CHECKPOINT WITH ARCHITECTURE METADATA
        # ============================================================
        if ENABLE_CHECKPOINTING and global_rank == 0 and checkpoint_dir:
            final_checkpoint_filename = f"checkpoint_epoch_{num_epochs-1:03d}.pt"
            final_checkpoint_path = os.path.join(checkpoint_dir, final_checkpoint_filename)
            
            # Only save if not already saved
            if not os.path.exists(final_checkpoint_path):
                final_is_best = (num_epochs - 1 == best_checkpoint_epoch)
                
                checkpoint_data = {
                    'epoch': num_epochs - 1,
                    'trial_number': trial_number,
                    'run_name': run_name,
                    'best_val_acc': best_val_acc,
                    'val_acc': val_acc,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'is_best': final_is_best,
                    'model_state_dict': (model.module if world_size > 1 else model).state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'num_classes': num_classes,
                    'dropout_rate': dropout_rate,
                    'model_architecture': 'mobilenetv2',
                    'hyperparameters': {
                        'lr': lr, 'batch_size': batch_size, 'optimizer': optimizer_name,
                        'weight_decay': weight_decay, 'dropout_rate': dropout_rate,
                        'label_smoothing': label_smoothing
                    },
                    'timestamp': time.time(),
                    'world_size': world_size
                }
                
                torch.save(checkpoint_data, final_checkpoint_path)
                checkpoint_paths.append(final_checkpoint_path)
                print(f"Saved final checkpoint: {final_checkpoint_filename}")

                # Update best checkpoint if final is best
                if final_is_best:
                    best_checkpoint_path = os.path.join(checkpoint_dir, "best_checkpoint.pt")
                    torch.save(checkpoint_data, best_checkpoint_path)
                    print(f"Final epoch is best - saved as best_checkpoint.pt")
        
        # ============================================================
        # CALCULATE DATALOADER EFFICIENCY METRICS
        # ============================================================
        if global_rank == 0 and training_time > 0:
            train_batches = essential_model_metrics.get('dataloader_train_batches_per_epoch', 0)
            val_batches = essential_model_metrics.get('dataloader_val_batches_per_epoch', 0)
            
            if train_batches > 0:
                avg_time_per_epoch = training_time / num_epochs
                avg_time_per_batch = avg_time_per_epoch / train_batches
                samples_per_second = (train_batches * effective_batch_size * world_size) / avg_time_per_epoch
                
                efficiency_metrics = {
                    'dataloader_avg_time_per_epoch': avg_time_per_epoch,
                    'dataloader_avg_time_per_batch': avg_time_per_batch,
                    'dataloader_samples_per_second': samples_per_second,
                    'dataloader_throughput_images_per_sec': samples_per_second,
                }
                
                essential_model_metrics.update(efficiency_metrics)
                
                print(f"\nDataLoader Efficiency:")
                print(f"  Avg time/epoch: {avg_time_per_epoch:.2f}s")
                print(f"  Avg time/batch: {avg_time_per_batch:.4f}s")
                print(f"  Throughput: {samples_per_second:.1f} samples/sec")
        
        # ============================================================
        # FINAL SUMMARY
        # ============================================================
        if global_rank == 0:
            print(f"\n{'='*60}")
            print(f"Training Complete - {run_name or f'Trial {trial_number}'}")
            print(f"{'='*60}")
            print(f"Best Validation Accuracy: {best_val_acc:.4f} (Epoch {best_checkpoint_epoch+1})")
            print(f"Final Validation Accuracy: {val_acc:.4f}")
            print(f"Training Time: {training_time:.2f}s ({training_time/60:.1f} min)")
            if best_checkpoint_path:
                print(f"Best Checkpoint: {best_checkpoint_path}")
            print(f"Checkpoint Directory: {checkpoint_dir}")
            print(f"Total Checkpoints Saved: {len(checkpoint_paths)}")
            print(f"Model Architecture: {num_classes} classes, dropout={dropout_rate}")
            print(f"{'='*60}\n")
        
        # ============================================================
        # RETURN RESULTS WITH ARCHITECTURE INFO
        # ============================================================
        return {
            "val_acc": float(best_val_acc),
            "train_loss": float(train_loss),
            "val_loss": float(val_loss),
            "train_acc": float(train_acc),
            "best_val_acc": float(best_val_acc),
            "best_checkpoint_path": best_checkpoint_path,
            "best_checkpoint_epoch": best_checkpoint_epoch,
            "final_val_acc": float(val_acc),
            "status": "completed",
            "epochs_completed": num_epochs,
            "training_time": training_time,
            "epoch_metrics": epoch_metrics,
            "model_metrics": essential_model_metrics,
            "checkpoint_paths": checkpoint_paths,
            "checkpoint_dir": checkpoint_dir,
            "hyperparameters": {
                "lr": lr,
                "batch_size": batch_size,
                "optimizer": optimizer_name,
                "weight_decay": weight_decay,
                "dropout_rate": dropout_rate,
                "label_smoothing": label_smoothing,
                "num_classes": num_classes  # Include in return
            },
            "architecture": {
                "num_classes": num_classes,
                "dropout_rate": dropout_rate,
                "model_name": "mobilenetv2"
            }
        }
        
    except Exception as e:
        print(f"Rank {global_rank}: Training error: {e}")
        import traceback
        traceback.print_exc()
        return {
            "val_acc": 0.0,
            "status": "failed",
            "error": str(e),
            "checkpoint_paths": checkpoint_paths,
            "best_checkpoint_path": best_checkpoint_path,
            "checkpoint_dir": checkpoint_dir,
            "epoch_metrics": epoch_metrics,
            "model_metrics": essential_model_metrics,
            "architecture": {
                "num_classes": num_classes,
                "dropout_rate": dropout_rate,
                "model_name": "mobilenetv2"
            }
        }
    
    finally:
        # Cleanup cache paths
        for cache_path in cache_paths:
            if os.path.exists(cache_path):
                shutil.rmtree(cache_path, ignore_errors=True)
        
        # Cleanup distributed process group
        if world_size > 1 and dist.is_initialized():
            dist.destroy_process_group()



##----------------------------------------------------------------            

class LabelSmoothingCrossEntropy(nn.Module):
    ## import torch.nn as nn ## moved up in dependencies import

    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_preds = torch.log_softmax(pred, dim=1)
        smooth_target = torch.zeros_like(log_preds).scatter_(1, target.unsqueeze(1), 1)
        smooth_target = smooth_target * (1 - self.smoothing) + self.smoothing / n_classes
        return (-smooth_target * log_preds).sum(dim=1).mean()


def train_one_epoch(model, criterion, optimizer, scheduler, train_dataloader, 
                             epoch, device, global_rank, label_to_idx):
    
    """Define training loop with essential metrics"""
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    successful_batches = 0
    gradient_norms = []
    
    for step, batch in enumerate(train_dataloader):
        try:
            inputs, labels = convert_batch_to_tensors(batch, device, label_to_idx, global_rank)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Calculate gradient norm
            if global_rank == 0 and step % 50 == 0:
                grad_norm = calculate_gradient_norm(model, 1 if torch.distributed.get_world_size() == 1 else torch.distributed.get_world_size())
                gradient_norms.append(grad_norm)
            
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            successful_batches += 1
            
            if global_rank == 0 and step % 100 == 0:
                current_acc = correct_predictions / total_samples if total_samples > 0 else 0
                print(f"  Step {step}: Loss {loss.item():.4f}, Acc {current_acc:.4f}")
                
        except Exception as e:
            if step < 5:
                print(f"Training step {step} error: {e}")
            continue
    
    scheduler.step()
    
    epoch_loss = running_loss / successful_batches if successful_batches > 0 else 0.0
    epoch_acc = correct_predictions / total_samples if total_samples > 0 else 0.0
    avg_gradient_norm = np.mean(gradient_norms) if gradient_norms else 0.0
    
    return {
        'loss': epoch_loss,
        'accuracy': epoch_acc,
        'gradient_norm': avg_gradient_norm,
        'weight_norm': calculate_weight_norm(model, 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size())
    }


def evaluate(model, criterion, val_dataloader, epoch, device, global_rank, label_to_idx):
    """Define evaluation loop"""
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    successful_batches = 0
    
    with torch.no_grad():
        for step, batch in enumerate(val_dataloader):
            try:
                inputs, labels = convert_batch_to_tensors(batch, device, label_to_idx, global_rank)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_samples += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()
                successful_batches += 1
                
            except Exception as e:
                continue
    
    val_loss = running_loss / successful_batches if successful_batches > 0 else 0.0
    val_acc = correct_predictions / total_samples if total_samples > 0 else 0.0
    
    return {'loss': val_loss, 'accuracy': val_acc}

In [0]:
# ============================================================
# CHECKPOINT UTILITIES 
# ============================================================
import json
import glob
import shutil
import time
import torch
import os
from pathlib import Path

## this wrapper is defined but not used in the end.
def save_checkpoint(model, optimizer, scheduler, epoch, trial_number,
                     best_val_acc, checkpoint_dir, 
                     global_rank=0, world_size=1, run_name=None,
                     num_classes=200, dropout_rate=0.2, val_acc=0.0, train_loss=0.0, val_loss=0.0, is_best=False):
    """
    Save model checkpoint with architecture metadata.
    Checkpoints are saved directly in the trial directory.
    
    Args:
        model: PyTorch model
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        epoch: Current epoch
        trial_number: Trial number
        best_val_acc: Best validation accuracy so far
        checkpoint_dir: Trial-specific checkpoint directory
        global_rank: Process rank (only rank 0 saves)
        world_size: Total number of processes
        run_name: Run name (for metadata only)
        num_classes: Number of output classes (for architecture reconstruction)
        dropout_rate: Dropout rate used in model (for architecture reconstruction)
        val_acc: Current epoch validation accuracy
        train_loss: Current epoch training loss
        val_loss: Current epoch validation loss
        is_best: Whether this is the best checkpoint so far
        
    Returns:
        str: Path to saved checkpoint or None
    """
    if global_rank != 0:  # Only rank 0 saves checkpoints
        return None
    
    try:
        # Checkpoint directory should already be trial-specific
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Prepare checkpoint data with architecture metadata
        checkpoint_data = {
            # Training state
            'epoch': epoch,
            'trial_number': trial_number,
            'run_name': run_name,
            'best_val_acc': best_val_acc,
            'val_acc': val_acc,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'is_best': is_best,
            
            # Model state
            'model_state_dict': model.module.state_dict() if world_size > 1 else model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            
            # Architecture metadata (CRITICAL for loading)
            'num_classes': num_classes,
            'dropout_rate': dropout_rate,
            'model_architecture': 'mobilenetv2',
            
            # Metadata
            'timestamp': time.time(),
            'world_size': world_size
        }
        
        # Save epoch checkpoint with zero-padded naming
        checkpoint_filename = f"checkpoint_epoch_{epoch:03d}.pt"
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
        torch.save(checkpoint_data, checkpoint_path)
        
        # Also save as best_checkpoint.pt if this is the best
        if is_best:
            best_checkpoint_path = os.path.join(checkpoint_dir, "best_checkpoint.pt")
            torch.save(checkpoint_data, best_checkpoint_path)
            print(f"★ Saved best checkpoint: best_checkpoint.pt (epoch {epoch})")
        
        # Save metadata JSON
        metadata = {
            'run_name': run_name,
            'trial_number': trial_number,
            'epoch': epoch,
            'best_val_acc': float(best_val_acc),
            'val_acc': float(val_acc),
            'train_loss': float(train_loss),
            'val_loss': float(val_loss),
            'is_best': is_best,
            'checkpoint_path': checkpoint_path,
            'num_classes': num_classes,
            'dropout_rate': dropout_rate,
            'model_architecture': 'mobilenetv2',
            'timestamp': time.time()
        }
        
        metadata_filename = f"metadata_epoch_{epoch:03d}.json"
        metadata_path = os.path.join(checkpoint_dir, metadata_filename)
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"✓ Saved checkpoint: {checkpoint_filename} (val_acc: {val_acc:.4f})")
        return checkpoint_path
        
    except Exception as e:
        print(f"⚠ Error saving checkpoint: {e}")
        import traceback
        traceback.print_exc()
        return None


def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None, device=None):
    """
    Load checkpoint and return metadata.
    
    Args:
        checkpoint_path: Path to checkpoint file
        model: PyTorch model to load state into
        optimizer: Optimizer (optional)
        scheduler: Scheduler (optional)
        device: Device to load tensors to
        
    Returns:
        dict: Checkpoint metadata or None
    """
    try:
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model state
        if hasattr(model, 'module'):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state if provided
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler state if provided
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        metadata = {
            'epoch': checkpoint.get('epoch', 0),
            'trial_number': checkpoint.get('trial_number', 0),
            'run_name': checkpoint.get('run_name', None),
            'best_val_acc': checkpoint.get('best_val_acc', 0.0),
            'timestamp': checkpoint.get('timestamp', 0)
        }
        
        print(f"✓ Loaded checkpoint: {os.path.basename(checkpoint_path)}")
        print(f"  Epoch: {metadata['epoch']}, Val Acc: {metadata['best_val_acc']:.4f}")
        
        return metadata
        
    except Exception as e:
        print(f"⚠ Error loading checkpoint: {e}")
        import traceback
        traceback.print_exc()
        return None


def find_latest_checkpoint(checkpoint_dir):
    """
    Find the latest checkpoint in a directory.
    
    Args:
        checkpoint_dir: Directory containing checkpoints
        
    Returns:
        str: Path to latest checkpoint or None
    """
    if not os.path.exists(checkpoint_dir):
        return None
    
    # Find all epoch checkpoints
    pattern = os.path.join(checkpoint_dir, "checkpoint_epoch_*.pt")
    checkpoints = glob.glob(pattern)
    
    if not checkpoints:
        return None
    
    # Sort by epoch number (extracted from filename)
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest = checkpoints[-1]
    
    print(f"Found latest checkpoint: {os.path.basename(latest)}")
    return latest


def get_checkpoint_info(checkpoint_dir):
    """
    Get information about all checkpoints in a directory.
    
    Args:
        checkpoint_dir: Directory containing checkpoints
        
    Returns:
        dict: Checkpoint information
    """
    if not os.path.exists(checkpoint_dir):
        return {
            'exists': False,
            'checkpoint_count': 0,
            'checkpoints': []
        }
    
    checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_epoch_*.pt")))
    best_checkpoint = os.path.join(checkpoint_dir, "best_checkpoint.pt")
    
    checkpoint_info = {
        'exists': True,
        'checkpoint_dir': checkpoint_dir,
        'checkpoint_count': len(checkpoint_files),
        'has_best': os.path.exists(best_checkpoint),
        'checkpoints': []
    }
    
    for cp_path in checkpoint_files:
        cp_size = os.path.getsize(cp_path) / (1024 * 1024)  # MB
        epoch = int(cp_path.split('_')[-1].split('.')[0])
        
        checkpoint_info['checkpoints'].append({
            'epoch': epoch,
            'filename': os.path.basename(cp_path),
            'path': cp_path,
            'size_mb': round(cp_size, 2)
        })
    
    if checkpoint_info['has_best']:
        best_size = os.path.getsize(best_checkpoint) / (1024 * 1024)
        checkpoint_info['best_checkpoint'] = {
            'filename': 'best_checkpoint.pt',
            'path': best_checkpoint,
            'size_mb': round(best_size, 2)
        }
    
    return checkpoint_info

In [0]:
# # Example 1: Find and load latest checkpoint
# checkpoint_dir = f"{CHECKPOINT_BASE_DIR}/trial_005_AdamW_lr1.23e-03_bs64"
# latest_checkpoint = find_latest_checkpoint(checkpoint_dir)

# if latest_checkpoint:
#     metadata = load_checkpoint(latest_checkpoint, model, optimizer, scheduler)
#     start_epoch = metadata['epoch'] + 1

# # Example 2: Get checkpoint info
# info = get_checkpoint_info(checkpoint_dir)
# print(f"Found {info['checkpoint_count']} checkpoints")
# print(f"Has best checkpoint: {info['has_best']}")

# # Example 3: Load best checkpoint
# best_checkpoint_path = os.path.join(checkpoint_dir, "best_checkpoint.pt")
# if os.path.exists(best_checkpoint_path):
#     model_state = torch.load(best_checkpoint_path)
#     model.load_state_dict(model_state['model_state_dict'])

In [0]:
def calculate_gradient_norm(model, world_size):
    """Calculate L2 norm of gradients"""
    import torch
    
    total_norm = 0.0
    param_count = 0
    
    if world_size > 1:
        parameters = model.module.parameters()
    else:
        parameters = model.parameters()
    
    for p in parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
    
    return (total_norm ** 0.5) if param_count > 0 else 0.0


def calculate_weight_norm(model, world_size):
    """Calculate L2 norm of model weights"""
    import torch
    
    total_norm = 0.0
    param_count = 0
    
    if world_size > 1:
        parameters = model.module.parameters()
    else:
        parameters = model.parameters()
    
    for p in parameters:
        param_norm = p.data.norm(2)
        total_norm += param_norm.item() ** 2
        param_count += 1
    
    return (total_norm ** 0.5) if param_count > 0 else 0.0


In [0]:
def objective_function(trial):
    """
    Objective function with experiment-specific checkpoint organization.
    Uses predictable paths and avoids temporary file names.
    """
    import json
    import os
    import time
    import numpy as np
    from pathlib import Path
    import mlflow 
    import torch
    from pyspark.ml.torch.distributor import TorchDistributor
    
    # ============================================================
    # HYPERPARAMETER SAMPLING
    # ============================================================
    
    # Core hyperparameters
    lr = trial.suggest_float('lr', 5e-5, 5e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])
    optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'SGD', 'Adam'])
    
    # Optimizer-specific weight decay
    if optimizer_name == 'AdamW':
        weight_decay = trial.suggest_float('weight_decay', 1e-4, 1e-1, log=True)
    elif optimizer_name == 'SGD':
        weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2, log=True)
    else:
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)
    
    # Additional key hyperparameters
    step_size = trial.suggest_int('step_size', 5, 15)
    gamma = trial.suggest_float('gamma', 0.1, 0.7)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    label_smoothing = trial.suggest_float('label_smoothing', 0.0, 0.2)
    
    # Optimizer-specific parameters
    momentum = trial.suggest_float('momentum', 0.85, 0.99) if optimizer_name == 'SGD' else 0.9
    nesterov = trial.suggest_categorical('nesterov', [True, False]) if optimizer_name == 'SGD' else False
    beta1 = trial.suggest_float('beta1', 0.85, 0.95) if optimizer_name == 'AdamW' else 0.9
    beta2 = trial.suggest_float('beta2', 0.95, 0.999) if optimizer_name == 'AdamW' else 0.999
    eps = trial.suggest_float('eps', 1e-9, 1e-7, log=True) if optimizer_name == 'AdamW' else 1e-8
    
    print(f"Trial {trial.number}: {optimizer_name}, lr={lr:.2e}, bs={batch_size}, wd={weight_decay:.2e}")
    
    # ============================================================
    # PREDICTABLE PATH SETUP
    # ============================================================
    
    # Create descriptive run name (no random components)
    run_name = f"trial_{trial.number:03d}_{optimizer_name}_lr{lr:.2e}_bs{batch_size}"
    
    # Trial-specific checkpoint directory (predictable, organized)
    trial_checkpoint_dir = os.path.join(CHECKPOINT_BASE_DIR, run_name)
    Path(trial_checkpoint_dir).mkdir(parents=True, exist_ok=True)
    
    # MLflow artifact subpath (consistent naming)
    mlflow_artifact_subpath = "checkpoints"
    
    print(f"\n{'='*60}")
    print(f"TRIAL {trial.number} PATHS")
    print(f"{'='*60}")
    print(f"Experiment Root:     {EXPERIMENT_ROOT}")
    print(f"Checkpoint Base:     {CHECKPOINT_BASE_DIR}")
    print(f"Trial Directory:     {trial_checkpoint_dir}")
    print(f"Run Name:            {run_name}")
    print(f"MLflow Artifact Sub: {mlflow_artifact_subpath}")
    print(f"{'='*60}\n")
    
    # ============================================================
    # MLFLOW RUN SETUP
    # ============================================================
    
    with mlflow.start_run(run_name=run_name, nested=True) as trial_run:
        trial_start_time = time.time()
        trial_run_id = trial_run.info.run_id
        
        print(f"MLflow Run ID: {trial_run_id}")
        print(f"Artifact URI:  {trial_run.info.artifact_uri}\n")
        
        # Log essential parameters
        trial_params = {
            'experiment_run_id': EXPERIMENT_RUN_ID,
            'experiment_timestamp': EXPERIMENT_TIMESTAMP,
            'experiment_short_name': EXPERIMENT_SHORT_NAME,
            'trial_number': trial.number,
            'run_name': run_name,
            
            # Hyperparameters
            'lr': lr, 
            'batch_size': batch_size, 
            'optimizer': optimizer_name,
            'weight_decay': weight_decay, 
            'step_size': step_size, 
            'gamma': gamma,
            'dropout_rate': dropout_rate, 
            'label_smoothing': label_smoothing,
            'momentum': momentum, 
            'nesterov': nesterov, 
            'beta1': beta1, 
            'beta2': beta2, 
            'eps': eps,
            
            # Training config
            'num_epochs': NUM_EPOCHS, 
            'num_workers': num_workers,
            'model_architecture': 'mobilenetv2',
            
            # Dataloader configuration
            'dataloader_type': 'streaming_mds',
            'dataloader_library': 'mosaic_streaming',
            'dataloader_num_workers': 0,
            'dataloader_pin_memory': True,
            'dataloader_drop_last': True,
            'dataloader_persistent_workers': False,
            'dataloader_shuffle': True,
            'dataloader_predownload_multiplier': 4,
            'dataloader_keep_zip': False,
            'dataloader_download_retry': 2,
            'dataloader_download_timeout': 300,
            'dataloader_validate_hash': False,
            
            # Dataset configuration
            'dataset_format': 'mds',
            'dataset_storage_type': 'uc_volumes',
            'dataset_train_path': f"{data_storage_location}/{mds_train_dir}",
            'dataset_val_path': f"{data_storage_location}/{mds_val_dir}",
            'dataset_num_classes': num_classes,
            'dataset_cache_location': '/local_disk0/tmp/mds_cache',
            
            # Checkpointing
            'checkpointing_enabled': ENABLE_CHECKPOINTING,
            'checkpoint_frequency': CHECKPOINT_FREQUENCY,
            'checkpoint_base_dir': CHECKPOINT_BASE_DIR,
            'checkpoint_trial_dir': trial_checkpoint_dir,
            'mlflow_artifact_subpath': mlflow_artifact_subpath,
            'experiment_root': EXPERIMENT_ROOT
        }
        mlflow.log_params(trial_params)

        # ============================================================
        # LOG DATASETS AS JSON artifact
        # ============================================================                
        try:
            dataset_metadata = {
                "train": {
                    "source": f"{data_storage_location}/{mds_train_dir}",
                    "format": "mds",
                    "storage": "uc_volumes",
                    "num_classes": num_classes,
                    "catalog": CATALOG,
                    "schema": SCHEMA,
                    "volume": VOLUME_NAME,
                    "split": "train"
                },
                "validation": {
                    "source": f"{data_storage_location}/{mds_val_dir}",
                    "format": "mds",
                    "storage": "uc_volumes",
                    "num_classes": num_classes,
                    "catalog": CATALOG,
                    "schema": SCHEMA,
                    "volume": VOLUME_NAME,
                    "split": "validation"
                },
                "experiment_timestamp": EXPERIMENT_TIMESTAMP,
                "trial_number": trial.number,
                "logged_at": time.time()
            }
            
            # Save to trial checkpoint directory
            dataset_metadata_path = os.path.join(trial_checkpoint_dir, "dataset_metadata.json")
            with open(dataset_metadata_path, 'w') as f:
                json.dump(dataset_metadata, f, indent=2)
            
            # Log to MLflow
            mlflow.log_artifact(dataset_metadata_path, "datasets")
            print("✓ Logged dataset metadata as artifact")
            
        except Exception as e:
            print(f"⚠ Dataset metadata logging failed: {e}")
            
        try:
            # ============================================================
            # DISTRIBUTED TRAINING EXECUTION
            # ============================================================
            
            distributor = TorchDistributor(num_processes=num_workers, local_mode=False, use_gpu=True)
            result = distributor.run(
                distributed_train_and_evaluate,
                lr=lr, 
                batch_size=batch_size, 
                optimizer_name=optimizer_name,
                weight_decay=weight_decay, 
                step_size=step_size, 
                gamma=gamma,
                dropout_rate=dropout_rate, 
                label_smoothing=label_smoothing,
                momentum=momentum, 
                nesterov=nesterov, 
                beta1=beta1, 
                beta2=beta2, 
                eps=eps,
                data_storage_location=data_storage_location,
                mds_train_dir=mds_train_dir, 
                mds_val_dir=mds_val_dir,
                num_epochs=NUM_EPOCHS,
                num_classes=num_classes,  
                mlflow_run_id=trial_run_id,
                mlflow_tracking_uri=mlflow.get_tracking_uri(),
                trial_number=trial.number,
                run_name=run_name,
                checkpoint_dir=trial_checkpoint_dir
            )
            
            trial_end_time = time.time()
            trial_duration = trial_end_time - trial_start_time
            
            # ============================================================
            # RESULT VALIDATION
            # ============================================================
            
            if not isinstance(result, dict):
                print(f"WARNING: Result is not a dict, got {type(result)}")
                mlflow.log_param("result_type_error", f"expected_dict_got_{type(result).__name__}")
                mlflow.log_metric("final_validation_accuracy", 0.0)
                mlflow.set_tag("trial_status", "failed_invalid_result")
                return 0.0
            
            # Extract key metrics
            trial_metric = result.get('val_acc', 0.0)
            best_val_acc = result.get('best_val_acc', trial_metric)
            final_val_acc = result.get('final_val_acc', trial_metric)
            best_checkpoint_path = result.get('best_checkpoint_path')
            best_checkpoint_epoch = result.get('best_checkpoint_epoch', -1)
            checkpoint_paths = result.get('checkpoint_paths', [])
            returned_checkpoint_dir = result.get('checkpoint_dir', '')
            
            # Verify checkpoint directory
            if returned_checkpoint_dir != trial_checkpoint_dir:
                print(f"⚠ WARNING: Checkpoint directory mismatch!")
                print(f"  Expected: {trial_checkpoint_dir}")
                print(f"  Returned: {returned_checkpoint_dir}")
                mlflow.log_param("checkpoint_dir_mismatch", "yes")
            else:
                mlflow.log_param("checkpoint_dir_verified", "yes")
            
            # ============================================================
            # CHECKPOINT LOGGING TO MLFLOW
            # ============================================================

            if checkpoint_paths:
                print(f"\n{'='*60}")
                print(f"LOGGING CHECKPOINTS - Trial {trial.number}")
                print(f"{'='*60}")
                print(f"Total checkpoints: {len(checkpoint_paths)}")
                print(f"Physical location: {trial_checkpoint_dir}")
                print(f"MLflow artifact:   {mlflow_artifact_subpath}")
                
                logged_count = 0
                failed_count = 0
                total_size_mb = 0.0
                
                checkpoint_manifest = {
                    "trial_number": trial.number,
                    "run_name": run_name,
                    "physical_directory": trial_checkpoint_dir,
                    "mlflow_artifact_uri": f"{trial_run.info.artifact_uri}/{mlflow_artifact_subpath}",
                    "total_checkpoints": len(checkpoint_paths),
                    "checkpoint_files": []
                }
                
                # Sort checkpoints for consistent ordering
                checkpoint_paths_sorted = sorted(checkpoint_paths)
                
                # ============================================================
                # Log epoch checkpoints (INSIDE LOOP)
                # ============================================================
                for checkpoint_path in checkpoint_paths_sorted:
                    if os.path.exists(checkpoint_path):
                        try:
                            checkpoint_filename = os.path.basename(checkpoint_path)
                            checkpoint_size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
                            total_size_mb += checkpoint_size_mb
                            
                            # Load checkpoint to check if it was the best epoch
                            checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
                            is_best_epoch = checkpoint_data.get('is_best', False)
                            checkpoint_epoch = checkpoint_data.get('epoch', -1)
                            
                            # Log checkpoint file
                            mlflow.log_artifact(checkpoint_path, mlflow_artifact_subpath)
                            logged_count += 1
                            
                            # Determine if this is the best checkpoint
                            is_best = (checkpoint_epoch == best_checkpoint_epoch)
                            
                            marker = "★" if is_best else "✓"
                            print(f"  {marker} {checkpoint_filename} ({checkpoint_size_mb:.2f} MB)"
                                f"{' [BEST EPOCH]' if is_best else ''}")
                            
                            # Add to manifest
                            checkpoint_manifest["checkpoint_files"].append({
                                "filename": checkpoint_filename,
                                "epoch": checkpoint_epoch,
                                "size_mb": round(checkpoint_size_mb, 2),
                                "physical_path": checkpoint_path,
                                "is_best": is_best,
                                "was_best_when_saved": is_best_epoch
                            })
                            
                        except Exception as e:
                            failed_count += 1
                            print(f"  ⚠ Failed to log {checkpoint_filename}: {e}")
                    else:
                        failed_count += 1
                        print(f"  ⚠ Not found: {checkpoint_path}")
                
                # ============================================================
                # Log best_checkpoint.pt ONCE (OUTSIDE LOOP) - FIXED
                # ============================================================
                best_checkpoint_file = os.path.join(trial_checkpoint_dir, "best_checkpoint.pt")
                if os.path.exists(best_checkpoint_file):
                    try:
                        best_size_mb = os.path.getsize(best_checkpoint_file) / (1024 * 1024)
                        total_size_mb += best_size_mb
                        
                        mlflow.log_artifact(best_checkpoint_file, mlflow_artifact_subpath)
                        logged_count += 1
                        
                        print(f"  ★ best_checkpoint.pt ({best_size_mb:.2f} MB) [BEST MODEL]")
                        
                        checkpoint_manifest["checkpoint_files"].append({
                            "filename": "best_checkpoint.pt",
                            "epoch": best_checkpoint_epoch,
                            "size_mb": round(best_size_mb, 2),
                            "physical_path": best_checkpoint_file,
                            "is_best": True,
                            "was_best_when_saved": True,
                            "note": "Copy of best epoch checkpoint"
                        })
                        
                    except Exception as e:
                        print(f"  ⚠ Failed to log best_checkpoint.pt: {e}")
                
                # ============================================================
                # Finalize manifest (OUTSIDE LOOP) - FIXED
                # ============================================================
                checkpoint_manifest.update({
                    "logged_count": logged_count,
                    "failed_count": failed_count,
                    "total_size_mb": round(total_size_mb, 2),
                    "best_checkpoint": {
                        "filename": os.path.basename(best_checkpoint_path) if best_checkpoint_path else None,
                        "epoch": best_checkpoint_epoch,
                        "validation_accuracy": best_val_acc
                    }
                })
                
                print(f"\nSummary: {logged_count}/{len(checkpoint_paths) + 1} logged, {total_size_mb:.2f} MB")
                print(f"{'='*60}\n")
                
                # Log checkpoint summary metrics
                mlflow.log_metrics({
                    "total_checkpoints_saved": len(checkpoint_paths),
                    "checkpoints_logged_to_mlflow": logged_count,
                    "checkpoints_failed_to_log": failed_count,
                    "total_checkpoints_size_mb": total_size_mb
                })
                
                # Save and log manifest with predictable name
                manifest_filename = f"checkpoint_manifest_trial_{trial.number:03d}.json"
                manifest_path = os.path.join(trial_checkpoint_dir, manifest_filename)
                
                with open(manifest_path, 'w') as f:
                    json.dump(checkpoint_manifest, f, indent=2)
                
                mlflow.log_artifact(manifest_path, mlflow_artifact_subpath)
                print(f"✓ Logged manifest: {manifest_filename}\n")
            
            # ============================================================
            # MODEL METRICS LOGGING
            # ============================================================
            
            model_metrics = result.get('model_metrics', {})
            if model_metrics:
                mlflow.log_metrics({
                    "model_parameter_count": model_metrics.get('parameter_count', 0),
                    "model_trainable_parameters": model_metrics.get('trainable_parameters', 0),
                    "model_size_mb": model_metrics.get('model_size_mb', 0.0)
                })
            
                gradient_norms = model_metrics.get('gradient_norms', [])
                weight_norms = model_metrics.get('weight_norms', [])
                
                if gradient_norms:
                    mlflow.log_metrics({
                        "final_gradient_norm": gradient_norms[-1],
                        "mean_gradient_norm": np.mean(gradient_norms),
                        "max_gradient_norm": np.max(gradient_norms),
                        "gradient_stability": 1.0 / (np.var(gradient_norms) + 1e-8)
                    })
                
                if weight_norms:
                    mlflow.log_metrics({
                        "final_weight_norm": weight_norms[-1],
                        "mean_weight_norm": np.mean(weight_norms),
                        "max_weight_norm": np.max(weight_norms)
                    })
            
            # ============================================================
            # EPOCH-BY-EPOCH METRICS LOGGING
            # ============================================================
            
            epoch_metrics = result.get('epoch_metrics', [])
            if epoch_metrics:
                for epoch_data in epoch_metrics:
                    epoch = epoch_data['epoch']
                    mlflow.log_metrics({
                        "train_loss_epoch": epoch_data['train_loss'],
                        "train_accuracy_epoch": epoch_data['train_acc'],
                        "val_loss_epoch": epoch_data['val_loss'],
                        "val_accuracy_epoch": epoch_data['val_acc'],
                        "learning_rate_epoch": epoch_data['learning_rate'],
                        "gradient_norm_epoch": epoch_data['gradient_norm']
                    }, step=epoch)
                    
                    if epoch_data.get('is_best', False):
                        mlflow.log_metric("is_best_epoch", 1.0, step=epoch)
                
                # Save epoch metrics to file with predictable name
                epoch_metrics_filename = f"epoch_metrics_trial_{trial.number:03d}.json"
                epoch_metrics_path = os.path.join(trial_checkpoint_dir, epoch_metrics_filename)
                
                with open(epoch_metrics_path, 'w') as f:
                    json.dump(epoch_metrics, f, indent=2)
                
                mlflow.log_artifact(epoch_metrics_path, "metrics")
                print(f"✓ Logged epoch metrics: {epoch_metrics_filename}")
            
            # ============================================================
            # FINAL TRAINING METRICS
            # ============================================================
            
            training_metrics = {
                'best_validation_accuracy': best_val_acc,
                'final_validation_accuracy': final_val_acc,
                'final_train_loss': result.get('train_loss', 0.0),
                'final_val_loss': result.get('val_loss', 0.0),
                'final_train_accuracy': result.get('train_acc', 0.0),
                'epochs_completed': result.get('epochs_completed', NUM_EPOCHS),
                'training_time_seconds': result.get('training_time', 0.0),
                'trial_duration_seconds': trial_duration
            }
            mlflow.log_metrics(training_metrics)
            
            # ============================================================
            # EFFICIENCY METRICS
            # ============================================================
            
            if trial_duration > 0:
                total_samples = 100000 * NUM_EPOCHS  # Adjust based on dataset
                mlflow.log_metrics({
                    'samples_per_second': total_samples / trial_duration,
                    'time_per_epoch_seconds': trial_duration / NUM_EPOCHS,
                    'throughput_images_per_second': total_samples / trial_duration
                })
                
                if best_checkpoint_epoch >= 0:
                    time_to_best = (best_checkpoint_epoch + 1) * (trial_duration / NUM_EPOCHS)
                    mlflow.log_metrics({
                        'time_to_best_checkpoint_seconds': time_to_best,
                        'efficiency_score': best_val_acc / (time_to_best / 60)
                    })
            
            # ============================================================
            # CONVERGENCE METRICS
            # ============================================================
            
            if epoch_metrics and len(epoch_metrics) > 1:
                val_accs = [ep['val_acc'] for ep in epoch_metrics]
                train_losses = [ep['train_loss'] for ep in epoch_metrics]
                val_losses = [ep['val_loss'] for ep in epoch_metrics]
                
                final_acc = val_accs[-1]
                initial_acc = val_accs[0]
                improvement = final_acc - initial_acc
                
                mlflow.log_metrics({
                    "total_accuracy_improvement": improvement,
                    "accuracy_improvement_rate": improvement / len(val_accs)
                })
                
                # Training stability
                if len(val_accs) >= 3:
                    recent_stability = 1.0 / (np.var(val_accs[-3:]) + 1e-8)
                    overall_stability = 1.0 / (np.var(val_accs) + 1e-8)
                    mlflow.log_metrics({
                        "validation_stability": recent_stability,
                        "overall_validation_stability": overall_stability
                    })
                
                # Loss convergence
                if len(train_losses) >= 2:
                    train_loss_improvement = train_losses[0] - train_losses[-1]
                    val_loss_improvement = val_losses[0] - val_losses[-1]
                    mlflow.log_metrics({
                        "train_loss_improvement": train_loss_improvement,
                        "val_loss_improvement": val_loss_improvement
                    })
                
                # Overfitting detection
                if best_checkpoint_epoch >= 0 and best_checkpoint_epoch < len(val_accs) - 1:
                    overfit_gap = best_val_acc - final_val_acc
                    epochs_since_best = len(val_accs) - 1 - best_checkpoint_epoch
                    
                    mlflow.log_metrics({
                        "overfitting_gap": overfit_gap,
                        "epochs_since_best": epochs_since_best
                    })
                    
                    if overfit_gap > 0.01:
                        mlflow.log_param("overfitting_detected", "yes")
                        print(f"⚠ Overfitting detected: best={best_val_acc:.4f} vs final={final_val_acc:.4f}")
                        print(f"  Gap: {overfit_gap:.4f} ({epochs_since_best} epochs since best)")
                    else:
                        mlflow.log_param("overfitting_detected", "no")
                
                # Learning curve analysis
                if len(val_accs) >= 5:
                    recent_trend = val_accs[-1] - val_accs[-3]
                    mlflow.log_metric("recent_validation_trend", recent_trend)
                    
                    if recent_trend > 0:
                        mlflow.log_param("convergence_status", "still_improving")
                    elif abs(recent_trend) < 0.001:
                        mlflow.log_param("convergence_status", "converged")
                    else:
                        mlflow.log_param("convergence_status", "degrading")
            
            # ============================================================
            # TAGS
            # ============================================================
            
            performance_tier = (
                "excellent" if best_val_acc > 0.7 else 
                "good" if best_val_acc > 0.5 else 
                "fair" if best_val_acc > 0.3 else 
                "poor"
            )
            
            convergence_status = "unknown"
            if epoch_metrics and len(epoch_metrics) > 1:
                if best_checkpoint_epoch == len(epoch_metrics) - 1:
                    convergence_status = "still_improving"
                elif best_checkpoint_epoch >= 0 and best_checkpoint_epoch < len(epoch_metrics) - 3:
                    convergence_status = "early_stopped"
                else:
                    convergence_status = "converged"
            
            mlflow.set_tags({
                "experiment_run_id": EXPERIMENT_RUN_ID,
                "experiment_timestamp": EXPERIMENT_TIMESTAMP,
                "experiment_short_name": EXPERIMENT_SHORT_NAME,
                "trial_number": f"{trial.number:03d}",
                "run_name": run_name,
                "optimizer": optimizer_name,
                "performance_tier": performance_tier,
                "trial_status": "completed",
                "checkpointing_enabled": str(ENABLE_CHECKPOINTING),
                "checkpoint_storage": "uc_volumes_organized",
                "checkpoint_base_dir": CHECKPOINT_BASE_DIR,
                "checkpoint_trial_dir": trial_checkpoint_dir,
                "total_checkpoints": str(len(checkpoint_paths)) if checkpoint_paths else "0",
                "has_best_checkpoint": "yes" if best_checkpoint_path else "no",
                "best_checkpoint_epoch": f"{best_checkpoint_epoch + 1:03d}" if best_checkpoint_epoch >= 0 else "none",
                "overfitting": "yes" if (best_checkpoint_epoch >= 0 and best_checkpoint_epoch < NUM_EPOCHS - 2) else "no",
                "convergence_status": convergence_status,
                "batch_size": str(batch_size),
                "learning_rate": f"{lr:.2e}"
            })
            
            # ============================================================
            # SUMMARY OUTPUT
            # ============================================================
            
            print(f"\n{'='*80}")
            print(f"TRIAL {trial.number:03d} COMPLETED SUCCESSFULLY")
            print(f"{'='*80}")
            print(f"Run Name: {run_name}")
            print(f"\nPerformance:")
            print(f"  Best Accuracy:  {best_val_acc:.4f} (Epoch {best_checkpoint_epoch + 1})")
            print(f"  Final Accuracy: {final_val_acc:.4f}")
            print(f"  Performance Tier: {performance_tier}")
            print(f"\nTiming:")
            print(f"  Duration: {trial_duration:.1f}s ({trial_duration/60:.1f} min)")
            print(f"  Time per Epoch: {trial_duration/NUM_EPOCHS:.1f}s")
            if best_checkpoint_epoch >= 0:
                time_to_best = (best_checkpoint_epoch + 1) * (trial_duration / NUM_EPOCHS)
                print(f"  Time to Best: {time_to_best:.1f}s ({time_to_best/60:.1f} min)")
            print(f"\nStorage:")
            print(f"  Experiment Root:   {EXPERIMENT_ROOT}")
            print(f"  Checkpoint Dir:    {trial_checkpoint_dir}")
            print(f"  MLflow Artifacts:  {trial_run.info.artifact_uri}/{mlflow_artifact_subpath}/")
            if checkpoint_paths:
                print(f"\n  Checkpoints:")
                print(f"    Total Saved:      {len(checkpoint_paths)}")
                print(f"    Logged to MLflow: {logged_count}")
                print(f"    Total Storage:    {total_size_mb:.2f} MB")
                if best_checkpoint_path:
                    print(f"    Best Checkpoint:  {os.path.basename(best_checkpoint_path)}")
            print(f"\nMLflow:")
            print(f"  Run ID:   {trial_run_id}")
            print(f"  Run Name: {run_name}")
            print(f"{'='*80}\n")
            
            return best_val_acc
            
        except Exception as e:
            # ============================================================
            # EXCEPTION HANDLING
            # ============================================================
            
            trial_end_time = time.time()
            trial_duration = trial_end_time - trial_start_time
            
            print(f"\n{'='*80}")
            print(f"TRIAL {trial.number:03d} FAILED")
            print(f"{'='*80}")
            print(f"Run Name: {run_name}")
            print(f"Error Type: {type(e).__name__}")
            print(f"Error Message: {e}")
            print(f"Duration before failure: {trial_duration:.1f}s")
            print(f"Checkpoint Directory: {trial_checkpoint_dir}")
            print(f"{'='*80}")
            
            import traceback
            traceback.print_exc()
            print(f"{'='*80}\n")
            
            # Log failure information
            mlflow.log_params({
                "trial_status": "failed",
                "error_message": str(e)[:500],
                "error_type": type(e).__name__
            })
            
            mlflow.log_metrics({
                "final_validation_accuracy": 0.0,
                "trial_duration_seconds": trial_duration
            })
            
            # Log error details to file with predictable name
            error_filename = f"error_log_trial_{trial.number:03d}.txt"
            error_path = os.path.join(trial_checkpoint_dir, error_filename)
            
            try:
                with open(error_path, 'w') as f:
                    f.write(f"Trial {trial.number:03d} Error Report\n")
                    f.write(f"{'='*80}\n\n")
                    f.write(f"Run Name: {run_name}\n")
                    f.write(f"Error Type: {type(e).__name__}\n")
                    f.write(f"Error Message: {str(e)}\n\n")
                    f.write(f"Storage Configuration:\n")
                    f.write(f"  Experiment Root: {EXPERIMENT_ROOT}\n")
                    f.write(f"  Checkpoint Base: {CHECKPOINT_BASE_DIR}\n")
                    f.write(f"  Trial Directory: {trial_checkpoint_dir}\n\n")
                    f.write(f"Hyperparameters:\n")
                    f.write(f"  Optimizer: {optimizer_name}\n")
                    f.write(f"  Learning Rate: {lr:.2e}\n")
                    f.write(f"  Batch Size: {batch_size}\n")
                    f.write(f"  Weight Decay: {weight_decay:.2e}\n\n")
                    f.write(f"Full Traceback:\n")
                    f.write(traceback.format_exc())
                
                mlflow.log_artifact(error_path, "error_logs")
                print(f"✓ Error log saved: {error_filename}")
            except Exception as log_error:
                print(f"⚠ Could not save error log: {log_error}")
            
            mlflow.set_tags({
                "experiment_run_id": EXPERIMENT_RUN_ID,
                "experiment_timestamp": EXPERIMENT_TIMESTAMP,
                "experiment_short_name": EXPERIMENT_SHORT_NAME,
                "trial_number": f"{trial.number:03d}",
                "run_name": run_name,
                "trial_status": "failed",
                "optimizer": optimizer_name,
                "checkpointing_enabled": str(ENABLE_CHECKPOINTING),
                "checkpoint_storage": "uc_volumes_organized",
                "checkpoint_base_dir": CHECKPOINT_BASE_DIR,
                "checkpoint_trial_dir": trial_checkpoint_dir,
                "has_best_checkpoint": "no",
                "error_type": type(e).__name__,
                "batch_size": str(batch_size),
                "learning_rate": f"{lr:.2e}"
            })
            
            return 0.0

In [0]:
# ============================================================
# CHECKPOINT HELPER FUNCTIONS
# ============================================================

def find_latest_checkpoint(checkpoint_dir):
    """
    Find the latest checkpoint in a directory.
    
    Args:
        checkpoint_dir: Directory containing checkpoints
        
    Returns:
        Path to latest checkpoint or None
    """
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoint_files = [
        f for f in os.listdir(checkpoint_dir) 
        if f.startswith('checkpoint_epoch_') and f.endswith('.pt')
    ]
    
    if not checkpoint_files:
        return None
    
    # Sort by epoch number
    checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest_checkpoint = os.path.join(checkpoint_dir, checkpoint_files[-1])
    
    return latest_checkpoint

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
    """
    Load checkpoint and return metadata.
    
    Args:
        checkpoint_path: Path to checkpoint file
        model: Model to load state into
        optimizer: Optimizer to load state into
        scheduler: Scheduler to load state into
        device: Device to load tensors to
        
    Returns:
        Dictionary with checkpoint metadata or None
    """
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model state
        if hasattr(model, 'module'):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler state
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        metadata = {
            'epoch': checkpoint.get('epoch', 0),
            'best_val_acc': checkpoint.get('best_val_acc', 0.0),
            'val_acc': checkpoint.get('val_acc', 0.0),
            'train_loss': checkpoint.get('train_loss', 0.0),
            'val_loss': checkpoint.get('val_loss', 0.0)
        }
        
        print(f"✓ Loaded checkpoint: {os.path.basename(checkpoint_path)}")
        print(f"  Epoch: {metadata['epoch']}, Val Acc: {metadata['val_acc']:.4f}")
        
        return metadata
        
    except Exception as e:
        print(f"⚠ Failed to load checkpoint {checkpoint_path}: {e}")
        return None

def get_checkpoint_summary(checkpoint_dir):
    """
    Get summary of checkpoints in a directory.
    
    Args:
        checkpoint_dir: Directory containing checkpoints
        
    Returns:
        Dictionary with checkpoint summary
    """
    if not os.path.exists(checkpoint_dir):
        return {
            'exists': False,
            'total_checkpoints': 0,
            'total_size_mb': 0.0,
            'checkpoints': []
        }
    
    checkpoint_files = [
        f for f in os.listdir(checkpoint_dir)
        if f.endswith('.pt')
    ]
    
    total_size = 0
    checkpoint_info = []
    
    for filename in sorted(checkpoint_files):
        filepath = os.path.join(checkpoint_dir, filename)
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        total_size += size_mb
        
        checkpoint_info.append({
            'filename': filename,
            'size_mb': round(size_mb, 2),
            'path': filepath
        })
    
    return {
        'exists': True,
        'total_checkpoints': len(checkpoint_files),
        'total_size_mb': round(total_size, 2),
        'checkpoints': checkpoint_info
    }

def cleanup_old_checkpoints(checkpoint_base_dir, keep_recent=3):
    """
    Clean up old trial checkpoints, keeping only recent ones.
    
    Args:
        checkpoint_base_dir: Base directory containing trial subdirectories
        keep_recent: Number of recent trials to keep
    """
    if not os.path.exists(checkpoint_base_dir):
        return
    
    trial_dirs = [
        d for d in os.listdir(checkpoint_base_dir)
        if os.path.isdir(os.path.join(checkpoint_base_dir, d)) and d.startswith('trial_')
    ]
    
    if len(trial_dirs) <= keep_recent:
        print(f"Only {len(trial_dirs)} trial directories found, no cleanup needed")
        return
    
    # Sort by modification time
    trial_dirs_with_time = [
        (d, os.path.getmtime(os.path.join(checkpoint_base_dir, d)))
        for d in trial_dirs
    ]
    trial_dirs_sorted = sorted(trial_dirs_with_time, key=lambda x: x[1], reverse=True)
    
    # Keep recent, remove old
    dirs_to_keep = [d[0] for d in trial_dirs_sorted[:keep_recent]]
    dirs_to_remove = [d[0] for d in trial_dirs_sorted[keep_recent:]]
    
    print(f"\nCleaning up old checkpoints:")
    print(f"  Keeping {len(dirs_to_keep)} recent trials")
    print(f"  Removing {len(dirs_to_remove)} old trials")
    
    import shutil
    for trial_dir in dirs_to_remove:
        trial_path = os.path.join(checkpoint_base_dir, trial_dir)
        try:
            shutil.rmtree(trial_path)
            print(f"  ✓ Removed: {trial_dir}")
        except Exception as e:
            print(f"  ⚠ Failed to remove {trial_dir}: {e}")

In [0]:
# Training Configuration
# num_classes = 200
NUM_EPOCHS = 3 #5  #Test with just 1 or 2 epoch first if needed
num_workers = 4  #Reduce to 2 workers to decrease I/O contention

## Optuna Configuration related 
# parameterized N_TRIALS 
N_TRIALS = 3 #5 ## can be updated/increased 
# batch_sizes = [8, 16] -- update parameter-ranges/etc
# optimizers = ['AdamW']

In [0]:
# ============================================================
# RUN OPTUNA OPTIMIZATION WITH ORGANIZED CHECKPOINTING
# ============================================================
import time
from datetime import datetime
import json
import numpy as np

# Verify configuration is set up
if not all([EXPERIMENT_ROOT, CHECKPOINT_BASE_DIR, MLFLOW_ARTIFACT_LOCATION, EXPERIMENT_NAME]):
    raise ValueError("Configuration not properly initialized. Run the configuration cell first.")

print(f"\n{'='*80}")
print(f"STARTING OPTUNA OPTIMIZATION")
print(f"{'='*80}")
print(f"Experiment: {EXPERIMENT_SHORT_NAME}")
print(f"Timestamp: {EXPERIMENT_TIMESTAMP}")
print(f"Trials: {N_TRIALS}")
print(f"Epochs per Trial: {NUM_EPOCHS}")
print(f"Workers: {num_workers}")
print(f"Checkpointing: {'Enabled' if ENABLE_CHECKPOINTING else 'Disabled'}")
print(f"\nPaths:")
print(f"  Experiment Root:  {EXPERIMENT_ROOT}")
print(f"  Checkpoint Base:  {CHECKPOINT_BASE_DIR}")
print(f"  MLflow Artifacts: {MLFLOW_ARTIFACT_LOCATION}")
print(f"  MLflow Experiment: {EXPERIMENT_NAME}")
print(f"{'='*80}\n")

# Create Optuna study
study = optuna.create_study(
    study_name=f"pytorch_{EXPERIMENT_SHORT_NAME}_{EXPERIMENT_TIMESTAMP}",
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=42)
)

start_time = time.time()

# Start parent MLflow run
with mlflow.start_run(run_name=f"optuna_study_{EXPERIMENT_TIMESTAMP}") as parent_run:
    EXPERIMENT_RUN_ID = parent_run.info.run_id
    parent_run_id = parent_run.info.run_id
    
    print(f"Parent MLflow Run ID: {parent_run_id}")
    print(f"Starting optimization...\n")        
    
    # Log optimization configuration
    mlflow.log_params({
        "experiment_run_id": EXPERIMENT_RUN_ID,
        "experiment_timestamp": EXPERIMENT_TIMESTAMP,
        "experiment_short_name": EXPERIMENT_SHORT_NAME,
        "study_name": study.study_name,
        "n_trials": N_TRIALS,
        "num_workers": num_workers,
        "num_epochs": NUM_EPOCHS,
        "model_architecture": "mobilenetv2",
        "optimization_type": "optuna_hpt_organized",
        
        # ADD GLOBAL DATALOADER CONFIGURATION
        "global_dataloader_type": "streaming_mds",
        "global_dataloader_library": "mosaic_streaming",
        "global_dataset_format": "mds",
        "global_dataset_storage": "uc_volumes",
        "global_dataset_train_path": f"{data_storage_location}/{mds_train_dir}",
        "global_dataset_val_path": f"{data_storage_location}/{mds_val_dir}",
        "global_dataset_num_classes": num_classes,
        "global_dataset_catalog": CATALOG,
        "global_dataset_schema": SCHEMA,
        "global_dataset_volume": VOLUME_NAME,
        
        # Checkpointing
        "checkpointing_enabled": ENABLE_CHECKPOINTING,
        "checkpoint_frequency": CHECKPOINT_FREQUENCY,
        "checkpoint_base_dir": CHECKPOINT_BASE_DIR,
        "experiment_root": EXPERIMENT_ROOT,
        "mlflow_artifact_location": MLFLOW_ARTIFACT_LOCATION,
        "resume_from_checkpoint": RESUME_FROM_CHECKPOINT
    })
    
    # Run optimization (sequential for simplicity)
    study.optimize(objective_function, n_trials=N_TRIALS, n_jobs=1)
    
    # Calculate duration
    end_time = time.time()
    optimization_duration = end_time - start_time
    
    # Get results
    best_params = study.best_params
    best_value = study.best_value
    best_trial_number = study.best_trial.number if study.best_trial else None
    
    completed_trials = [t for t in study.trials if t.state.name == "COMPLETE" and t.value is not None]
    failed_trials = [t for t in study.trials if t.state.name != "COMPLETE" or t.value is None]
    trial_values = [t.value for t in completed_trials]
    
    print(f"\n{'='*80}")
    print(f"OPTIMIZATION COMPLETED")
    print(f"{'='*80}")
    print(f"Duration: {optimization_duration:.1f}s ({optimization_duration/60:.1f} min)")
    print(f"Completed Trials: {len(completed_trials)}/{len(study.trials)}")
    print(f"Failed Trials: {len(failed_trials)}")
    print(f"\nBest Trial: {best_trial_number}")
    print(f"Best Validation Accuracy: {best_value:.4f}")
    print(f"Best Parameters:")
    for param, value in best_params.items():
        print(f"  {param}: {value}")
    print(f"{'='*80}\n")
    
    # ============================================================
    # CHECKPOINT SUMMARY
    # ============================================================
    
    checkpoint_summary = {
        "experiment_timestamp": EXPERIMENT_TIMESTAMP,
        "checkpoint_base_dir": CHECKPOINT_BASE_DIR,
        "trial_checkpoints": {},
        "total_checkpoints": 0,
        "total_size_mb": 0.0,
        "best_trial_number": best_trial_number,
        "best_trial_checkpoint_dir": None
    }
    
    if ENABLE_CHECKPOINTING and os.path.exists(CHECKPOINT_BASE_DIR):
        print("Collecting checkpoint summary...")
        
        # Get all trial directories
        trial_dirs = [
            d for d in os.listdir(CHECKPOINT_BASE_DIR)
            if os.path.isdir(os.path.join(CHECKPOINT_BASE_DIR, d)) and d.startswith('trial_')
        ]
        
        total_checkpoints = 0
        total_size_bytes = 0
        
        for trial_dir in sorted(trial_dirs):
            trial_path = os.path.join(CHECKPOINT_BASE_DIR, trial_dir)
            
            # Count checkpoint files
            checkpoint_files = [
                f for f in os.listdir(trial_path)
                if f.endswith('.pt')
            ]
            
            if checkpoint_files:
                # Calculate size
                trial_size = sum(
                    os.path.getsize(os.path.join(trial_path, f))
                    for f in checkpoint_files
                )
                trial_size_mb = trial_size / (1024 * 1024)
                
                # Check for best checkpoint
                has_best = os.path.exists(os.path.join(trial_path, "best_checkpoint.pt"))
                
                checkpoint_summary["trial_checkpoints"][trial_dir] = {
                    "count": len(checkpoint_files),
                    "size_mb": round(trial_size_mb, 2),
                    "has_best": has_best,
                    "path": trial_path
                }
                
                total_checkpoints += len(checkpoint_files)
                total_size_bytes += trial_size
        
        checkpoint_summary["total_checkpoints"] = total_checkpoints
        checkpoint_summary["total_size_mb"] = round(total_size_bytes / (1024 * 1024), 2)
        
        # Find best trial checkpoint directory
        if best_trial_number is not None:
            best_trial_dirs = [
                d for d in trial_dirs
                if d.startswith(f"trial_{best_trial_number:03d}_")
            ]
            if best_trial_dirs:
                best_trial_checkpoint_dir = os.path.join(CHECKPOINT_BASE_DIR, best_trial_dirs[0])
                checkpoint_summary["best_trial_checkpoint_dir"] = best_trial_checkpoint_dir
        
        print(f"✓ Found {total_checkpoints} checkpoints across {len(trial_dirs)} trials")
        print(f"  Total Size: {checkpoint_summary['total_size_mb']:.2f} MB")
        if checkpoint_summary["best_trial_checkpoint_dir"]:
            print(f"  Best Trial Checkpoints: {checkpoint_summary['best_trial_checkpoint_dir']}")
    
    # ============================================================
    # LOG OPTIMIZATION RESULTS
    # ============================================================
    
    # Essential metrics
    optimization_results = {
        "best_validation_accuracy": best_value,
        "total_trials": len(study.trials),
        "completed_trials": len(completed_trials),
        "failed_trials": len(failed_trials),
        "optimization_duration_seconds": optimization_duration,
        "optimization_duration_minutes": optimization_duration / 60.0,
        "success_rate": len(completed_trials) / len(study.trials) if len(study.trials) > 0 else 0.0,
        "total_checkpoints_saved": checkpoint_summary["total_checkpoints"],
        "total_checkpoints_size_mb": checkpoint_summary["total_size_mb"]
    }
    
    mlflow.log_metrics(optimization_results)
    
    # Log best hyperparameters
    for param_name, param_value in best_params.items():
        mlflow.log_param(f"best_{param_name}", param_value)
    
    # Log trial statistics
    if trial_values:
        trial_stats = {
            "mean_trial_accuracy": np.mean(trial_values),
            "std_trial_accuracy": np.std(trial_values),
            "min_trial_accuracy": np.min(trial_values),
            "max_trial_accuracy": np.max(trial_values),
            "median_trial_accuracy": np.median(trial_values)
        }
        mlflow.log_metrics(trial_stats)
        
        # Calculate improvement metrics
        if len(trial_values) > 1:
            improvement = best_value - np.min(trial_values)
            mlflow.log_metric("accuracy_improvement_range", improvement)
    
    # Log per-trial checkpoint counts
    for trial_dir, info in checkpoint_summary["trial_checkpoints"].items():
        trial_num = trial_dir.split('_')[1]
        mlflow.log_metric(f"trial_{trial_num}_checkpoint_count", info["count"])
        mlflow.log_metric(f"trial_{trial_num}_checkpoint_size_mb", info["size_mb"])
    
    # ============================================================
    # PARAMETER IMPORTANCE
    # ============================================================
    
    try:
        if len(completed_trials) > 1:
            importance = optuna.importance.get_param_importances(study)
            
            print("\nParameter Importance:")
            for param_name, importance_value in sorted(importance.items(), key=lambda x: x[1], reverse=True):
                mlflow.log_metric(f"param_importance_{param_name}", importance_value)
                print(f"  {param_name}: {importance_value:.4f}")
            
            most_important = max(importance.items(), key=lambda x: x[1])
            mlflow.log_param("most_important_param", most_important[0])
            mlflow.log_metric("most_important_param_score", most_important[1])
            print(f"\nMost Important: {most_important[0]} ({most_important[1]:.4f})")
    except Exception as e:
        print(f"Could not calculate parameter importance: {e}")
        mlflow.log_param("importance_error", str(e)[:200])
    
    # ============================================================
    # SAVE AND LOG CHECKPOINT SUMMARY
    # ============================================================
    
    # Save checkpoint summary to file
    summary_filename = f"optimization_summary_{EXPERIMENT_TIMESTAMP}.json"
    summary_path = os.path.join(CHECKPOINT_BASE_DIR, summary_filename)
    
    full_summary = {
        "optimization": {
            "experiment_timestamp": EXPERIMENT_TIMESTAMP,
            "experiment_short_name": EXPERIMENT_SHORT_NAME,
            "study_name": study.study_name,
            "duration_seconds": optimization_duration,
            "duration_minutes": optimization_duration / 60.0,
            "n_trials": N_TRIALS,
            "completed_trials": len(completed_trials),
            "failed_trials": len(failed_trials)
        },
        "best_trial": {
            "trial_number": best_trial_number,
            "validation_accuracy": best_value,
            "parameters": best_params,
            "checkpoint_dir": checkpoint_summary["best_trial_checkpoint_dir"]
        },
        "checkpoints": checkpoint_summary,
        "paths": {
            "experiment_root": EXPERIMENT_ROOT,
            "checkpoint_base_dir": CHECKPOINT_BASE_DIR,
            "mlflow_artifact_location": MLFLOW_ARTIFACT_LOCATION
        },
        "mlflow": {
            "parent_run_id": parent_run_id,
            "experiment_name": EXPERIMENT_NAME
        }
    }
    
    with open(summary_path, 'w') as f:
        json.dump(full_summary, f, indent=2)
    
    mlflow.log_artifact(summary_path, "optimization_summary")
    print(f"\n✓ Saved optimization summary: {summary_filename}")
    
    # ============================================================
    # SAVE STUDY OBJECT
    # ============================================================
    
    # Save Optuna study for later analysis
    study_filename = f"optuna_study_{EXPERIMENT_TIMESTAMP}.pkl"
    study_path = os.path.join(CHECKPOINT_BASE_DIR, study_filename)
    
    import pickle
    with open(study_path, 'wb') as f:
        pickle.dump(study, f)
    
    mlflow.log_artifact(study_path, "optimization_summary")
    print(f"✓ Saved Optuna study: {study_filename}")
    
    # ============================================================
    # SET TAGS
    # ============================================================
    
    mlflow.set_tags({
        "optimization_type": "optuna_hpt_organized",
        "experiment_short_name": EXPERIMENT_SHORT_NAME,
        "experiment_timestamp": EXPERIMENT_TIMESTAMP,
        "model_architecture": "mobilenetv2",
        "best_optimizer": best_params.get('optimizer', 'unknown'),
        "optimization_status": "completed",
        "total_trials": str(len(study.trials)),
        "completed_trials": str(len(completed_trials)),
        "best_trial_number": str(best_trial_number) if best_trial_number is not None else "none",
        "best_accuracy": f"{best_value:.4f}",
        "checkpointing_enabled": str(ENABLE_CHECKPOINTING),
        "total_checkpoints": str(checkpoint_summary["total_checkpoints"]),
        "checkpoint_base_dir": CHECKPOINT_BASE_DIR
    })

# ============================================================
# FINAL SUMMARY DISPLAY
# ============================================================

print(f"\n{'='*80}")
print(f"OPTIMIZATION SUMMARY")
print(f"{'='*80}")
print(f"Parent Run ID: {parent_run_id}")
print(f"Duration: {optimization_duration:.1f}s ({optimization_duration/60:.1f} min)")
print(f"\nResults:")
print(f"  Best Trial: {best_trial_number}")
print(f"  Best Accuracy: {best_value:.4f}")
print(f"  Completed: {len(completed_trials)}/{len(study.trials)}")
print(f"  Success Rate: {len(completed_trials)/len(study.trials)*100:.1f}%")
print(f"\nCheckpoints:")
print(f"  Total Saved: {checkpoint_summary['total_checkpoints']}")
print(f"  Total Size: {checkpoint_summary['total_size_mb']:.2f} MB")
print(f"  Location: {CHECKPOINT_BASE_DIR}")
if checkpoint_summary["best_trial_checkpoint_dir"]:
    print(f"  Best Trial: {checkpoint_summary['best_trial_checkpoint_dir']}")
print(f"\nMLflow:")
print(f"  Experiment: {EXPERIMENT_NAME}")
print(f"  Parent Run: {parent_run_id}")
print(f"  Artifacts: {MLFLOW_ARTIFACT_LOCATION}")
print(f"{'='*80}\n")

# ============================================================
# DISPLAY DETAILED CHECKPOINT SUMMARY
# ============================================================

if checkpoint_summary["trial_checkpoints"]:
    print(f"{'='*80}")
    print(f"CHECKPOINT DETAILS")
    print(f"{'='*80}")
    
    for trial_dir in sorted(checkpoint_summary["trial_checkpoints"].keys()):
        info = checkpoint_summary["trial_checkpoints"][trial_dir]
        best_marker = "★" if info["has_best"] else " "
        print(f"{best_marker} {trial_dir}:")
        print(f"    Checkpoints: {info['count']}")
        print(f"    Size: {info['size_mb']:.2f} MB")
        print(f"    Path: {info['path']}")
    
    print(f"{'='*80}\n")

# ============================================================
# SAVE PRETTY-PRINTED SUMMARY
# ============================================================

print("Checkpoint Summary (JSON):")
print(json.dumps(checkpoint_summary, indent=2))

print(f"\n✓ All results saved to: {CHECKPOINT_BASE_DIR}")
print(f"✓ MLflow experiment: {EXPERIMENT_NAME}")
print(f"✓ Parent run ID: {parent_run_id}")

In [0]:
parent_run_id

In [0]:
def load_best_model(parent_run_id):
    """Load best model from HPO run."""
    import mlflow
    import torch
    from torchvision import models
    import json
    import glob
    
    client = mlflow.tracking.MlflowClient()
    
    # Get optimization summary
    summary_path = client.download_artifacts(parent_run_id, "optimization_summary")
    summary_file = glob.glob(f"{summary_path}/*.json")[0]
    
    with open(summary_file, 'r') as f:
        summary = json.load(f)
    
    # Get best checkpoint
    best_trial = summary['best_trial']
    checkpoint_dir = best_trial['checkpoint_dir']
    checkpoint_path = f"{checkpoint_dir}/best_checkpoint.pt"
    
    # Load checkpoint
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    
    # Build model to match checkpoint
    model = models.mobilenet_v2(weights=None)
    
    if 'classifier.4.weight' in state_dict:
        hidden_dim = state_dict['classifier.1.weight'].shape[0]
        num_classes = state_dict['classifier.4.weight'].shape[0]
        
        model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2),
            torch.nn.Linear(model.last_channel, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.2),
            torch.nn.Linear(hidden_dim, num_classes)
        )
    else:
        num_classes = state_dict['classifier.1.weight'].shape[0]
        model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)
    
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    print(f"Loaded model with {num_classes} classes")
    return model, device


def calculate_metrics(all_predictions, all_labels, all_probs, num_classes):
    """Calculate accuracy, mAP, and per-class metrics."""
    import numpy as np
    from sklearn.metrics import average_precision_score, confusion_matrix
    
    # Accuracy
    accuracy = (all_predictions == all_labels).sum() / len(all_labels)
    
    # mAP (mean Average Precision)
    labels_onehot = np.zeros((len(all_labels), num_classes))
    labels_onehot[np.arange(len(all_labels)), all_labels] = 1
    
    # Calculate AP for each class
    aps = []
    for i in range(num_classes):
        if labels_onehot[:, i].sum() > 0:
            ap = average_precision_score(labels_onehot[:, i], all_probs[:, i])
            aps.append(ap)
    
    mAP = np.mean(aps) if aps else 0.0
    
    # Per-class accuracy
    cm = confusion_matrix(all_labels, all_predictions)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    
    # Top-5 accuracy
    top5_preds = np.argsort(all_probs, axis=1)[:, -5:]
    top5_correct = np.any(top5_preds == all_labels[:, None], axis=1).sum()
    top5_accuracy = top5_correct / len(all_labels)
    
    return {
        'accuracy': accuracy,
        'mAP': mAP,
        'top5_accuracy': top5_accuracy,
        'per_class_accuracy': per_class_acc,
        'mean_class_accuracy': np.mean(per_class_acc)
    }


def run_inference(parent_run_id, val_path, batch_size=64):
    """Run inference with comprehensive metrics."""
    import torch
    import shutil
    import numpy as np
    
    # Load model
    model, device = load_best_model(parent_run_id)
    
    # Create class mapping
    print("Creating class mapping...")
    label_to_idx = create_comprehensive_label_mapping(val_path)
    num_classes = len(label_to_idx)
    
    # Create dataloader
    dataloader, cache_dir = get_dataloader_with_mosaic(
        remote_path=val_path,
        local_path="/local_disk0/tmp/mds_inference",
        batch_size=batch_size,
        rank=0
    )
    
    # Collect predictions and labels
    all_predictions = []
    all_labels = []
    all_probs = []
    
    try:
        with torch.no_grad():
            for batch in dataloader:
                images, labels = convert_batch_to_tensors(
                    batch, 
                    device=device, 
                    class_to_idx=label_to_idx,
                    rank=0
                )
                
                outputs = model(images)
                probs = torch.softmax(outputs, dim=1)
                predictions = outputs.argmax(dim=1)
                
                all_predictions.append(predictions.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
                all_probs.append(probs.cpu().numpy())
    
    finally:
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir, ignore_errors=True)
    
    # Concatenate all batches
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)
    
    # Calculate metrics
    metrics = calculate_metrics(all_predictions, all_labels, all_probs, num_classes)
    
    # Add sample count to metrics for easy access
    metrics['sample_count'] = len(all_labels)
    
    # Print results
    print(f"\n{'='*70}")
    print(f"INFERENCE METRICS")
    print(f"{'='*70}")
    print(f"  Samples:              {len(all_labels):,}")
    print(f"  Accuracy:             {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
    print(f"  Top-5 Accuracy:       {metrics['top5_accuracy']:.4f} ({metrics['top5_accuracy']*100:.2f}%)")
    print(f"  mAP:                  {metrics['mAP']:.4f}")
    print(f"  Mean Class Accuracy:  {metrics['mean_class_accuracy']:.4f}")
    print(f"{'='*70}\n")
    
    return metrics

# Usage
val_path = "/Volumes/mmt/pytorch/torch_data/imagenet_tiny200_mds_val"
accuracy = run_inference(parent_run_id, val_path)

In [0]:
accuracy.keys()

In [0]:
# Log inference metrics to the existing parent MLflow run
# This maintains the logical connection between optimization and inference

# Get the metrics programmatically from the previous inference run
# The 'accuracy' variable from cell 23 actually contains all metrics
if 'accuracy' in locals() and isinstance(accuracy, dict):
    # Extract sample count from the metrics dictionary itself
    if 'sample_count' in accuracy:
        sample_count = accuracy['sample_count']
    else:
        raise ValueError("Sample count not found in metrics! Please re-run the inference cell (cell 23) with the updated version.")
    
    # Use the metrics returned from run_inference function
    metrics_dict = {
        "inference_accuracy": accuracy['accuracy'],
        "inference_top5_accuracy": accuracy['top5_accuracy'], 
        "inference_map": accuracy['mAP'],
        "inference_mean_class_accuracy": accuracy['mean_class_accuracy'],
        "inference_samples": sample_count
    }
    print(f"Extracted metrics programmatically from inference results")
    print(f"Sample count: {sample_count:,} samples")
else:
    # Error out if metrics are not available - no fallback
    raise ValueError("Inference metrics not found! Please run the inference cell (cell 23) first to generate the 'accuracy' variable with metrics.")

# Log to the existing parent run from optimization
with mlflow.start_run(run_id=parent_run_id):
    # Log inference metrics (metrics can be overwritten)
    for metric_name, metric_value in metrics_dict.items():
        mlflow.log_metric(metric_name, metric_value)
    
    # Try to log parameters, but skip if they already exist
    params_to_log = {
        "inference_batch_size": 64,
        "inference_dataset_path": val_path,
        "inference_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    for param_name, param_value in params_to_log.items():
        try:
            mlflow.log_param(param_name, param_value)
        except Exception as e:
            if "already logged" in str(e):
                print(f"Parameter '{param_name}' already exists, skipping...")
            else:
                print(f"Warning: Could not log parameter '{param_name}': {e}")
    
    # Add inference tags (tags can be overwritten)
    mlflow.set_tag("inference_completed", "true")
    mlflow.set_tag("inference_phase", "post_optimization")
    mlflow.set_tag("inference_timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    
    print(f"\nInference metrics logged to existing MLflow run: {parent_run_id}")
    print(f"Final Results Summary:")
    print(f"   • Accuracy: {metrics_dict['inference_accuracy']:.4f} ({metrics_dict['inference_accuracy']*100:.2f}%)")
    print(f"   • Top-5 Accuracy: {metrics_dict['inference_top5_accuracy']:.4f} ({metrics_dict['inference_top5_accuracy']*100:.2f}%)")
    print(f"   • Mean Average Precision: {metrics_dict['inference_map']:.4f}")
    print(f"   • Mean Class Accuracy: {metrics_dict['inference_mean_class_accuracy']:.4f}")
    print(f"   • Total validation samples: {metrics_dict['inference_samples']:,}")
    print(f"\nView complete experiment: {EXPERIMENT_NAME}")
    print(f"Run ID: {parent_run_id}")