An example for distributing the Pytorch training with  Hyperparameter tuning optimizations using [Optuna](https://optuna.org/)

The cluster used is the same as before -- you are welcome to test other cluster configs. 

```
"spark_version": "16.4.x-scala2.13",
"node_type_id": "g5.12xlarge",
"autoscale": {
    "min_workers": 2, ## you can fix this instead
    "max_workers": 4
    }
```


In [0]:
%pip install mlflow[skinny]>=3 optuna nvidia-ml-py3 --upgrade

dbutils.library.restartPython()

In [0]:
# MLflow version: 3.4.0
# Optuna version: 4.5.0
# PyTorch version: 2.6.0+cu124

In [0]:
import time
from datetime import datetime
import os
import io
import base64
import shutil
import tempfile
import uuid
from datetime import datetime, timedelta
from functools import partial

import numpy as np
import torch
import torch.nn as nn  
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image

import mlflow
import optuna
from pyspark.ml.torch.distributor import TorchDistributor
from streaming import StreamingDataset, StreamingDataLoader
from streaming.base.util import clean_stale_shared_memory

print(f"MLflow version: {mlflow.__version__}")
print(f"Optuna version: {optuna.__version__}")
print(f"PyTorch version: {torch.__version__}")

In [0]:
# Unity Catalog Configuration -- update the Catalog and Schema & Volume names here for what you use
CATALOG = "mmt"
SCHEMA = "pytorch"
VOLUME_NAME = "torch_data"

# Dataset Configuration
mds_train_dir = 'imagenet_tiny200_mds_train'
mds_val_dir = 'imagenet_tiny200_mds_val'
data_storage_location = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

# Training Configuration -- 
num_classes = 200
NUM_EPOCHS = 2
num_workers = 2

# MLflow Configuration
USER_NAME = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
experiment_name = f"/Users/{USER_NAME}/mlflow_experiments/pytorch_imagenet_mobilenetv2_hpt" ## a shared workspace folder is an alternative common path to use -- here we specify a user folder
ARTIFACT_PATH = f"dbfs:/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

# Setup MLflow
import mlflow 

# Enable system metrics logging globally -- this tracks cpu/gpu metrics
mlflow.enable_system_metrics_logging()

mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

if mlflow.get_experiment_by_name(experiment_name) is None:
    mlflow.create_experiment(name=experiment_name, artifact_location=ARTIFACT_PATH)
mlflow.set_experiment(experiment_name)

print(f"Data storage location: {data_storage_location}")
print(f"Training data: {data_storage_location}/{mds_train_dir}")
print(f"Validation data: {data_storage_location}/{mds_val_dir}")
print(f"MLflow experiment: {experiment_name}")

In [0]:
def get_dataloader_with_mosaic(remote_path, local_cache_path, batch_size, rank=0):
    """Get MosaicML DataLoader with  settings for larger batch sizes"""
    
    import os
    import shutil
    from streaming import StreamingDataset
    from torch.utils.data import DataLoader
    
    print(f"Rank {rank}: Getting  MDS data from {remote_path}")
    
    # Create unique cache directory
    import time
    import uuid
    cache_suffix = f"{int(time.time())}_{rank}_{str(uuid.uuid4())[:8]}"
    unique_cache_path = f"{local_cache_path}_{cache_suffix}"
    
    try:
        # StreamingDataset configuration
        dataset = StreamingDataset(
            remote=remote_path,
            local=unique_cache_path,
            shuffle=True,
            batch_size=batch_size,
            # Key optimizations for Mosaic streaming
            download_retry=3,                # Retry failed downloads
            download_timeout=120,            # Longer timeout for large files
            keep_zip=False,                  # Don't keep compressed files
            cache_limit="50gb",              # Limit cache size
            predownload=min(batch_size * 4, 1000),  # Predownload samples
            partition_algo='relaxed'         # Better load balancing
        )
        
        print(f"Rank {rank}: Created StreamingDataset with {len(dataset)} samples")
        
        # Calculate optimal num_workers based on batch size and system
        import multiprocessing
        max_workers = min(multiprocessing.cpu_count(), 16)
        optimal_workers = min(max_workers, max(2, batch_size // 8))
        
        # Optimized DataLoader settings
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,                   # StreamingDataset handles shuffling
            num_workers=optimal_workers,     # Dynamic worker count
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=max(2, min(8, batch_size // 16)),  # Dynamic prefetch
            drop_last=True,
            multiprocessing_context='spawn' if os.name == 'nt' else 'fork'  # OS-specific
        )
        
        print(f"Rank {rank}: Created optimized dataloader with {len(dataloader)} batches, {optimal_workers} workers")
        
        return dataloader, unique_cache_path
        
    except Exception as e:
        print(f"Rank {rank}: Error creating dataloader: {e}")
        if os.path.exists(unique_cache_path):
            shutil.rmtree(unique_cache_path, ignore_errors=True)
        raise

In [0]:
### your data structure may be different -->  verify_mds_dataset will need to be updated 

import io
import tempfile
import shutil
import os
from streaming import StreamingDataset

def create_comprehensive_label_mapping(remote_path):
    """Extract class names using DataLoader approach"""
    print(f"Creating comprehensive label mapping from: {remote_path}")
    
    temp_cache = tempfile.mkdtemp(prefix="label_mapping_")
    
    try:
        dataset = StreamingDataset(
            remote=remote_path,
            local=temp_cache,
            shuffle=True,
            batch_size=32  # Use a reasonable batch size
        )
        
        # Create DataLoader
        from torch.utils.data import DataLoader
        dataloader = DataLoader(
            dataset,
            batch_size=32,
            num_workers=0  # Use 0 to avoid multiprocessing issues during class extraction
        )
        
        print(f"Dataset created with {len(dataset)} total samples")
        
        unique_classes = set()
        sample_count = 0
        max_batches = 50  # Process 50 batches to get class names
        
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= max_batches:
                break
                
            # batch['class_name'] is now a list of class names
            class_names = batch['class_name']
            unique_classes.update(class_names)
            sample_count += len(class_names)
            
            if batch_idx % 10 == 0:
                print(f"  Processed {batch_idx} batches, {sample_count} samples, found {len(unique_classes)} unique classes")
            
            if len(unique_classes) >= 200:
                break
        
        classes = sorted(list(unique_classes))
        print(f"Found {len(classes)} unique classes")
        print(f"Sample classes: {classes[:5]}")
        
        label_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        
        print(f"\n=== Label Mapping Summary ===")
        print(f"Total classes: {len(label_to_idx)}")
        
        return label_to_idx
            
    except Exception as e:
        print(f"Error creating label mapping: {e}")
        import traceback
        print(f"Full error: {traceback.format_exc()}")
        return create_imagenet_tiny200_mapping()
    
    finally:
        if os.path.exists(temp_cache):
            shutil.rmtree(temp_cache, ignore_errors=True)

def create_imagenet_tiny200_mapping():
    """Fallback mapping"""
    print("Creating fallback ImageNet Tiny-200 mapping...")
    classes = [f"class_{i:03d}" for i in range(200)]
    return {cls: idx for idx, cls in enumerate(classes)}

def verify_mds_dataset(remote_path):
    """Simple verification for multi-shard MDS dataset"""
    print(f"Verifying multi-shard MDS dataset at: {remote_path}")
    
    if not os.path.exists(remote_path):
        print(f"ERROR: Remote path does not exist: {remote_path}")
        return False
    
    files_in_dir = os.listdir(remote_path)
    print(f"Items in root directory: {len(files_in_dir)}")
    
    # Check for main index.json
    has_main_index = 'index.json' in files_in_dir
    if has_main_index:
        print("Found main index.json")
    
    # Check for numbered shard directories
    shard_dirs = [f for f in files_in_dir if f.isdigit() and os.path.isdir(os.path.join(remote_path, f))]
    shard_dirs.sort(key=int)
    
    print(f"Found {len(shard_dirs)} shard directories")
    if len(shard_dirs) > 0:
        print(f"  Shard range: {shard_dirs[0]} to {shard_dirs[-1]}")
        print(f"  Sample shards: {shard_dirs[:5]}{'...' if len(shard_dirs) > 5 else ''}")
        
        # Check first shard
        if len(shard_dirs) > 0:
            shard_path = os.path.join(remote_path, shard_dirs[0])
            shard_contents = os.listdir(shard_path)
            print(f"Shard {shard_dirs[0]} contents: {shard_contents}")
            
            # Check for shard files
            shard_files = [f for f in shard_contents if f.endswith('.mds') or f.endswith('.mds.zstd')]
            if shard_files:
                shard_file_path = os.path.join(shard_path, shard_files[0])
                file_size = os.path.getsize(shard_file_path)
                print(f"Shard file size: {file_size:,} bytes")
        
        print(f"Verified sample shards")
        return True
    
    print("No valid shard directories found")
    return False


In [0]:
def convert_batch_to_tensors(batch, device=None, class_to_idx=None, rank=None, 
                           use_augmentation=False, augmentation_strength='medium'):
    """Convert MDS batch to PyTorch tensors with optional data augmentation"""
    import io
    from PIL import Image
    import torchvision.transforms as transforms
    import torch
    
    # Auto-detect device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Handle rank parameter - get from distributed training if available
    if rank is None:
        try:
            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
            else:
                rank = 0
        except:
            rank = 0
    
    # Define augmentation presets
    augmentation_presets = {
        'light': {
            'rotation': 10,
            'brightness': 0.1,
            'contrast': 0.1,
            'saturation': 0.1,
            'hue': 0.05,
            'horizontal_flip': 0.3,
            'vertical_flip': 0.0,
            'gaussian_blur': 0.0,
            'random_erasing': 0.0
        },
        'medium': {
            'rotation': 15,
            'brightness': 0.2,
            'contrast': 0.2,
            'saturation': 0.2,
            'hue': 0.1,
            'horizontal_flip': 0.5,
            'vertical_flip': 0.1,
            'gaussian_blur': 0.1,
            'random_erasing': 0.1
        },
        'heavy': {
            'rotation': 25,
            'brightness': 0.3,
            'contrast': 0.3,
            'saturation': 0.3,
            'hue': 0.15,
            'horizontal_flip': 0.5,
            'vertical_flip': 0.2,
            'gaussian_blur': 0.2,
            'random_erasing': 0.2
        }
    }
    
    # Build transform pipeline
    transform_list = [transforms.Resize((224, 224))]
    
    if use_augmentation:
        # Get augmentation parameters
        if isinstance(augmentation_strength, str):
            aug_params = augmentation_presets.get(augmentation_strength, augmentation_presets['medium'])
        elif isinstance(augmentation_strength, dict):
            aug_params = augmentation_strength
        else:
            aug_params = augmentation_presets['medium']
        
        # Print augmentation info only once per function and only from rank 0
        if not hasattr(convert_batch_to_tensors, '_printed_aug_info'):
            if rank == 0:
                print(f"Using {augmentation_strength} augmentation with parameters: {aug_params}")
            convert_batch_to_tensors._printed_aug_info = True

        # Add augmentation transforms
        if aug_params.get('rotation', 0) > 0:
            transform_list.append(
                transforms.RandomRotation(degrees=aug_params['rotation'])
            )
        
        if aug_params.get('horizontal_flip', 0) > 0:
            transform_list.append(
                transforms.RandomHorizontalFlip(p=aug_params['horizontal_flip'])
            )
        
        if aug_params.get('vertical_flip', 0) > 0:
            transform_list.append(
                transforms.RandomVerticalFlip(p=aug_params['vertical_flip'])
            )
        
        # Color jittering
        color_params = [
            aug_params.get('brightness', 0),
            aug_params.get('contrast', 0),
            aug_params.get('saturation', 0),
            aug_params.get('hue', 0)
        ]
        if any(p > 0 for p in color_params):
            transform_list.append(
                transforms.ColorJitter(
                    brightness=color_params[0],
                    contrast=color_params[1],
                    saturation=color_params[2],
                    hue=color_params[3]
                )
            )
        
        # Gaussian blur
        if aug_params.get('gaussian_blur', 0) > 0:
            transform_list.append(
                transforms.RandomApply([
                    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
                ], p=aug_params['gaussian_blur'])
            )
    
    # Add normalization (always last before tensor conversion)
    transform_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Add random erasing after tensor conversion if specified
    if use_augmentation and aug_params.get('random_erasing', 0) > 0:
        transform_list.append(
            transforms.RandomErasing(
                p=aug_params['random_erasing'],
                scale=(0.02, 0.33),
                ratio=(0.3, 3.3),
                value=0
            )
        )
    
    transform = transforms.Compose(transform_list)
    
    # Debug info only for rank 0 and only occasionally
    debug_output = (rank == 0)
    
    # REMOVED THE DUPLICATE PRINT STATEMENT HERE
    
    # Process images
    images = []
    for i, img_bytes in enumerate(batch['image_data']):
        try:
            img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
            img_tensor = transform(img)
            images.append(img_tensor)
        except Exception as e:
            if debug_output and i < 3:  # Only show first few errors
                print(f"Error processing image {i}: {e}")
            images.append(torch.zeros((3, 224, 224)))
    
    images_tensor = torch.stack(images).to(device)
    
    # Process labels 
    class_names = batch['class_name']
    
    # Ensure class_to_idx is a dictionary, not an integer
    if class_to_idx is None or not isinstance(class_to_idx, dict):
        unique_classes = sorted(set(class_names))
        class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}
        if debug_output:
            print(f"Created temporary class mapping with {len(class_to_idx)} classes")
    
    labels = []
    for class_name in class_names:
        # Use isinstance check instead of 'in' operator with potential int
        if isinstance(class_to_idx, dict) and class_name in class_to_idx:
            labels.append(class_to_idx[class_name])
        else:
            labels.append(0)  # Default to class 0 for unknown classes
            if debug_output:
                print(f"Warning: Unknown class '{class_name}', using default")
    
    labels_tensor = torch.tensor(labels, dtype=torch.long).to(device)
    
    return images_tensor, labels_tensor


def create_custom_augmentation_config(
    rotation=15,
    brightness=0.2,
    contrast=0.2,
    saturation=0.2,
    hue=0.1,
    horizontal_flip=0.5,
    vertical_flip=0.1,
    gaussian_blur=0.1,
    random_erasing=0.1
):
    """Create custom augmentation configuration
    
    Args:
        rotation: Degrees of random rotation (0 to disable)
        brightness: Brightness jitter factor (0 to disable)
        contrast: Contrast jitter factor (0 to disable)
        saturation: Saturation jitter factor (0 to disable)
        hue: Hue jitter factor (0 to disable)
        horizontal_flip: Probability of horizontal flip (0 to disable)
        vertical_flip: Probability of vertical flip (0 to disable)
        gaussian_blur: Probability of gaussian blur (0 to disable)
        random_erasing: Probability of random erasing (0 to disable)
    
    Returns:
        dict: Custom augmentation configuration
    """
    return {
        'rotation': rotation,
        'brightness': brightness,
        'contrast': contrast,
        'saturation': saturation,
        'hue': hue,
        'horizontal_flip': horizontal_flip,
        'vertical_flip': vertical_flip,
        'gaussian_blur': gaussian_blur,
        'random_erasing': random_erasing
    }


## Usage examples:
# def example_usage():
#     """Examples of how to use the enhanced convert_batch_to_tensors function"""
    
#     # Example 1: No augmentation (default behavior)
#     images, labels = convert_batch_to_tensors(batch, device, class_to_idx)
    
#     # Example 2: Light augmentation
#     images, labels = convert_batch_to_tensors(
#         batch, device, class_to_idx, 
#         use_augmentation=True, 
#         augmentation_strength='light'
#     )
    
#     # Example 3: Heavy augmentation
#     images, labels = convert_batch_to_tensors(
#         batch, device, class_to_idx,
#         use_augmentation=True,
#         augmentation_strength='heavy'
#     )
    
#     # Example 4: Custom augmentation
#     custom_aug = create_custom_augmentation_config(
#         rotation=20,
#         brightness=0.3,
#         horizontal_flip=0.7,
#         gaussian_blur=0.2,
#         random_erasing=0.15
#     )
    
#     images, labels = convert_batch_to_tensors(
#         batch, device, class_to_idx,
#         use_augmentation=True,
#         augmentation_strength=custom_aug
#     )
    

In [0]:
import torch

# 1. FIRST: Define device globally
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 2. Create class mapping
print("Creating class mapping...")
label_to_idx = create_comprehensive_label_mapping(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}")

# 3. Verify datasets
print("Verifying datasets...")
train_valid = verify_mds_dataset(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}")
val_valid = verify_mds_dataset(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_val_dir}")

print("Setup complete! Ready for training.")

In [0]:
# Label Mapping
print("=== Creating Comprehensive Label Mapping ===")
train_path = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}/{mds_train_dir}"

# Global variables that need to be defined
label_to_idx = {}

try:
    label_to_idx = create_comprehensive_label_mapping(train_path)
    num_classes = len(label_to_idx)
    
    # If we didn't get enough classes, use the fallback
    if num_classes < 50:
        print("Insufficient classes found, using fallback mapping...")
        label_to_idx = create_imagenet_tiny200_mapping()
        num_classes = len(label_to_idx)
        
except Exception as e:
    print(f"Error in comprehensive mapping: {e}")
    print("Using fallback mapping...")
    label_to_idx = create_imagenet_tiny200_mapping()
    num_classes = len(label_to_idx)

print(f"\nFinal label mapping created with {num_classes} classes")

# Test the mapping with a few samples
print("\n=== Testing Label Mapping ===")
test_classes = ['barrel, cask', 'school bus', 'pizza, pizza pie']
for test_class in test_classes:
    idx = label_to_idx.get(test_class, -1)
    print(f"'{test_class}' -> index {idx}")

# Show some actual mappings that exist
print(f"\nFirst 10 actual mappings:")
for i, (class_name, idx) in enumerate(list(label_to_idx.items())[:10]):
    print(f"  {idx}: {class_name}")

print(f"\nLabel mapping setup complete!")

In [0]:
def get_model(lr=0.001):
    """Create MobileNetV2 model for ImageNet Tiny-200"""
    from torchvision.models import MobileNet_V2_Weights
    
    model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    
    # Freeze feature extraction layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace classifier for num_classes
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes)
    
    # Only train the classifier
    for param in model.classifier.parameters():
        param.requires_grad = True
    
    return model

In [0]:
class LabelSmoothingCrossEntropy(nn.Module):
    ## import torch.nn as nn ## moved up in dependencies import

    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_preds = torch.log_softmax(pred, dim=1)
        smooth_target = torch.zeros_like(log_preds).scatter_(1, target.unsqueeze(1), 1)
        smooth_target = smooth_target * (1 - self.smoothing) + self.smoothing / n_classes
        return (-smooth_target * log_preds).sum(dim=1).mean()

#### IF needed -- some checks before training 

In [0]:
# # Before training starts:
# if torch.cuda.is_available():
#     for i in range(torch.cuda.device_count()):
#         total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
#         print(f"GPU {i}: {total_memory:.1f}GB total memory")
#         torch.cuda.empty_cache()

In [0]:
# def debug_data_loading_bottleneck():
#     """Debug the specific data loading steps"""
#     import time
    
#     print("=== DEBUGGING DATA LOADING BOTTLENECK ===")
    
#     # Test 1: Check if it's the MDS dataset creation
#     print("1. Testing MDS dataset creation...")
#     start_time = time.time()
    
#     try:
#         from streaming import StreamingDataset
#         import tempfile
        
#         cache_dir = tempfile.mkdtemp(prefix="debug_cache_")
#         print(f"   Cache dir: {cache_dir}")
        
#         print("   Creating StreamingDataset...")
#         dataset = StreamingDataset(
#             remote="/Volumes/mmt/pytorch/torch_data/imagenet_tiny200_mds_train",
#             local=cache_dir,
#             shuffle=True,
#             batch_size=4,  # Small batch for testing
#             predownload=8,  # Minimal predownload
#             download_timeout=120  # 2 minute timeout
#         )
        
#         creation_time = time.time() - start_time
#         print(f"✓ Dataset created in {creation_time:.1f} seconds")
        
#         # Test 2: Check if it's the first sample access
#         print("2. Testing first sample access...")
#         start_time = time.time()
        
#         first_sample = dataset[0]
#         sample_time = time.time() - start_time
#         print(f"✓ First sample loaded in {sample_time:.1f} seconds")
#         print(f"   Sample keys: {list(first_sample.keys())}")
        
#         # Test 3: Check if it's the DataLoader iteration
#         print("3. Testing DataLoader iteration...")
#         from streaming import StreamingDataLoader
        
#         dataloader = StreamingDataLoader(
#             dataset,
#             batch_size=4,
#             num_workers=0,  # No multiprocessing for debugging
#             pin_memory=False
#         )
        
#         start_time = time.time()
#         batch = next(iter(dataloader))
#         batch_time = time.time() - start_time
#         print(f"✓ First batch loaded in {batch_time:.1f} seconds")
        
#         # Test 4: Check if it's the tensor conversion
#         print("4. Testing tensor conversion...")
#         start_time = time.time()
        
#         device = torch.device('cuda:0')
#         inputs, labels = convert_batch_to_tensors(batch, device, label_to_idx, 0)
#         conversion_time = time.time() - start_time
#         print(f"✓ Tensor conversion completed in {conversion_time:.1f} seconds")
#         print(f"   Inputs shape: {inputs.shape}, Labels shape: {labels.shape}")
        
#         print("=== ALL DATA LOADING TESTS PASSED ===")
        
#         # Cleanup
#         import shutil
#         shutil.rmtree(cache_dir, ignore_errors=True)
        
#     except Exception as e:
#         elapsed = time.time() - start_time
#         print(f"✗ FAILED after {elapsed:.1f} seconds: {e}")
#         import traceback
#         traceback.print_exc()

# # Run this test
# debug_data_loading_bottleneck()

In [0]:
# def debug_distributed_training():
#     """Debug the distributed training components specifically"""
#     import time
    
#     print("=== DEBUGGING DISTRIBUTED TRAINING ===")
    
#     # Test 1: Model creation and DDP wrapping
#     print("1. Testing model creation...")
#     start_time = time.time()
    
#     try:
#         model = get_model(lr=0.001)
#         device = torch.device('cuda:0')
#         model = model.to(device)
#         print(f"✓ Model created and moved to GPU in {time.time() - start_time:.1f} seconds")
        
#         # Test 2: Simple forward pass
#         print("2. Testing simple forward pass...")
#         start_time = time.time()
        
#         # Create dummy input
#         dummy_input = torch.randn(4, 3, 224, 224).to(device)
#         model.eval()
#         with torch.no_grad():
#             output = model(dummy_input)
        
#         print(f"✓ Forward pass completed in {time.time() - start_time:.1f} seconds")
#         print(f"   Output shape: {output.shape}")
        
#         # Test 3: Training step without distributed
#         print("3. Testing training step (non-distributed)...")
#         start_time = time.time()
        
#         model.train()
#         criterion = torch.nn.CrossEntropyLoss()
#         optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
        
#         # Dummy labels
#         dummy_labels = torch.randint(0, 200, (4,)).to(device)
        
#         optimizer.zero_grad()
#         outputs = model(dummy_input)
#         loss = criterion(outputs, dummy_labels)
#         loss.backward()
#         optimizer.step()
        
#         print(f"✓ Training step completed in {time.time() - start_time:.1f} seconds")
#         print(f"   Loss: {loss.item():.4f}")
        
#         print("=== NON-DISTRIBUTED TRAINING WORKS FINE ===")
        
#     except Exception as e:
#         print(f"✗ FAILED: {e}")
#         import traceback
#         traceback.print_exc()

# # Run this test
# debug_distributed_training()

In [0]:
# def simple_objective_function(trial):
#     """[simple] version to isolate the distributed issue"""
    
#     lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
#     batch_size = trial.suggest_categorical('batch_size', [8, 16])  # Smaller batches
#     optimizer_name = trial.suggest_categorical('optimizer', ['AdamW']) 
    
#     print(f"Trial {trial.number}: lr={lr}, batch_size={batch_size}, optimizer={optimizer_name}")
    
#     with mlflow.start_run(run_name=f"simplified_trial_{trial.number}", nested=True):
#         mlflow.log_params({
#             'lr': lr,
#             'batch_size': batch_size,
#             'optimizer': optimizer_name,
#             'test_mode': 'simplified'
#         })
        
#         try:
#             # Test with SINGLE PROCESS first (no distribution)
#             distributor = TorchDistributor(
#                 num_processes=1,  # ← Single process to test
#                 local_mode=True,   # ← Local mode
#                 use_gpu=True
#             )
            
#             result = distributor.run(
#                 simplified_train_function,  # ← We'll create this
#                 lr=lr,
#                 batch_size=batch_size,
#                 optimizer_name=optimizer_name
#             )
            
#             trial_metric = result.get('val_acc', 0.0) if isinstance(result, dict) else 0.0
#             mlflow.log_metric("validation_accuracy", trial_metric)
            
#             print(f"Simplified trial {trial.number} completed: {trial_metric:.4f}")
#             return trial_metric
            
#         except Exception as e:
#             print(f"Simplified trial {trial.number} failed: {e}")
#             mlflow.log_param("error", str(e))
#             return 0.0

# def simplified_train_function(lr, batch_size, optimizer_name):
#     """Simplified training function - single process"""
#     print("Starting simplified training (single process)...")
    
#     try:
#         device = torch.device('cuda:0')
        
#         # Create model
#         model = get_model(lr=lr)
#         model = model.to(device)
        
#         # Create single dataloader (no distributed)
#         train_loader, train_cache = get_dataloader_with_mosaic(
#             "/Volumes/mmt/pytorch/torch_data/imagenet_tiny200_mds_train",
#             None, batch_size, rank=0
#         )
        
#         val_loader, val_cache = get_dataloader_with_mosaic(
#             "/Volumes/mmt/pytorch/torch_data/imagenet_tiny200_mds_val",
#             None, batch_size, rank=0
#         )
        
#         # Setup training
#         criterion = torch.nn.CrossEntropyLoss()
#         optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        
#         print("Starting training loop...")
        
#         # Train for just 1 epoch with limited steps
#         model.train()
#         for step, batch in enumerate(train_loader):
#             if step >= 10:  # Only 10 steps for testing
#                 break
                
#             print(f"Step {step}/10...")
            
#             inputs, labels = convert_batch_to_tensors(batch, device, label_to_idx, 0)
            
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             if step % 5 == 0:
#                 print(f"Step {step}: Loss = {loss.item():.4f}")
        
#         # Quick validation
#         model.eval()
#         val_acc = 0.1  # Placeholder
        
#         print(f"Simplified training completed. Val accuracy: {val_acc:.4f}")
        
#         # Cleanup
#         import shutil
#         for cache in [train_cache, val_cache]:
#             if os.path.exists(cache):
#                 shutil.rmtree(cache)
        
#         return {"val_acc": val_acc, "status": "completed"}
        
#     except Exception as e:
#         print(f"Error in simplified training: {e}")
#         import traceback
#         traceback.print_exc()
#         return {"val_acc": 0.0, "status": "failed", "error": str(e)}

# # Test with simplified version
# print("Testing simplified single-process version...")
# study_simple = optuna.create_study(direction="maximize")
# study_simple.optimize(simple_objective_function, n_trials=1)

### Training Setup

In [0]:
def distributed_train_and_evaluate(lr=0.001, batch_size=256, optimizer_name='AdamW',
                                 weight_decay=1e-4, step_size=7, gamma=0.1,
                                 dropout_rate=0.2, label_smoothing=0.1,
                                 momentum=0.9, nesterov=False, beta1=0.9, beta2=0.999, eps=1e-8,
                                 data_storage_location=None, mds_train_dir=None, 
                                 mds_val_dir=None, num_epochs=2,
                                 mlflow_run_id=None, mlflow_tracking_uri=None,
                                 mlflow_experiment_name=None,
                                 use_augmentation=True, augmentation_strength='medium'):
    
    """Define distributed training with essential model and system metrics and data augmentation"""
    
    import os
    import time
    import shutil
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from datetime import timedelta
    import numpy as np
    
    # Initialize core variables
    cache_paths = []
    training_start_time = time.time()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    
    # Core metrics storage
    epoch_metrics = []
    essential_model_metrics = {
        'parameter_count': 0,
        'trainable_parameters': 0,
        'model_size_mb': 0.0,
        'gradient_norms': [],
        'weight_norms': [],
        'learning_curves': {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    }
    
    # Calculate effective batch size
    effective_batch_size = batch_size * world_size
    
    try:
        print(f"Rank {global_rank}/{world_size}: Starting training...")
        print(f"Per-GPU batch size: {batch_size}, Effective batch size: {effective_batch_size}")
        
        if global_rank == 0:
            print(f" MEMORY OPTIMIZATION:")
            print(f"   - Large batch size: {batch_size} per GPU")
            print(f"   - Effective batch size: {effective_batch_size}")
            print(f"   - Augmentation: {augmentation_strength if use_augmentation else 'disabled'}")
        
        # Initialize distributed training
        if world_size > 1:
            dist.init_process_group("nccl", timeout=timedelta(seconds=1800))
        
        ### 
        device = torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
            torch.cuda.empty_cache()
        
        # Create enhanced model
        model = get_model(lr=lr)
        
        # Add dropout to classifier
        if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Sequential):
            if len(model.classifier) > 1 and hasattr(model.classifier[1], 'in_features'):
                num_features = model.classifier[1].in_features
                model.classifier = nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.Linear(num_features, 512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(dropout_rate * 0.5),
                    nn.Linear(512, 200)
                )
        
        model = model.to(device)
        
        # Enable channels_last for better memory efficiency with larger batches
        model = model.to(memory_format=torch.channels_last)
        
        # Calculate essential model metrics
        if global_rank == 0:
            essential_model_metrics['parameter_count'] = sum(p.numel() for p in model.parameters())
            essential_model_metrics['trainable_parameters'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
            essential_model_metrics['model_size_mb'] = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
            print(f"Model: {essential_model_metrics['parameter_count']:,} total params, "
                  f"{essential_model_metrics['trainable_parameters']:,} trainable")
        
        if world_size > 1:
            model = DDP(model, device_ids=[local_rank], output_device=local_rank)
        
        # Create data loaders with larger batch sizes
        train_input_remote_path = os.path.join(data_storage_location, mds_train_dir)
        val_input_remote_path = os.path.join(data_storage_location, mds_val_dir)
        
        # Use larger batch sizes with optimized DataLoader
        train_dataloader, train_cache_path = get_dataloader_with_mosaic(
            train_input_remote_path, f"/local_disk0/tmp/train_cache", batch_size, global_rank
        )
        cache_paths.append(train_cache_path)
        
        # Use even larger batch size for validation (no gradients needed)
        val_batch_size = min(batch_size * 2, 1024)  # 2x larger for validation
        val_dataloader, val_cache_path = get_dataloader_with_mosaic(
            val_input_remote_path, f"/local_disk0/tmp/val_cache", val_batch_size, global_rank
        )
        cache_paths.append(val_cache_path)
        
        if world_size > 1:
            dist.barrier()
        
        # Setup training components
        criterion = LabelSmoothingCrossEntropy(label_smoothing) if label_smoothing > 0 else nn.CrossEntropyLoss()
        
        model_params = [p for p in (model.module if world_size > 1 else model).parameters() if p.requires_grad]
        
        # Adjust learning rate for larger batch size (linear scaling rule)
        adjusted_lr = lr * max(1.0, effective_batch_size / 256)  # Scale LR with batch size
        
        if optimizer_name == 'SGD':
            optimizer = torch.optim.SGD(model_params, lr=adjusted_lr, momentum=momentum, 
                                      weight_decay=weight_decay, nesterov=nesterov)
        elif optimizer_name == 'Adam':
            optimizer = torch.optim.Adam(model_params, lr=adjusted_lr, weight_decay=weight_decay, 
                                       betas=(beta1, beta2), eps=eps)
        else:  # AdamW
            optimizer = torch.optim.AdamW(model_params, lr=adjusted_lr, weight_decay=weight_decay, 
                                        betas=(beta1, beta2), eps=eps)
        
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
        
        # Add mixed precision scaler for memory efficiency | torch.amp.GradScaler('cuda', args...)
        scaler = torch.cuda.amp.GradScaler() 
        
        best_val_acc = 0.0
        
        # Log memory usage before training
        if global_rank == 0 and torch.cuda.is_available():
            memory_allocated = torch.cuda.memory_allocated() / 1024**3
            memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            memory_util = memory_allocated/memory_total*100
            print(f"Pre-training GPU memory: {memory_allocated:.2f} GB / {memory_total:.2f} GB ({memory_util:.1f}%)")
        
        # Training loop
        for epoch in range(num_epochs):
            print(f"Rank {global_rank}: Epoch {epoch+1}/{num_epochs}")
            
            # Training phase with memory optimization
            train_results = train_one_epoch(
                model, criterion, optimizer, scheduler, train_dataloader, 
                epoch, device, global_rank, label_to_idx,
                use_augmentation=use_augmentation,
                augmentation_strength=augmentation_strength,
                scaler=scaler  # Pass scaler for mixed precision
            )
            
            # Validation phase
            val_results = evaluate(
                model, criterion, val_dataloader, epoch, device, global_rank, label_to_idx,
                use_augmentation=False,
                scaler=scaler  # Pass scaler for mixed precision
            )
            
            train_loss, train_acc = train_results['loss'], train_results['accuracy']
            val_loss, val_acc = val_results['loss'], val_results['accuracy']
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
            
            # Store essential metrics
            if global_rank == 0:
                essential_model_metrics['gradient_norms'].append(train_results.get('gradient_norm', 0.0))
                essential_model_metrics['weight_norms'].append(train_results.get('weight_norm', 0.0))
                essential_model_metrics['learning_curves']['train_loss'].append(float(train_loss))
                essential_model_metrics['learning_curves']['val_loss'].append(float(val_loss))
                essential_model_metrics['learning_curves']['train_acc'].append(float(train_acc))
                essential_model_metrics['learning_curves']['val_acc'].append(float(val_acc))
                
                epoch_data = {
                    'epoch': epoch,
                    'train_loss': float(train_loss),
                    'train_acc': float(train_acc),
                    'val_loss': float(val_loss),
                    'val_acc': float(val_acc),
                    'learning_rate': float(optimizer.param_groups[0]['lr']),
                    'gradient_norm': train_results.get('gradient_norm', 0.0),
                    'weight_norm': train_results.get('weight_norm', 0.0)
                }
                epoch_metrics.append(epoch_data)
                
                # Log memory usage during training
                if torch.cuda.is_available():
                    memory_allocated = torch.cuda.memory_allocated() / 1024**3
                    memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    memory_util = memory_allocated/memory_total*100
                    print(f'Epoch {epoch+1}: Train {train_acc:.4f}, Val {val_acc:.4f}, '
                          f'Loss {train_loss:.4f}/{val_loss:.4f}, LR {optimizer.param_groups[0]["lr"]:.2e}, '
                          f'Memory {memory_util:.1f}%')
                else:
                    print(f'Epoch {epoch+1}: Train {train_acc:.4f}, Val {val_acc:.4f}, '
                          f'Loss {train_loss:.4f}/{val_loss:.4f}, LR {optimizer.param_groups[0]["lr"]:.2e}')
        
        training_time = time.time() - training_start_time
        
        return {
            "val_acc": float(best_val_acc),
            "train_loss": float(train_loss),
            "val_loss": float(val_loss),
            "train_acc": float(train_acc),
            "best_val_acc": float(best_val_acc),
            "status": "completed",
            "epochs_completed": num_epochs,
            "training_time": training_time,
            "epoch_metrics": epoch_metrics,
            "model_metrics": essential_model_metrics,
            "hyperparameters": {
                "lr": lr, "adjusted_lr": adjusted_lr, "batch_size": batch_size, 
                "effective_batch_size": effective_batch_size, "optimizer": optimizer_name,
                "weight_decay": weight_decay, "dropout_rate": dropout_rate, 
                "label_smoothing": label_smoothing, "use_augmentation": use_augmentation,
                "augmentation_strength": augmentation_strength, "memory_optimized": True
            }
        }
        
    except Exception as e:
        print(f"Rank {global_rank}: Training error: {e}")
        return {"val_acc": 0.0, "status": "failed", "error": str(e)}
    
    finally:
        for cache_path in cache_paths:
            if os.path.exists(cache_path):
                shutil.rmtree(cache_path, ignore_errors=True)
        if world_size > 1 and dist.is_initialized():
            dist.destroy_process_group()


def train_one_epoch(model, criterion, optimizer, scheduler, train_dataloader, 
                   epoch, device, global_rank, label_to_idx,
                   use_augmentation=True, augmentation_strength='medium',
                   scaler=None):
    """Training loop """
    
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    successful_batches = 0
    gradient_norms = []
    
    # Enable mixed precision training
    use_mixed_precision = scaler is not None
    
    for step, batch in enumerate(train_dataloader):
        try:
            # Apply augmentation during training
            inputs, labels = convert_batch_to_tensors(
                batch, device, label_to_idx, global_rank,
                use_augmentation=use_augmentation,
                augmentation_strength=augmentation_strength
            )
            
            # Convert to channels_last for better memory efficiency
            inputs = inputs.to(memory_format=torch.channels_last)
            
            optimizer.zero_grad()
            
            # Use mixed precision if available
            if use_mixed_precision:
                # with torch.cuda.amp.autocast():
                with torch.amp.autocast('cuda'):    
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                # Scale loss and backward pass
                scaler.scale(loss).backward()
                
                # Calculate gradient norm before unscaling
                if global_rank == 0 and step % 50 == 0:
                    # Unscale gradients for norm calculation
                    scaler.unscale_(optimizer)
                    grad_norm = calculate_gradient_norm(model, 1 if torch.distributed.get_world_size() == 1 else torch.distributed.get_world_size())
                    gradient_norms.append(grad_norm)
                
                # Update with scaler
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard precision training
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                
                # Calculate gradient norm
                if global_rank == 0 and step % 50 == 0:
                    grad_norm = calculate_gradient_norm(model, 1 if torch.distributed.get_world_size() == 1 else torch.distributed.get_world_size())
                    gradient_norms.append(grad_norm)
                
                optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            successful_batches += 1
            
            # Log memory usage periodically
            if global_rank == 0 and step % 100 == 0:
                current_acc = correct_predictions / total_samples if total_samples > 0 else 0
                aug_status = f"(aug: {augmentation_strength})" if use_augmentation else "(no aug)"
                
                if torch.cuda.is_available():
                    memory_allocated = torch.cuda.memory_allocated() / 1024**3
                    memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    memory_util = memory_allocated/memory_total*100
                    print(f"  Step {step}: Loss {loss.item():.4f}, Acc {current_acc:.4f} {aug_status}, "
                          f"Memory {memory_util:.1f}%")
                else:
                    print(f"  Step {step}: Loss {loss.item():.4f}, Acc {current_acc:.4f} {aug_status}")
                
        except Exception as e:
            if step < 5:
                print(f"Training step {step} error: {e}")
            continue
    
    scheduler.step()
    
    epoch_loss = running_loss / successful_batches if successful_batches > 0 else 0.0
    epoch_acc = correct_predictions / total_samples if total_samples > 0 else 0.0
    avg_gradient_norm = np.mean(gradient_norms) if gradient_norms else 0.0
    
    return {
        'loss': epoch_loss,
        'accuracy': epoch_acc,
        'gradient_norm': avg_gradient_norm,
        'weight_norm': calculate_weight_norm(model, 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()),
        'mixed_precision_used': use_mixed_precision
    }


def evaluate(model, criterion, val_dataloader, epoch, device, global_rank, label_to_idx,
            use_augmentation=False, augmentation_strength='light', scaler=None):
    """Evaluation loop """
    
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    successful_batches = 0
    
    # Enable mixed precision for evaluation too
    use_mixed_precision = scaler is not None
    
    with torch.no_grad():
        for step, batch in enumerate(val_dataloader):
            try:
                # Typically no augmentation during validation
                inputs, labels = convert_batch_to_tensors(
                    batch, device, label_to_idx, global_rank,
                    use_augmentation=use_augmentation,
                    augmentation_strength=augmentation_strength
                )
                
                # Convert to channels_last for consistency
                inputs = inputs.to(memory_format=torch.channels_last)
                
                # Use mixed precision if available
                if use_mixed_precision:
                    #with torch.cuda.amp.autocast():
                    with torch.amp.autocast('cuda'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_samples += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()
                successful_batches += 1
                
                # Log memory usage periodically during validation
                if global_rank == 0 and step % 50 == 0 and torch.cuda.is_available():
                    memory_allocated = torch.cuda.memory_allocated() / 1024**3
                    memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    memory_util = memory_allocated/memory_total*100
                    if step == 0:  # Only log once per epoch
                        print(f"  Validation memory usage: {memory_util:.1f}%")
                
            except Exception as e:
                continue
    
    val_loss = running_loss / successful_batches if successful_batches > 0 else 0.0
    val_acc = correct_predictions / total_samples if total_samples > 0 else 0.0
    
    return {
        'loss': val_loss, 
        'accuracy': val_acc,
        'mixed_precision_used': use_mixed_precision
    }

In [0]:
def calculate_gradient_norm(model, world_size):
    """Calculate L2 norm of gradients"""
    import torch
    
    total_norm = 0.0
    param_count = 0
    
    if world_size > 1:
        parameters = model.module.parameters()
    else:
        parameters = model.parameters()
    
    for p in parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
    
    return (total_norm ** 0.5) if param_count > 0 else 0.0


def calculate_weight_norm(model, world_size):
    """Calculate L2 norm of model weights"""
    import torch
    
    total_norm = 0.0
    param_count = 0
    
    if world_size > 1:
        parameters = model.module.parameters()
    else:
        parameters = model.parameters()
    
    for p in parameters:
        param_norm = p.data.norm(2)
        total_norm += param_norm.item() ** 2
        param_count += 1
    
    return (total_norm ** 0.5) if param_count > 0 else 0.0


In [0]:
def objective_function(trial):
    """Define the objective function with essential hyperparameters"""
    
    # Core hyperparameters with UPDATED LARGER BATCH SIZES
    lr = trial.suggest_float('lr', 5e-5, 5e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [128, 256, 384, 512])  # MUCH LARGER cf [32, 64, 128]
    optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'SGD', 'Adam'])
    
    # Optimizer-specific weight decay
    if optimizer_name == 'AdamW':
        weight_decay = trial.suggest_float('weight_decay', 1e-4, 1e-1, log=True)
    elif optimizer_name == 'SGD':
        weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2, log=True)
    else:
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)
    
    # Additional key hyperparameters
    step_size = trial.suggest_int('step_size', 5, 15)
    gamma = trial.suggest_float('gamma', 0.1, 0.7)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    label_smoothing = trial.suggest_float('label_smoothing', 0.0, 0.2)
    
    # Data augmentation hyperparameters
    use_augmentation = trial.suggest_categorical('use_augmentation', [True, False])
    
    if use_augmentation:
        augmentation_strength = trial.suggest_categorical('augmentation_strength', ['light', 'medium', 'heavy'])
    else:
        augmentation_strength = None
    
    # Optimizer-specific parameters
    momentum = trial.suggest_float('momentum', 0.85, 0.99) if optimizer_name == 'SGD' else 0.9
    nesterov = trial.suggest_categorical('nesterov', [True, False]) if optimizer_name == 'SGD' else False
    beta1 = trial.suggest_float('beta1', 0.85, 0.95) if optimizer_name == 'AdamW' else 0.9
    beta2 = trial.suggest_float('beta2', 0.95, 0.999) if optimizer_name == 'AdamW' else 0.999
    eps = trial.suggest_float('eps', 1e-9, 1e-7, log=True) if optimizer_name == 'AdamW' else 1e-8
    
    # Calculate effective batch size for logging
    effective_batch_size = batch_size * num_workers
    
    aug_info = f"aug_{augmentation_strength}" if use_augmentation else "no_aug"
    print(f"LARGE BATCH Trial {trial.number}: {optimizer_name}, lr={lr:.2e}, bs={batch_size}, "
          f"eff_bs={effective_batch_size}, wd={weight_decay:.2e}, {aug_info}")
    
    run_name = f"trial_{trial.number}_{optimizer_name}_bs{batch_size}_lr{lr:.2e}_{aug_info}"
    
    with mlflow.start_run(run_name=run_name, nested=True) as trial_run:
        trial_start_time = time.time()
        trial_run_id = trial_run.info.run_id
        
        # Log essential parameters including large batch info
        trial_params = {
            'trial_number': trial.number,
            'lr': lr, 'batch_size': batch_size, 'effective_batch_size': effective_batch_size,
            'optimizer': optimizer_name,
            'weight_decay': weight_decay, 'step_size': step_size, 'gamma': gamma,
            'dropout_rate': dropout_rate, 'label_smoothing': label_smoothing,
            'momentum': momentum, 'nesterov': nesterov, 'beta1': beta1, 'beta2': beta2, 'eps': eps,
            'use_augmentation': use_augmentation,
            'augmentation_strength': str(augmentation_strength) if augmentation_strength else 'none',
            'num_epochs': NUM_EPOCHS, 'num_workers': num_workers,
            'model_architecture': 'mobilenetv2',
            'large_batch_optimization': True,  # Flag for large batch training
            'memory_optimized': True
        }
        mlflow.log_params(trial_params)
        
        try:
            distributor = TorchDistributor(num_processes=num_workers, local_mode=False, use_gpu=True)
            result = distributor.run(
                distributed_train_and_evaluate,  # Using updated function with large batch support
                lr=lr, batch_size=batch_size, optimizer_name=optimizer_name,
                weight_decay=weight_decay, step_size=step_size, gamma=gamma,
                dropout_rate=dropout_rate, label_smoothing=label_smoothing,
                momentum=momentum, nesterov=nesterov, beta1=beta1, beta2=beta2, eps=eps,
                data_storage_location=data_storage_location,
                mds_train_dir=mds_train_dir, mds_val_dir=mds_val_dir,
                num_epochs=NUM_EPOCHS, mlflow_run_id=trial_run_id,
                use_augmentation=use_augmentation,
                augmentation_strength=augmentation_strength
            )
            
            trial_end_time = time.time()
            trial_duration = trial_end_time - trial_start_time
            
            if isinstance(result, dict):
                trial_metric = result.get('val_acc', 0.0)
                
                # Log essential model metrics
                model_metrics = result.get('model_metrics', {})
                if model_metrics:
                    mlflow.log_metric("model_parameter_count", model_metrics.get('parameter_count', 0))
                    mlflow.log_metric("model_trainable_parameters", model_metrics.get('trainable_parameters', 0))
                    mlflow.log_metric("model_size_mb", model_metrics.get('model_size_mb', 0.0))
                
                    # Log gradient and weight norms
                    gradient_norms = model_metrics.get('gradient_norms', [])
                    weight_norms = model_metrics.get('weight_norms', [])
                    
                    if gradient_norms:
                        mlflow.log_metric("final_gradient_norm", gradient_norms[-1])
                        mlflow.log_metric("mean_gradient_norm", np.mean(gradient_norms))
                        mlflow.log_metric("gradient_stability", 1.0 / (np.var(gradient_norms) + 1e-8))
                    
                    if weight_norms:
                        mlflow.log_metric("final_weight_norm", weight_norms[-1])
                        mlflow.log_metric("mean_weight_norm", np.mean(weight_norms))
                
                # Log epoch-by-epoch metrics
                epoch_metrics = result.get('epoch_metrics', [])
                if epoch_metrics:
                    for epoch_data in epoch_metrics:
                        epoch = epoch_data['epoch']
                        mlflow.log_metric("train_loss_epoch", epoch_data['train_loss'], step=epoch)
                        mlflow.log_metric("train_accuracy_epoch", epoch_data['train_acc'], step=epoch)
                        mlflow.log_metric("val_loss_epoch", epoch_data['val_loss'], step=epoch)
                        mlflow.log_metric("val_accuracy_epoch", epoch_data['val_acc'], step=epoch)
                        mlflow.log_metric("learning_rate_epoch", epoch_data['learning_rate'], step=epoch)
                        mlflow.log_metric("gradient_norm_epoch", epoch_data['gradient_norm'], step=epoch)
                
                # Log large batch training metrics
                training_metrics = {
                    'final_validation_accuracy': trial_metric,
                    'final_train_loss': result.get('train_loss', 0.0),
                    'final_val_loss': result.get('val_loss', 0.0),
                    'final_train_accuracy': result.get('train_acc', 0.0),
                    'best_validation_accuracy': result.get('best_val_acc', trial_metric),
                    'epochs_completed': result.get('epochs_completed', NUM_EPOCHS),
                    'training_time_seconds': result.get('training_time', 0.0),
                    'trial_duration_seconds': trial_duration,
                    # Large batch specific metrics
                    'large_batch_size': float(batch_size),
                    'effective_batch_size': float(effective_batch_size),
                    'memory_efficiency_score': trial_metric / (batch_size / 100),  # Accuracy per 100 batch units
                    'throughput_samples_per_second': (50000 * num_workers * NUM_EPOCHS) / trial_duration if trial_duration > 0 else 0.0
                }
                mlflow.log_metrics(training_metrics)
                
                # Log augmentation-specific metrics
                if use_augmentation:
                    mlflow.log_metric("augmentation_enabled", 1.0)
                    mlflow.log_metric("augmentation_impact_score", trial_metric)
                else:
                    mlflow.log_metric("augmentation_enabled", 0.0)
                
                # Log efficiency metrics with large batch considerations
                if trial_duration > 0:
                    total_samples = 50000 * num_workers * NUM_EPOCHS
                    samples_per_second = total_samples / trial_duration
                    mlflow.log_metric('samples_per_second', samples_per_second)
                    mlflow.log_metric('time_per_epoch_seconds', trial_duration / NUM_EPOCHS)
                    mlflow.log_metric('samples_per_second_per_batch_size', samples_per_second / batch_size)
                
                # Log convergence metrics
                if epoch_metrics and len(epoch_metrics) > 1:
                    val_accs = [ep['val_acc'] for ep in epoch_metrics]
                    final_acc = val_accs[-1]
                    initial_acc = val_accs[0]
                    improvement = final_acc - initial_acc
                    
                    mlflow.log_metric("total_accuracy_improvement", improvement)
                    mlflow.log_metric("accuracy_improvement_rate", improvement / len(val_accs))
                    
                    # Training stability
                    if len(val_accs) >= 3:
                        recent_stability = 1.0 / (np.var(val_accs[-3:]) + 1e-8)
                        mlflow.log_metric("validation_stability", recent_stability)
                        
            else:
                trial_metric = 0.0
                mlflow.log_metric("final_validation_accuracy", trial_metric)
            
            # Set essential tags including large batch info
            performance_tier = "excellent" if trial_metric > 0.7 else "good" if trial_metric > 0.5 else "fair" if trial_metric > 0.3 else "poor"
            augmentation_tag = f"aug_{augmentation_strength}" if use_augmentation else "no_augmentation"
            
            mlflow.set_tags({
                "trial_number": str(trial.number),
                "optimizer": optimizer_name,
                "performance_tier": performance_tier,
                "augmentation_strategy": augmentation_tag,
                "trial_status": "completed",
                "large_batch_training": "true",
                "batch_size": str(batch_size),
                "effective_batch_size": str(effective_batch_size),
                "memory_optimized": "true"
            })
            
            print(f"Large Batch Trial {trial.number} completed: Accuracy = {trial_metric:.4f}, "
                  f"Batch Size = {batch_size}, Duration = {trial_duration:.1f}s, Aug = {aug_info}")
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print(f"OOM Error with large batch_size={batch_size}: {e}")
                trial_metric = 0.0
                mlflow.log_param("oom_error", f"batch_size_{batch_size}")
                mlflow.set_tag("trial_status", "oom_failed")
            else:
                raise
                
        except Exception as e:
            trial_end_time = time.time()
            trial_duration = trial_end_time - trial_start_time
            
            print(f"Large Batch Trial {trial.number} failed: {e}")
            trial_metric = 0.0
            
            mlflow.log_param("trial_status", "failed")
            mlflow.log_param("error_message", str(e)[:200])
            mlflow.log_metric("final_validation_accuracy", trial_metric)
            mlflow.log_metric("trial_duration_seconds", trial_duration)
            
            augmentation_tag = f"aug_{augmentation_strength}" if use_augmentation else "no_augmentation"
            mlflow.set_tags({
                "trial_number": str(trial.number),
                "trial_status": "failed",
                "optimizer": optimizer_name,
                "augmentation_strategy": augmentation_tag,
                "large_batch_training": "true"
            })
    
    return trial_metric

In [0]:
# Training Configuration
# num_classes = 200
NUM_EPOCHS =  5  # Test with just 1 epoch first if needed
num_workers = 2  # Reduce to 2 workers to decrease I/O contention

# n_trials = 5 ## can be increased 
# batch_sizes = [8, 16] -- update parameter-ranges/etc
# optimizers = ['AdamW']

In [0]:
def setup_spark_for_optuna():
    """Setup Spark with appropriate GPU configuration for Optuna"""
    from pyspark.sql import SparkSession
    
    try:
        import torch
        gpu_available = torch.cuda.is_available()
    except ImportError:
        gpu_available = False
    
    gpu_amount = "1" if gpu_available else "0"
    
    # Try to update existing session first
    try:
        spark = SparkSession.getActiveSession()
        if spark:
            spark.conf.set("spark.task.resource.gpu.amount", gpu_amount)
            return spark
    except:
        pass
    
    # Create new session if needed
    return SparkSession.builder \
        .appName("OptunaTuning") \
        .config("spark.task.resource.gpu.amount", gpu_amount) \
        .config("spark.executor.resource.gpu.amount", gpu_amount) \
        .getOrCreate()

spark = setup_spark_for_optuna()
# study.optimize(objective_function, n_trials=5, n_jobs=1)

In [0]:
## Run optimization with data augmentation tracking
start_time = time.time()

study = optuna.create_study(
    # study_name=f"pytorch_mobilenetv2_aug_{int(time.time())}",  
    study_name = f"pytorch_mobilenetv2_aug_{datetime.now().strftime('%Y%m%d_%H%M%S')}" ,
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=42)
)

print("Starting Optuna optimization with  BATCH SIZES and data augmentation...")
print(f"Trials: 5, Workers: {num_workers}, Epochs: {NUM_EPOCHS}")
print(f" BATCH SIZES: [128, 256, 384, 512] ") #cf [32, 64, 128]

with mlflow.start_run(run_name=f"optuna_optimization_{study.study_name}") as parent_run:
    parent_run_id = parent_run.info.run_id
    
    # Log optimization configuration including large batch info
    mlflow.log_params({
        "study_name": study.study_name,
        "n_trials": 5,
        "num_workers": num_workers,
        "num_epochs": NUM_EPOCHS,
        "model_architecture": "mobilenetv2",
        "optimization_type": "large_batch_with_augmentation",
        "augmentation_enabled": False,  # You can change this to True
        "large_batch_optimization": True, # update as appropriate
        "batch_size_range": "[128, 256, 384, 512]",
        "memory_optimization": True
    })
    
    print(f"Parent run ID: {parent_run_id}")
    
    # Run optimization with updated objective function ## alternative way to detect?
    study.optimize(objective_function, n_trials=5, n_jobs=1)
    
    # Log results
    end_time = time.time()
    optimization_duration = end_time - start_time
    
    best_params = study.best_params
    best_value = study.best_value
    
    completed_trials = [t for t in study.trials if t.state.name == "COMPLETE" and t.value is not None]
    trial_values = [t.value for t in completed_trials]
    
    print(f"\nBest parameters: {best_params}")
    print(f"Best validation accuracy: {best_value:.4f}")
    print(f"Best batch size: {best_params.get('batch_size', 'N/A')}")
    print(f"Completed trials: {len(completed_trials)}/{len(study.trials)}")
    
    # Analyze batch size impact
    batch_size_analysis = {}
    for trial in completed_trials:
        bs = trial.params.get('batch_size', 'unknown')
        if bs not in batch_size_analysis:
            batch_size_analysis[bs] = []
        batch_size_analysis[bs].append(trial.value)
    
    if batch_size_analysis:
        print(f"\nBatch Size Performance Analysis:")
        for bs, values in batch_size_analysis.items():
            avg_acc = np.mean(values)
            print(f"Batch size {bs}: {len(values)} trials, avg accuracy: {avg_acc:.4f}")
    
    # Analyze augmentation impact
    aug_trials = [t for t in completed_trials if t.params.get('use_augmentation', False)]
    no_aug_trials = [t for t in completed_trials if not t.params.get('use_augmentation', True)]
    
    aug_values = [t.value for t in aug_trials] if aug_trials else []
    no_aug_values = [t.value for t in no_aug_trials] if no_aug_trials else []
    
    print(f"\nAugmentation Analysis:")
    print(f"Trials with augmentation: {len(aug_trials)}")
    print(f"Trials without augmentation: {len(no_aug_trials)}")
    
    if aug_values:
        print(f"Avg accuracy with augmentation: {np.mean(aug_values):.4f}")
    if no_aug_values:
        print(f"Avg accuracy without augmentation: {np.mean(no_aug_values):.4f}")
    
    # Analyze augmentation strength impact
    strength_analysis = {}
    for trial in aug_trials:
        strength = trial.params.get('augmentation_strength', 'unknown')
        if strength not in strength_analysis:
            strength_analysis[strength] = []
        strength_analysis[strength].append(trial.value)
    
    if strength_analysis:
        print(f"\nAugmentation Strength Analysis:")
        for strength, values in strength_analysis.items():
            print(f"{strength}: {len(values)} trials, avg accuracy: {np.mean(values):.4f}")
    
    # Log essential optimization results with batch metrics
    optimization_results = {
        "best_validation_accuracy": float(best_value),
        "best_batch_size": float(best_params.get('batch_size', 0)),
        "best_effective_batch_size": float(best_params.get('batch_size', 0) * num_workers),
        "total_trials": float(len(study.trials)),
        "completed_trials": float(len(completed_trials)),
        "optimization_duration_minutes": float(optimization_duration / 60.0),
        "success_rate": float(len(completed_trials) / len(study.trials)) if len(study.trials) > 0 else 0.0,
        "large_batch_optimization": 1.0,
        "memory_optimization_enabled": 1.0
    }
    
    # Log batch size effectiveness
    if batch_size_analysis:
        for bs, values in batch_size_analysis.items():
            if isinstance(bs, (int, float)):
                mlflow.log_metric(f"avg_accuracy_batch_size_{bs}", float(np.mean(values)))
                mlflow.log_metric(f"trials_batch_size_{bs}", float(len(values)))
                mlflow.log_metric(f"memory_efficiency_batch_size_{bs}", float(np.mean(values) / (bs / 100)))
    
    # Log augmentation-specific results with proper type conversion
    augmentation_results = {
        "trials_with_augmentation": float(len(aug_trials)),
        "trials_without_augmentation": float(len(no_aug_trials)),
        "augmentation_usage_rate": float(len(aug_trials) / len(completed_trials)) if completed_trials else 0.0
    }
    
    if aug_values:
        augmentation_results.update({
            "avg_accuracy_with_augmentation": float(np.mean(aug_values)),
            "std_accuracy_with_augmentation": float(np.std(aug_values)),
            "best_accuracy_with_augmentation": float(np.max(aug_values))
        })
    
    if no_aug_values:
        augmentation_results.update({
            "avg_accuracy_without_augmentation": float(np.mean(no_aug_values)),
            "std_accuracy_without_augmentation": float(np.std(no_aug_values)),
            "best_accuracy_without_augmentation": float(np.max(no_aug_values))
        })
    
    # Calculate augmentation effectiveness with proper type conversion
    if aug_values and no_aug_values:
        aug_improvement = np.mean(aug_values) - np.mean(no_aug_values)
        augmentation_results["augmentation_improvement"] = float(aug_improvement)
        augmentation_results["augmentation_effective"] = 1.0 if aug_improvement > 0 else 0.0
        print(f"Augmentation improvement: {aug_improvement:+.4f}")
    
    # Log augmentation strength effectiveness - handle string values properly
    if strength_analysis:
        best_strength = max(strength_analysis.items(), key=lambda x: np.mean(x[1]))
        
        # Log the string value as a parameter, not a metric
        mlflow.log_param("best_augmentation_strength", str(best_strength[0]))
        
        # Log the numeric accuracy as a metric
        augmentation_results["best_strength_accuracy"] = float(np.mean(best_strength[1]))
        
        for strength, values in strength_analysis.items():
            # Use safe metric names (replace special characters)
            safe_strength = str(strength).replace(' ', '_').replace('-', '_')
            mlflow.log_metric(f"avg_accuracy_strength_{safe_strength}", float(np.mean(values)))
            mlflow.log_metric(f"trials_strength_{safe_strength}", float(len(values)))
    
    # Log best hyperparameters including batch size info
    for param_name, param_value in best_params.items():
        mlflow.log_param(f"best_{param_name}", param_value)
    
    # Log effective batch size separately
    if 'batch_size' in best_params:
        mlflow.log_param("best_effective_batch_size", best_params['batch_size'] * num_workers)
    
    # Now safe to log metrics (all values are guaranteed to be numeric)
    mlflow.log_metrics(optimization_results)
    mlflow.log_metrics(augmentation_results)
    
    # Log trial statistics with proper type conversion
    if trial_values:
        trial_stats = {
            "mean_trial_accuracy": float(np.mean(trial_values)),
            "std_trial_accuracy": float(np.std(trial_values)),
            "min_trial_accuracy": float(np.min(trial_values)),
            "max_trial_accuracy": float(np.max(trial_values))
        }
        mlflow.log_metrics(trial_stats)
    
    # Log parameter importance including batch size
    try:
        if len(completed_trials) > 1:
            importance = optuna.importance.get_param_importances(study)
            for param_name, importance_value in importance.items():
                mlflow.log_metric(f"param_importance_{param_name}", float(importance_value))
            
            most_important = max(importance.items(), key=lambda x: x[1])
            mlflow.log_param("most_important_param", str(most_important[0]))
            mlflow.log_metric("most_important_param_score", float(most_important[1]))
            
            # Check batch size importance
            batch_size_importance = importance.get('batch_size', 0.0)
            if batch_size_importance > 0:
                mlflow.log_metric("batch_size_importance", float(batch_size_importance))
                print(f"Batch size parameter importance: {batch_size_importance:.4f}")
            
            # Check if augmentation parameters are important
            aug_params = ['use_augmentation', 'augmentation_strength']
            aug_importance = {k: v for k, v in importance.items() if k in aug_params}
            if aug_importance:
                total_aug_importance = sum(aug_importance.values())
                mlflow.log_metric("total_augmentation_importance", float(total_aug_importance))
                print(f"Augmentation parameter importance: {total_aug_importance:.4f}")
                
    except Exception as e:
        mlflow.log_param("importance_error", str(e))
    
    # Set final tags including large batch and augmentation info
    best_aug_strategy = "no_augmentation"
    if best_params.get('use_augmentation', False):
        best_aug_strategy = f"aug_{best_params.get('augmentation_strength', 'unknown')}"
    
    # Ensure augmentation_effective is numeric before using in tags
    aug_effective_value = augmentation_results.get("augmentation_effective", -1.0)
    if isinstance(aug_effective_value, str):
        aug_effective_str = aug_effective_value
    else:
        aug_effective_str = "true" if aug_effective_value > 0 else "false"
    
    mlflow.set_tags({
        "optimization_type": "large_batch_with_augmentation",
        "model_architecture": "mobilenetv2",
        "best_optimizer": best_params.get('optimizer', 'unknown'),
        "best_batch_size": str(best_params.get('batch_size', 'unknown')),
        "best_effective_batch_size": str(best_params.get('batch_size', 0) * num_workers),
        "best_augmentation_strategy": best_aug_strategy,
        "augmentation_in_best": str(best_params.get('use_augmentation', False)),
        "large_batch_optimization": "true",
        "memory_optimized": "true",
        "optimization_status": "completed",
        "total_trials": str(len(study.trials)),
        "best_accuracy": f"{best_value:.4f}",
        "augmentation_effective": aug_effective_str
    })

print(f"\nOptimization completed in {optimization_duration:.2f} seconds")
print(f"Parent run ID: {parent_run_id}")

# Print final summary including batch size info
print(f"\n OPTIMIZATION SUMMARY:")
print(f"Best batch size: {best_params.get('batch_size', 'N/A')}")
print(f"Effective batch size: {best_params.get('batch_size', 0) * num_workers}")
print(f"Best optimizer: {best_params.get('optimizer', 'N/A')}")

if best_params.get('use_augmentation', False):
    print(f"Best trial used augmentation: {best_params.get('augmentation_strength', 'unknown')}")
else:
    print("Best trial did not use augmentation")

if 'augmentation_improvement' in augmentation_results:
    improvement = augmentation_results['augmentation_improvement']
    if improvement > 0:
        print(f"Augmentation provided {improvement:.4f} average improvement")
    else:
        print(f"Augmentation showed {improvement:.4f} average impact (negative)")


In [0]:
### NOTES: 
# - `{top-1} validation accuracy` is typically standard for Tiny ImageNet classification. Here we are using validation accuracy for this example. If your use case requires a different metric (e.g., `top-5 accuracy` or `F1`), do feel free to modify the objective_function to return and log that metric instead.

# - `Encountering py4j errors and failures` when scaling to more workers (e.g. 4 that the compute can scale up to): these are likely due to several factors specific to distributed PyTorch with Optuna on Databricks:

#     1. **Resource Contention**: Each worker creates its own MDS cache directories and data loaders. With 4 workers, you're creating 4x the I/O operations, leading to:
#         - File system contention on `/local_disk0/tmp/mds_cache_*` directories
#         - Network bandwidth saturation when accessing MDS data
#         - Memory pressure from multiple concurrent data loading processes

#     2. **Spark Context Conflicts**: Optuna trials run in parallel, and each trial attempts to initialize TorchDistributor, which can cause Spark context conflicts when multiple trials try to claim the same executor resources.

#     3. **Port/Resource Binding**: Distributed PyTorch requires specific ports for inter-process communication. With more workers, port conflicts become more likely.

# - **Some Heuristics**        
#     1. **Start Conservative**: Begin with 2 workers, validate stability, then scale incrementally    
#     2. **Monitor Resources**: Track CPU, memory, and I/O utilization during trials    
#     3. **Use Pruning**: Implement early stopping to free resources for promising trials    
#     4. **Stagger Trial Starts**: Add small delays between trial initiations to reduce resource conflicts
#     5. **implementing proper resource isolation and coordination strategies.**
