This is a reference example of how one could go about auto-detecting and configurating resources -- but this requires some testing wrt your model, data, and access to compute etc. 

In [0]:
# ============================================================
# CLUSTER RESOURCE DETECTION & AUTO-CONFIGURATION
# ============================================================
import torch
from pyspark.sql import SparkSession
from typing import Dict

def detect_and_configure_resources() -> Dict:
    """Detect cluster resources and return optimized training configuration."""
    config = {
        'num_gpus_per_node': 0,
        'num_workers': 0,
        'total_nodes': 1,
        'total_gpus': 0,
        'use_gpu': False,
        'cuda_available': torch.cuda.is_available(),
        'recommended_batch_size': 32, # heuristics 
        'recommended_num_workers': 4, # 
        'distributor_num_processes': 1,
        'distributor_local_mode': True
    }
    
    # GPU detection
    if torch.cuda.is_available():
        config['num_gpus_per_node'] = torch.cuda.device_count()
        config['use_gpu'] = True
        
        # Recommend batch size based on GPU memory
        try:
            props = torch.cuda.get_device_properties(0)
            memory_gb = props.total_memory / (1024**3)
            
            if memory_gb >= 40:    # A100/H100
                config['recommended_batch_size'] = 128
            elif memory_gb >= 24:  # A10/RTX 4090
                config['recommended_batch_size'] = 64
            elif memory_gb >= 16:  # V100/T4
                config['recommended_batch_size'] = 32
            else:
                config['recommended_batch_size'] = 16
        except:
            config['recommended_batch_size'] = 32
    
    # Cluster size detection
    try:
        spark = SparkSession.builder.getOrCreate()
        sc = spark.sparkContext
        
        executors = sc._jsc.sc().statusTracker().getExecutorInfos()
        config['num_workers'] = len(executors) - 1
        config['total_nodes'] = config['num_workers'] + 1
        
        # Recommend dataloader workers based on cluster size
        # Heuristics: 2-[4]-8 workers per GPU, capped at 8
        if config['use_gpu']:
            config['recommended_num_workers'] = min(4 * config['num_gpus_per_node'], 8)
        else:
            config['recommended_num_workers'] = 4
            
    except Exception as e:
        print(f"Warning: Could not detect cluster size: {e}")
        config['num_workers'] = 0
        config['total_nodes'] = 1
        config['recommended_num_workers'] = 4
    
    # Calculate total GPUs and distributor config
    config['total_gpus'] = config['num_gpus_per_node'] * config['total_nodes']
    
    if config['use_gpu']:
        config['distributor_num_processes'] = config['total_gpus']
        config['distributor_local_mode'] = config['total_nodes'] == 1
    else:
        config['distributor_num_processes'] = max(config['num_workers'], 1)
        config['distributor_local_mode'] = config['num_workers'] == 0
    
    return config

# Detect resources
CLUSTER_RESOURCES = detect_and_configure_resources()
CLUSTER_RESOURCES

In [0]:
# ============================================================
# BATCH SIZE AND NUM_WORKERS DETERMINATION LOGIC
# ============================================================
from typing import Dict
import torch
import numpy as np
from pyspark.sql import SparkSession

def detect_and_configure_resources() -> Dict:
    """Detect cluster resources and return optimized training configuration."""
    config = {
        'num_gpus_per_node': 0,
        'num_workers': 0,
        'total_nodes': 1,
        'total_gpus': 0,
        'use_gpu': False,
        'cuda_available': torch.cuda.is_available(),
        'recommended_batch_size': 32,  # fallback
        'recommended_num_workers': 4,  # fallback
        'distributor_num_processes': 1,
        'distributor_local_mode': True,
        'gpu_memory_gb': 0.0,
        'gpu_name': None
    }
    
    # ============================================================
    # BATCH SIZE DETERMINATION - Based on GPU Memory
    # ============================================================
    if torch.cuda.is_available():
        config['num_gpus_per_node'] = torch.cuda.device_count()
        config['use_gpu'] = True
        
        try:
            # Get GPU properties
            props = torch.cuda.get_device_properties(0)
            memory_gb = props.total_memory / (1024**3)
            gpu_name = props.name
            
            config['gpu_memory_gb'] = round(memory_gb, 2)
            config['gpu_name'] = gpu_name
            
            # Batch size recommendations based on empirical testing
            # These are ~conservative estimates for image classification with ResNet/MobileNet
            if memory_gb >= 80:      # H100 (80GB)
                config['recommended_batch_size'] = 256
            elif memory_gb >= 40:    # A100 (40GB/80GB)
                config['recommended_batch_size'] = 128
            elif memory_gb >= 24:    # A10G (24GB), RTX 4090 (24GB)
                config['recommended_batch_size'] = 64
            elif memory_gb >= 16:    # V100 (16GB), T4 (16GB)
                config['recommended_batch_size'] = 32
            elif memory_gb >= 12:    # T4 (12GB variant)
                config['recommended_batch_size'] = 24
            elif memory_gb >= 8:     # RTX 2080 (8GB)
                config['recommended_batch_size'] = 16
            else:                    # Smaller GPUs
                config['recommended_batch_size'] = 8
                
            print(f"  GPU: {gpu_name} ({memory_gb:.1f}GB)")
            print(f"  Recommended batch size: {config['recommended_batch_size']}")
            
        except Exception as e:
            print(f"  Warning: Could not detect GPU memory: {e}")
            config['recommended_batch_size'] = 32  # Safe default
    else:
        # CPU-only training - use smaller batch size
        config['recommended_batch_size'] = 16
        print(f"  CPU-only mode: Using batch size {config['recommended_batch_size']}")
    
    # ============================================================
    # NUM_WORKERS DETERMINATION - Based on CPU cores and GPUs
    # ============================================================
    try:
        spark = SparkSession.builder.getOrCreate()
        sc = spark.sparkContext
        
        executors = sc._jsc.sc().statusTracker().getExecutorInfos()
        config['num_workers'] = len(executors) - 1
        config['total_nodes'] = config['num_workers'] + 1
        
        # Get CPU core count
        import multiprocessing
        cpu_cores = multiprocessing.cpu_count()
        
        if config['use_gpu']:
            # GPU training: num_workers for DataLoader
            # Rule of thumb: 4 workers per GPU, but cap based on CPU cores
            # Leave some cores for system/other processes
            
            workers_per_gpu = 4  # Standard/General recommendation
            total_workers_needed = workers_per_gpu * config['num_gpus_per_node']
            
            # Cap at 75% of available cores to avoid oversubscription
            max_workers = int(cpu_cores * 0.75)
            
            # Also cap at 8 per GPU (diminishing returns beyond this)
            absolute_max = 8 * config['num_gpus_per_node']
            
            config['recommended_num_workers'] = min(
                total_workers_needed,
                max_workers,
                absolute_max
            )
            
            # Ensure at least 2 workers per GPU
            config['recommended_num_workers'] = max(
                config['recommended_num_workers'],
                2 * config['num_gpus_per_node']
            )
            
            print(f"  CPU cores: {cpu_cores}")
            print(f"  GPUs per node: {config['num_gpus_per_node']}")
            print(f"  Recommended DataLoader workers: {config['recommended_num_workers']}")
            print(f"    (4 workers/GPU, capped at {max_workers} based on {cpu_cores} cores)")
            
        else:
            # CPU training: fewer workers to avoid thread contention
            # Use 50% of cores for data loading, rest for computation
            config['recommended_num_workers'] = max(2, int(cpu_cores * 0.5))
            
            print(f"  CPU cores: {cpu_cores}")
            print(f"  Recommended DataLoader workers: {config['recommended_num_workers']}")
            
    except Exception as e:
        print(f"  Warning: Could not detect cluster configuration: {e}")
        # Fallback defaults
        if config['use_gpu']:
            config['recommended_num_workers'] = 4 * config['num_gpus_per_node']
        else:
            config['recommended_num_workers'] = 4
    
    # ============================================================
    # ADDITIONAL OPTIMIZATIONS
    # ============================================================
    
    # Calculate total GPUs and distributor config
    config['total_gpus'] = config['num_gpus_per_node'] * config['total_nodes']
    
    if config['use_gpu']:
        config['distributor_num_processes'] = config['total_gpus']
        config['distributor_local_mode'] = config['total_nodes'] == 1
    else:
        config['distributor_num_processes'] = max(config['num_workers'], 1)
        config['distributor_local_mode'] = config['num_workers'] == 0
    
    return config


# ============================================================
# e.g. MODEL-SPECIFIC BATCH SIZE RECOMMENDATIONS
# ============================================================

def get_model_specific_batch_size(
    model_name: str,
    gpu_memory_gb: float,
    image_size: int = 224
) -> int:
    """
    Get batch size recommendation based on specific model architecture.
    
    These are empirically tested values for common models.
    """
    
    # Batch size lookup table: {model: {memory_gb: batch_size}}
    batch_size_table = {
        'mobilenet_v2': {
            8: 32,
            16: 64,
            24: 128,
            40: 256
        },
        'resnet50': {
            8: 16,
            16: 32,
            24: 64,
            40: 128
        },
        'resnet101': {
            8: 8,
            16: 16,
            24: 32,
            40: 64
        },
        'efficientnet_b0': {
            8: 32,
            16: 64,
            24: 128,
            40: 256
        },
        'efficientnet_b4': {
            8: 8,
            16: 16,
            24: 32,
            40: 64
        },
        'vit_base': {
            8: 8,
            16: 16,
            24: 32,
            40: 64
        }
    }
    
    # Adjust for image size (larger images need smaller batches)
    size_multiplier = (224 / image_size) ** 2
    
    if model_name in batch_size_table:
        # Find closest memory tier
        memory_tiers = sorted(batch_size_table[model_name].keys())
        closest_tier = min(memory_tiers, key=lambda x: abs(x - gpu_memory_gb))
        
        base_batch_size = batch_size_table[model_name][closest_tier]
        adjusted_batch_size = int(base_batch_size * size_multiplier)
        
        # Round to nearest power of 2 for efficiency
        return 2 ** int(np.log2(adjusted_batch_size))
    else:
        # Fallback to memory-based estimation
        if gpu_memory_gb >= 40:
            return int(128 * size_multiplier)
        elif gpu_memory_gb >= 24:
            return int(64 * size_multiplier)
        elif gpu_memory_gb >= 16:
            return int(32 * size_multiplier)
        else:
            return int(16 * size_multiplier)


# ============================================================
# USAGE EXAMPLES
# ============================================================

# Basic usage
CLUSTER_RESOURCES = detect_and_configure_resources()

# Advanced: Model-specific batch size
if CLUSTER_RESOURCES['use_gpu']:
    model_batch_size = get_model_specific_batch_size(
        model_name='mobilenet_v2',
        gpu_memory_gb=CLUSTER_RESOURCES['gpu_memory_gb'],
        image_size=224
    )
    print(f"\nModel-specific batch size recommendation: {model_batch_size}")
    
    # Override default if model-specific is available
    CLUSTER_RESOURCES['recommended_batch_size'] = model_batch_size


# ============================================================
# VALIDATION AND WARNINGS
# ============================================================

def validate_configuration(config: Dict) -> None:
    """Validate configuration and print warnings."""
    
    warnings = []
    
    # Check batch size
    if config['use_gpu'] and config['recommended_batch_size'] < 8:
        warnings.append(
            f"Very small batch size ({config['recommended_batch_size']}) - "
            "may lead to unstable training"
        )
    
    # Check num_workers
    if config['recommended_num_workers'] == 0:
        warnings.append(
            "DataLoader workers set to 0 - data loading will be synchronous (slower)"
        )
    
    # Check GPU utilization
    if config['use_gpu']:
        expected_memory_usage = config['recommended_batch_size'] * 4  # Rough estimate in GB
        if expected_memory_usage > config['gpu_memory_gb'] * 0.9:
            warnings.append(
                f"Batch size may be too large - estimated {expected_memory_usage:.1f}GB "
                f"vs {config['gpu_memory_gb']:.1f}GB available"
            )
    
    # Check worker/GPU ratio
    if config['use_gpu'] and config['num_gpus_per_node'] > 0:
        workers_per_gpu = config['recommended_num_workers'] / config['num_gpus_per_node']
        if workers_per_gpu < 2:
            warnings.append(
                f"Only {workers_per_gpu:.1f} workers per GPU - may bottleneck data loading"
            )
        elif workers_per_gpu > 8:
            warnings.append(
                f"{workers_per_gpu:.1f} workers per GPU - may cause overhead"
            )
    
    if warnings:
        print(f"\n⚠ Configuration Warnings:")
        for warning in warnings:
            print(f"  • {warning}")
    else:
        print(f"\n✓ Configuration validated successfully")


# Run validation
validate_configuration(CLUSTER_RESOURCES)

In [0]:
CLUSTER_RESOURCES

In [0]:
# ============================================================
# DATASET CONFIGURATION
# ============================================================
CATALOG = "mmt"
SCHEMA = "pytorch"
VOLUME_NAME = "torch_data"

# Dataset paths
mds_train_dir = 'imagenet_tiny200_mds_train'
mds_val_dir = 'imagenet_tiny200_mds_val'
data_storage_location = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

# Dataset parameters
num_classes = 200
num_workers = CLUSTER_RESOURCES['recommended_num_workers']  # AUTO-CONFIGURED

# ============================================================
# TRAINING CONFIGURATION (AUTO-TUNED)
# ============================================================
NUM_EPOCHS = 2
N_TRIALS = 3

# Batch size configuration - will be used as base for Optuna search
BASE_BATCH_SIZE = CLUSTER_RESOURCES['recommended_batch_size']  # AUTO-CONFIGURED
TOTAL_GPUS = CLUSTER_RESOURCES['total_gpus']  # AUTO-CONFIGURED

# Effective batch size across all GPUs
EFFECTIVE_BATCH_SIZE = BASE_BATCH_SIZE * max(TOTAL_GPUS, 1)

# ============================================================
# DISTRIBUTED TRAINING CONFIGURATION
# ============================================================
DISTRIBUTOR_CONFIG = {
    'num_processes': CLUSTER_RESOURCES['distributor_num_processes'],
    'local_mode': CLUSTER_RESOURCES['distributor_local_mode'],
    'use_gpu': CLUSTER_RESOURCES['use_gpu']
}

# ============================================================
# CHECKPOINTING CONFIGURATION
# ============================================================
ENABLE_CHECKPOINTING = True
CHECKPOINT_FREQUENCY = 1
RESUME_FROM_CHECKPOINT = False

# ============================================================
# EXPERIMENT-SPECIFIC PATH CONFIGURATION
# ============================================================
import os
from datetime import datetime
import mlflow

BASE_VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"
BASE_VOLUME_DBFS = f"dbfs:/Volumes/{CATALOG}/{SCHEMA}/{VOLUME_NAME}"

EXPERIMENT_SHORT_NAME = "imagenet_mobilenetv2_hpt_chkpt"
EXPERIMENT_TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

USER_NAME = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
EXPERIMENT_NAME = f"/Users/{USER_NAME}/mlflow_experiments/pytorch_{EXPERIMENT_SHORT_NAME}"

EXPERIMENT_ROOT = f"{BASE_VOLUME_PATH}/{EXPERIMENT_SHORT_NAME}"
EXPERIMENT_ROOT_DBFS = f"{BASE_VOLUME_DBFS}/{EXPERIMENT_SHORT_NAME}"

CHECKPOINT_BASE_DIR = f"{EXPERIMENT_ROOT}/checkpoints/{EXPERIMENT_TIMESTAMP}"
MLFLOW_ARTIFACT_LOCATION = f"{EXPERIMENT_ROOT_DBFS}/mlflow_artifacts"

os.makedirs(CHECKPOINT_BASE_DIR, exist_ok=True)
os.makedirs(f"{EXPERIMENT_ROOT}/mlflow_artifacts", exist_ok=True)

# ============================================================
# MLFLOW SETUP
# ============================================================
mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")
mlflow.enable_system_metrics_logging()

experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)

if experiment:
    experiment_id = experiment.experiment_id
    print(f"✓ Reusing existing MLflow experiment")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Experiment ID: {experiment_id}")
    MLFLOW_ARTIFACT_LOCATION = experiment.artifact_location
else:
    experiment_id = mlflow.create_experiment(
        name=EXPERIMENT_NAME,
        artifact_location=MLFLOW_ARTIFACT_LOCATION
    )
    experiment = mlflow.get_experiment(experiment_id)
    print(f"✓ Created new MLflow experiment")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Experiment ID: {experiment_id}")

mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)

# ============================================================
# CONFIGURATION SUMMARY
# ============================================================
print(f"\n{'='*80}")
print(f"CLUSTER RESOURCE CONFIGURATION")
print(f"{'='*80}")
print(f"Cluster Size:")
print(f"  Total Nodes:          {CLUSTER_RESOURCES['total_nodes']} ({CLUSTER_RESOURCES['num_workers']} workers + 1 driver)")
print(f"  GPUs per Node:        {CLUSTER_RESOURCES['num_gpus_per_node']}")
print(f"  Total GPUs:           {CLUSTER_RESOURCES['total_gpus']}")
print(f"  CUDA Available:       {CLUSTER_RESOURCES['cuda_available']}")
print(f"\nDistributed Training:")
print(f"  num_processes:        {DISTRIBUTOR_CONFIG['num_processes']}")
print(f"  local_mode:           {DISTRIBUTOR_CONFIG['local_mode']}")
print(f"  use_gpu:              {DISTRIBUTOR_CONFIG['use_gpu']}")
print(f"\nAuto-Configured Parameters:")
print(f"  Base Batch Size:      {BASE_BATCH_SIZE} (per GPU)")
print(f"  Effective Batch Size: {EFFECTIVE_BATCH_SIZE} (total across {max(TOTAL_GPUS, 1)} GPUs)")
print(f"  Dataloader Workers:   {num_workers} (per process)")

print(f"\n{'='*80}")
print(f"EXPERIMENT CONFIGURATION")
print(f"{'='*80}")
print(f"Experiment: {EXPERIMENT_SHORT_NAME}")
print(f"Timestamp:  {EXPERIMENT_TIMESTAMP}")
print(f"\nDataset Configuration:")
print(f"  Storage Location: {data_storage_location}")
print(f"  Training Data:    {data_storage_location}/{mds_train_dir}")
print(f"  Validation Data:  {data_storage_location}/{mds_val_dir}")
print(f"  Num Classes:      {num_classes}")
print(f"\nDirectory Structure:")
print(f"  Root: {EXPERIMENT_ROOT}/")
print(f"  ├── checkpoints/")
print(f"  │   └── {EXPERIMENT_TIMESTAMP}/")
print(f"  └── mlflow_artifacts/")
print(f"\nPaths:")
print(f"  Experiment Root:     {EXPERIMENT_ROOT}")
print(f"  Checkpoint Base:     {CHECKPOINT_BASE_DIR}")
print(f"  MLflow Artifacts:    {MLFLOW_ARTIFACT_LOCATION}")
print(f"  MLflow Experiment:   {EXPERIMENT_NAME}")
print(f"  Experiment ID:       {experiment_id}")
print(f"\nTraining Configuration:")
print(f"  Num Classes:      {num_classes}")
print(f"  Num Epochs:       {NUM_EPOCHS}")
print(f"  Num Trials:       {N_TRIALS}")
print(f"  Checkpointing:    {ENABLE_CHECKPOINTING}")
print(f"  Checkpoint Freq:  {CHECKPOINT_FREQUENCY}")
print(f"  Resume Enabled:   {RESUME_FROM_CHECKPOINT}")
print(f"{'='*80}\n")

# ============================================================
# VERIFY PATHS
# ============================================================
print("Verifying paths...")
assert os.path.exists(data_storage_location), f"Data storage location not found: {data_storage_location}"
assert os.path.exists(f"{data_storage_location}/{mds_train_dir}"), f"Training data not found"
assert os.path.exists(f"{data_storage_location}/{mds_val_dir}"), f"Validation data not found"
assert os.path.exists(CHECKPOINT_BASE_DIR), f"Checkpoint directory not created"
assert os.path.exists(f"{EXPERIMENT_ROOT}/mlflow_artifacts"), f"MLflow artifacts directory not created"
print("✓ All paths verified")
print(f"  ✓ Training data: {len(os.listdir(f'{data_storage_location}/{mds_train_dir}'))} files")
print(f"  ✓ Validation data: {len(os.listdir(f'{data_storage_location}/{mds_val_dir}'))} files")
print()

# ============================================================
# OPTUNA HYPERPARAMETER SEARCH SPACE (BATCH SIZE AWARE)
# ============================================================
def get_optuna_search_space(trial):
    """Define search space with batch size scaled to cluster resources."""
    return {
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True),
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5),
        # Batch size suggestions scaled to BASE_BATCH_SIZE
        'batch_size': trial.suggest_categorical('batch_size', 
            [BASE_BATCH_SIZE // 2, BASE_BATCH_SIZE, BASE_BATCH_SIZE * 2]
        ),
        'optimizer': trial.suggest_categorical('optimizer', ['adam', 'adamw', 'sgd'])
    }

# ============================================================
# GLOBAL VARIABLES
# ============================================================
EXPERIMENT_RUN_ID = None

print(f"{'='*80}")
print(f"READY TO START OPTIMIZATION")
print(f"{'='*80}")
print(f"Configuration complete. Run the optimization cell to begin training.")
print(f"Use DISTRIBUTOR_CONFIG for TorchDistributor initialization.")
print(f"Use BASE_BATCH_SIZE as reference for Optuna batch size search.")
print(f"{'='*80}\n")


In [0]:
# **Example Integration -- [!!] subject to data-model-training testing:**

# 1. **Auto-configured `num_workers`**: Set based on GPU count (4 per GPU, capped at 8)
# 2. **Auto-configured `BASE_BATCH_SIZE`**: Based on GPU memory detection
# 3. **`DISTRIBUTOR_CONFIG`**: Ready-to-use dictionary for TorchDistributor
# 4. **`EFFECTIVE_BATCH_SIZE`**: Total batch size across all GPUs for monitoring

# **Usage in your training function:**


In [0]:
DISTRIBUTOR_CONFIG

**NOTES:**

References:   

Batch Size Scaling:  
Goyal et al. (2017) - "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"    
Smith et al. (2017) - "Don't Decay the Learning Rate, Increase the Batch Size"    

Model-Specific:   
He et al. (2015) - ResNet paper (batch 256 across 8 GPUs)    
Sandler et al. (2018) - MobileNetV2 (batch 96 on TPUs)   
Tan & Le (2019) - EfficientNet (batch 2048 on TPUs)   

Practical Implementations:   
PyTorch ImageNet examples: https://github.com/pytorch/examples/tree/main/imagenet   
TorchVision training reference: https://github.com/pytorch/vision/tree/main/references/classification    

**General Heuristics** 
- Start with literature values if available
- Run find_max_batch_size() empirically (most reliable)
- Use Optuna to tune around that value
- Monitor GPU utilization and adjust
