# Distributed Training ViT on A100

ViT-Base (ViT-B/16): ~86M parameters trained on ImageNet-1K with comprehensive GPU profiling

## Hyperparameters

- **Model:** ViT-Base/16 (~86M parameters)
- **Dataset:** ImageNet-1K (1000 classes)
- **Batch size:** 32 per GPU
- **Learning rate:** 3e-4 (AdamW)
- **Warmup steps:** 10,000
- **Total epochs:** 300
- **Image size:** 224x224
- **Patch size:** 16x16
- **Mixed precision:** FP16
- **Weight decay:** 0.3


## ViT-Base Architecture

**Vision Transformer (ViT-B/16) Overview:**
- **Input:** 224x224 RGB images
- **Patch size:** 16x16 (=> 196 patches per image)
- **Embedding dim:** 768
- **Transformer blocks:** 12
- **Attention heads:** 12
- **MLP hidden dim:** 3072
- **Classification head:** Linear layer

**Parameter Calculation:**
- Patch embedding: (16*16*3)*768 + 768 = 590,592
- Position embedding: (196+1)*768 = 151,296 (197 patches + CLS token)
- Each transformer block:
  - Multi-head self-attention: 3*768*768 + 768*768 = 2,359,296
  - MLP: 768*3072 + 3072*768 = 4,718,592
  - LayerNorms: 2*768*2 = 3,072
  - Total per block: ~7.08M
- 12 blocks: ~85M
- Classification head: 768*1000 + 1000 = 769,000
- **Total:** ~86.6M parameters

**Memory Requirement (Training):**
- Model params (fp32): 86.6M * 4B = ~346MB
- Model params (fp16): 86.6M * 2B = ~173MB
- Activations (batch=32, fp16): ~2-3GB
- Gradients (fp32): ~346MB
- Optimizer states (AdamW): ~692MB
- **Total (mixed precision):** ~4-5GB

**Computation per Forward Pass:**
- Patch embedding: 196 * 768 * 3 * 16² = ~590M FLOPs
- Self-attention (per layer): 4 * 197 * 768² = ~467M FLOPs
- MLP (per layer): 2 * 197 * 768 * 3072 = ~926M FLOPs
- Total per layer: ~1.4G FLOPs
- **Total forward:** ~17.5 GFLOPs/image
- **Total (forward+backward):** ~52.5 GFLOPs/image
- For batch=32: ~1.68 TFLOPs/step

**A100 Performance:**
- Theoretical peak (FP16): 312 TFLOPs
- Memory bandwidth: 2TB/s
- Expected utilization: 30-50% (due to memory-bound operations)


## Environment Setup and Profiling Tools Installation

This section automatically detects the environment (Colab vs local) and installs necessary profiling tools.


In [None]:
import os
import sys
import subprocess
import platform

# Detect environment
def is_colab():
    """Check if running in Google Colab"""
    try:
        import google.colab
        return True
    except ImportError:
        return False

def install_package(package):
    """Install package with pip"""
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

def install_apt_package(package):
    """Install apt package (Ubuntu/Debian)"""
    subprocess.check_call(["apt-get", "update", "-qq"])
    subprocess.check_call(["apt-get", "install", "-y", package])

# Environment detection
COLAB = is_colab()
LOCAL = not COLAB

print(f"Environment: {'Colab' if COLAB else 'Local'}")
print(f"Platform: {platform.platform()}")
print(f"Python: {sys.version}")

# Install required packages
required_packages = [
    "torch>=2.0.0",
    "torchvision",
    "tensorboard",
    "wandb",
    "timm",
    "transformers",
    "datasets",
    "pillow",
    "matplotlib",
    "seaborn",
    "nvidia-ml-py3"
]

print("\nInstalling Python packages...")
for package in required_packages:
    try:
        install_package(package)
        print(f"[INSTALLED] {package}")
    except Exception as e:
        print(f"[FAILED] {package}: {e}")

# Colab-specific setup
if COLAB:
    print("\nSetting up Google Colab environment...")
    
    # Install CUDA profiling tools (if available)
    try:
        # Check CUDA version
        cuda_version = subprocess.check_output(["nvcc", "--version"]).decode()
        print(f"CUDA version: {cuda_version.split('release')[1].split(',')[0].strip()}")
        
        # Try to install nsight tools (may not work in standard Colab)
        print("Attempting to install NVIDIA profiling tools...")
        try:
            install_apt_package("wget")
            # Download and install nsight systems (simplified)
            subprocess.run([
                "wget", "-q", 
                "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb"
            ])
            subprocess.run(["dpkg", "-i", "cuda-keyring_1.0-1_all.deb"])
            subprocess.run(["apt-get", "update", "-qq"])
            # Note: Full nsight installation may require custom runtime
            print("WARNING: Full Nsight tools require custom Colab runtime")
        except Exception as e:
            print(f"WARNING: Nsight installation failed (expected in standard Colab): {e}")
    except Exception as e:
        print(f"WARNING: CUDA tools not available: {e}")
    
    # Enable TensorBoard extension
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorboard-plugin-profile"])
        print("[INSTALLED] TensorBoard profiling extension")
    except:
        pass

# Local setup
elif LOCAL:
    print("\nSetting up local environment...")
    
    # Check if running on Linux (required for some tools)
    if platform.system() == "Linux":
        try:
            # Install profiling tools
            print("Installing NVIDIA profiling tools...")
            install_apt_package("cuda-nvprof")
            install_apt_package("cuda-nsight")
            print("[INSTALLED] NVIDIA profiling tools")
        except Exception as e:
            print(f"WARNING: Failed to install NVIDIA tools: {e}")
            print("    Please install manually or run with sudo")
    else:
        print("WARNING: Some profiling tools are Linux-only")

print("\nEnvironment setup complete!")


## ViT-Base Model Implementation

Complete PyTorch implementation of ViT-Base/16 with distributed training support.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torchvision.models import vit_b_16
import math
import time
import json
from pathlib import Path

# Custom ViT implementation (alternative to torchvision)
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (B, C, H, W) -> (B, embed_dim, H/P, W/P) -> (B, embed_dim, N) -> (B, N, embed_dim)
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class MLP(nn.Module):
    def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        x = x[:, 0]  # Use CLS token
        x = self.head(x)
        return x

def count_parameters(model):
    """Count model parameters"""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

# Create model
def create_vit_model(num_classes=1000, use_pretrained=True):
    if use_pretrained:
        # Use torchvision pretrained model
        model = vit_b_16(weights='IMAGENET1K_V1')
        if num_classes != 1000:
            model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    else:
        # Use custom implementation
        model = VisionTransformer(num_classes=num_classes)
    
    total_params, trainable_params = count_parameters(model)
    print(f"Model parameters: {total_params:,} total, {trainable_params:,} trainable")
    print(f"Model size: {total_params * 4 / 1024**2:.1f} MB (fp32)")
    
    return model

# Test the model
if __name__ == "__main__":
    model = create_vit_model(num_classes=1000, use_pretrained=False)
    x = torch.randn(2, 3, 224, 224)
    y = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {y.shape}")


## Profiling and Monitoring Setup

This section sets up comprehensive GPU profiling including PyTorch Profiler, NVIDIA tools, and custom metrics.


In [None]:
import torch
import torch.profiler
import psutil
import GPUtil
import threading
import time
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, deque
import subprocess
import nvidia_ml_py3 as nvml

# Initialize NVIDIA ML
try:
    nvml.nvmlInit()
    print("[INITIALIZED] NVIDIA ML library")
except:
    print("WARNING: NVIDIA ML library not available")

class GPUMonitor:
    """Real-time GPU monitoring"""
    def __init__(self, device_id=0, interval=0.1):
        self.device_id = device_id
        self.interval = interval
        self.monitoring = False
        self.metrics = defaultdict(deque)
        self.thread = None
        
    def start(self):
        """Start monitoring in background thread"""
        self.monitoring = True
        self.thread = threading.Thread(target=self._monitor_loop)
        self.thread.start()
        
    def stop(self):
        """Stop monitoring"""
        self.monitoring = False
        if self.thread:
            self.thread.join()
            
    def _monitor_loop(self):
        """Main monitoring loop"""
        while self.monitoring:
            try:
                # GPU metrics via nvidia-ml-py3
                handle = nvml.nvmlDeviceGetHandleByIndex(self.device_id)
                
                # Memory info
                mem_info = nvml.nvmlDeviceGetMemoryInfo(handle)
                gpu_memory_used = mem_info.used / 1024**3  # GB
                gpu_memory_total = mem_info.total / 1024**3  # GB
                gpu_memory_util = (mem_info.used / mem_info.total) * 100
                
                # GPU utilization
                util = nvml.nvmlDeviceGetUtilizationRates(handle)
                gpu_util = util.gpu
                
                # Temperature
                temp = nvml.nvmlDeviceGetTemperature(handle, nvml.NVML_TEMPERATURE_GPU)
                
                # Power
                power = nvml.nvmlDeviceGetPowerUsage(handle) / 1000.0  # Watts
                
                # Store metrics
                timestamp = time.time()
                self.metrics['timestamp'].append(timestamp)
                self.metrics['gpu_util'].append(gpu_util)
                self.metrics['gpu_memory_used'].append(gpu_memory_used)
                self.metrics['gpu_memory_util'].append(gpu_memory_util)
                self.metrics['temperature'].append(temp)
                self.metrics['power'].append(power)
                
                # Keep only last 1000 samples
                for key in self.metrics:
                    if len(self.metrics[key]) > 1000:
                        self.metrics[key].popleft()
                        
            except Exception as e:
                print(f"Monitoring error: {e}")
                
            time.sleep(self.interval)
            
    def get_stats(self):
        """Get current statistics"""
        if not self.metrics['gpu_util']:
            return {}
            
        return {
            'gpu_util_avg': sum(self.metrics['gpu_util']) / len(self.metrics['gpu_util']),
            'gpu_util_max': max(self.metrics['gpu_util']),
            'gpu_memory_max': max(self.metrics['gpu_memory_used']),
            'gpu_memory_avg': sum(self.metrics['gpu_memory_used']) / len(self.metrics['gpu_memory_used']),
            'temperature_max': max(self.metrics['temperature']),
            'power_avg': sum(self.metrics['power']) / len(self.metrics['power']),
        }
        
    def plot_metrics(self, save_path=None):
        """Plot monitoring metrics"""
        if not self.metrics['timestamp']:
            print("No metrics to plot")
            return
            
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        timestamps = list(self.metrics['timestamp'])
        start_time = timestamps[0]
        times = [(t - start_time) for t in timestamps]
        
        # GPU Utilization
        axes[0, 0].plot(times, list(self.metrics['gpu_util']))
        axes[0, 0].set_title('GPU Utilization (%)')
        axes[0, 0].set_ylabel('Utilization %')
        
        # Memory Usage
        axes[0, 1].plot(times, list(self.metrics['gpu_memory_used']), label='Used')
        axes[0, 1].set_title('GPU Memory Usage (GB)')
        axes[0, 1].set_ylabel('Memory (GB)')
        
        # Temperature
        axes[1, 0].plot(times, list(self.metrics['temperature']))
        axes[1, 0].set_title('GPU Temperature (°C)')
        axes[1, 0].set_ylabel('Temperature (°C)')
        
        # Power
        axes[1, 1].plot(times, list(self.metrics['power']))
        axes[1, 1].set_title('GPU Power Usage (W)')
        axes[1, 1].set_ylabel('Power (W)')
        
        for ax in axes.flat:
            ax.set_xlabel('Time (s)')
            ax.grid(True)
            
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

class TorchProfiler:
    """PyTorch profiler wrapper"""
    def __init__(self, log_dir="./profiler_logs", trace_handler=None):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        if trace_handler is None:
            trace_handler = torch.profiler.tensorboard_trace_handler(str(self.log_dir))
            
        self.profiler = torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
            on_trace_ready=trace_handler,
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
            with_flops=True,
            with_modules=True
        )
        
    def __enter__(self):
        self.profiler.__enter__()
        return self.profiler
        
    def __exit__(self, *args):
        self.profiler.__exit__(*args)

def profile_model_flops(model, input_tensor):
    """Profile model FLOPs using torch.profiler"""
    model.eval()
    with torch.profiler.profile(with_flops=True) as prof:
        with torch.no_grad():
            _ = model(input_tensor)
    
    # Calculate total FLOPs
    total_flops = sum([item.flops for item in prof.key_averages()])
    return total_flops

def measure_memory_usage(model, input_tensor, device):
    """Measure peak memory usage"""
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    
    model.train()
    output = model(input_tensor)
    loss = output.sum()
    loss.backward()
    
    peak_memory = torch.cuda.max_memory_allocated(device) / 1024**3  # GB
    return peak_memory

def benchmark_throughput(model, input_shape, device, batch_sizes=[1, 8, 16, 32], num_runs=10):
    """Benchmark model throughput"""
    results = {}
    model.eval()
    
    for batch_size in batch_sizes:
        input_tensor = torch.randn(batch_size, *input_shape, device=device)
        
        # Warmup
        for _ in range(5):
            with torch.no_grad():
                _ = model(input_tensor)
        
        torch.cuda.synchronize()
        
        # Measure
        times = []
        for _ in range(num_runs):
            start = time.time()
            with torch.no_grad():
                _ = model(input_tensor)
            torch.cuda.synchronize()
            times.append(time.time() - start)
        
        avg_time = sum(times) / len(times)
        throughput = batch_size / avg_time
        
        results[batch_size] = {
            'avg_time': avg_time,
            'throughput': throughput,
            'images_per_second': throughput
        }
        
        print(f"Batch size {batch_size}: {throughput:.1f} images/sec")
    
    return results

# Test GPU monitoring
print("Testing GPU monitoring...")
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    print(f"Using GPU {device}: {torch.cuda.get_device_name(device)}")
    
    # Quick GPU info
    memory_total = torch.cuda.get_device_properties(device).total_memory / 1024**3
    print(f"Total GPU memory: {memory_total:.1f} GB")
else:
    print("CUDA not available")


## Training without any optimizations

Basic ViT training loop with comprehensive profiling and monitoring.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import json
from pathlib import Path

def get_imagenet_loaders(data_path="/tmp/imagenet", batch_size=32, num_workers=4):
    """Create ImageNet data loaders (using FakeData for demo)"""
    
    # ImageNet transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Use FakeData for demo (replace with real ImageNet)
    train_dataset = datasets.FakeData(
        size=1000, image_size=(3, 224, 224), 
        num_classes=1000, transform=train_transform
    )
    val_dataset = datasets.FakeData(
        size=200, image_size=(3, 224, 224), 
        num_classes=1000, transform=val_transform
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, 
        shuffle=True, num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, 
        shuffle=False, num_workers=num_workers, pin_memory=True
    )
    
    return train_loader, val_loader

def calculate_accuracy(outputs, targets, topk=(1, 5)):
    """Calculate top-k accuracy"""
    maxk = max(topk)
    batch_size = targets.size(0)
    
    _, pred = outputs.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))
    
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train_epoch(model, train_loader, optimizer, criterion, device, epoch, 
                profiler=None, gpu_monitor=None):
    """Train for one epoch"""
    model.train()
    
    # Metrics
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    end = time.time()
    
    for i, (images, targets) in enumerate(train_loader):
        # Measure data loading time
        data_time.update(time.time() - end)
        
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, targets)
        
        # Measure accuracy
        acc1, acc5 = calculate_accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Profiler step
        if profiler:
            profiler.step()
        
        # Log progress
        if i % 10 == 0:
            print(f'Epoch: [{epoch}][{i}/{len(train_loader)}] '
                  f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                  f'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  f'Loss {losses.val:.4f} ({losses.avg:.4f}) '
                  f'Acc@1 {top1.val:.3f} ({top1.avg:.3f}) '
                  f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})')
            
            # GPU stats
            if gpu_monitor:
                stats = gpu_monitor.get_stats()
                if stats:
                    print(f'GPU: {stats.get("gpu_util_avg", 0):.1f}% util, '
                          f'{stats.get("gpu_memory_avg", 0):.1f}GB mem, '
                          f'{stats.get("temperature_max", 0):.0f}°C, '
                          f'{stats.get("power_avg", 0):.0f}W')
        
        # Limit iterations for demo
        if i >= 50:
            break
    
    return {
        'loss': losses.avg,
        'top1': top1.avg,
        'top5': top5.avg,
        'batch_time': batch_time.avg,
        'data_time': data_time.avg
    }

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    with torch.no_grad():
        for i, (images, targets) in enumerate(val_loader):
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            acc1, acc5 = calculate_accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            
            # Limit for demo
            if i >= 10:
                break
    
    print(f'Validation: Loss {losses.avg:.4f} Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}')
    
    return {
        'loss': losses.avg,
        'top1': top1.avg,
        'top5': top5.avg
    }

def main_training_basic():
    """Main training function without optimizations"""
    print("Starting basic ViT training...")
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Model
    model = create_vit_model(num_classes=1000, use_pretrained=False)
    model = model.to(device)
    
    # Data
    batch_size = 32 if torch.cuda.is_available() else 8
    train_loader, val_loader = get_imagenet_loaders(batch_size=batch_size)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.3)
    
    # Profiling setup
    gpu_monitor = None
    if torch.cuda.is_available():
        gpu_monitor = GPUMonitor(device_id=device.index)
        gpu_monitor.start()
    
    # Benchmark model
    print(\"\\nModel benchmarking...\")\n    with torch.no_grad():\n        dummy_input = torch.randn(1, 3, 224, 224, device=device)\n        \n        # FLOPs\n        total_flops = profile_model_flops(model, dummy_input)\n        print(f\"Model FLOPs: {total_flops / 1e9:.2f} GFLOPs\")\n        \n        # Memory\n        peak_memory = measure_memory_usage(model, dummy_input, device)\n        print(f\"Peak memory: {peak_memory:.2f} GB\")\n        \n        # Throughput\n        if torch.cuda.is_available():\n            throughput_results = benchmark_throughput(model, (3, 224, 224), device)\n    \n    # Training with profiling\n    results = []\n    \n    with TorchProfiler(log_dir=\"./logs/basic_training\") as profiler:\n        for epoch in range(2):  # Limited epochs for demo\n            print(f\"\\n=== Epoch {epoch + 1} ===\")\n            \n            # Train\n            train_metrics = train_epoch(\n                model, train_loader, optimizer, criterion, device, \n                epoch, profiler, gpu_monitor\n            )\n            \n            # Validate\n            val_metrics = validate(model, val_loader, criterion, device)\n            \n            # Store results\n            epoch_results = {\n                'epoch': epoch + 1,\n                'train': train_metrics,\n                'val': val_metrics\n            }\n            \n            if gpu_monitor:\n                epoch_results['gpu_stats'] = gpu_monitor.get_stats()\n            \n            results.append(epoch_results)\n            \n            # Save checkpoint\n            checkpoint = {\n                'epoch': epoch + 1,\n                'model_state_dict': model.state_dict(),\n                'optimizer_state_dict': optimizer.state_dict(),\n                'train_loss': train_metrics['loss'],\n                'val_loss': val_metrics['loss'],\n            }\n            torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')\n    \n    # Cleanup\n    if gpu_monitor:\n        gpu_monitor.stop()\n        gpu_monitor.plot_metrics(save_path='gpu_metrics_basic.png')\n    \n    # Save results\n    with open('training_results_basic.json', 'w') as f:\n        json.dump(results, f, indent=2)\n    \n    print(\"\\nBasic training completed!\")\n    print(\"Check the following files:\")\n    print(\"   - training_results_basic.json (metrics)\")\n    print(\"   - gpu_metrics_basic.png (GPU monitoring)\")\n    print(\"   - ./logs/basic_training/ (PyTorch profiler traces)\")\n    print(\"   - checkpoint_epoch_*.pth (model checkpoints)\")\n    \n    return results\n\n# Run basic training\nif __name__ == \"__main__\":\n    results = main_training_basic()"


## DeepSpeed ZeRO Stage 1

ZeRO-1 partitions optimizer states across GPUs, reducing memory usage while maintaining performance.


In [None]:
# Install DeepSpeed if not available
try:
    import deepspeed
    print("[AVAILABLE] DeepSpeed")
except ImportError:
    print("Installing DeepSpeed...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "deepspeed"])
    import deepspeed

import torch
import torch.distributed as dist
import os
import json

# DeepSpeed ZeRO-1 Configuration
def get_deepspeed_config_zero1():
    """DeepSpeed configuration for ZeRO Stage 1"""
    return {
        "train_batch_size": 32,
        "train_micro_batch_size_per_gpu": 8,
        "gradient_accumulation_steps": 1,
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": 3e-4,
                "weight_decay": 0.3,
                "bias_correction": True
            }
        },
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": 3e-4,
                "warmup_num_steps": 1000
            }
        },
        "zero_optimization": {
            "stage": 1,
            "allgather_partitions": True,
            "allgather_bucket_size": 2e8,
            "overlap_comm": True,
            "reduce_scatter": True,
            "reduce_bucket_size": 2e8,
            "contiguous_gradients": True
        },
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "loss_scale_window": 1000,
            "initial_scale_power": 16,
            "hysteresis": 2,
            "min_loss_scale": 1
        },
        "gradient_clipping": 1.0,
        "wall_clock_breakdown": True
    }

def setup_deepspeed_distributed():
    """Setup distributed training for DeepSpeed"""
    if 'RANK' not in os.environ:
        # Single GPU setup
        os.environ['RANK'] = '0'
        os.environ['WORLD_SIZE'] = '1'
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        os.environ['LOCAL_RANK'] = '0'
    
    # Initialize distributed
    if not dist.is_initialized():
        dist.init_process_group(backend='nccl')
    
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    
    return local_rank

def train_with_deepspeed_zero1():
    """Train ViT with DeepSpeed ZeRO-1"""
    print("Starting DeepSpeed ZeRO-1 training...")
    
    # Setup distributed
    local_rank = setup_deepspeed_distributed()
    device = torch.device(f'cuda:{local_rank}')
    
    # Model
    model = create_vit_model(num_classes=1000, use_pretrained=False)
    
    # Data loaders
    train_loader, val_loader = get_imagenet_loaders(batch_size=8)  # Micro batch size
    
    # DeepSpeed config
    ds_config = get_deepspeed_config_zero1()
    
    # Initialize DeepSpeed
    model_engine, optimizer, train_loader, _ = deepspeed.initialize(
        args=None,
        model=model,
        model_parameters=model.parameters(),
        training_data=train_loader.dataset,
        config=ds_config
    )
    
    # Profiling setup
    gpu_monitor = GPUMonitor(device_id=local_rank)
    gpu_monitor.start()
    
    criterion = torch.nn.CrossEntropyLoss()
    results = []
    
    # Training loop
    with TorchProfiler(log_dir=\"./logs/deepspeed_zero1\") as profiler:\n        for epoch in range(2):\n            print(f\"\\n=== DeepSpeed ZeRO-1 Epoch {epoch + 1} ===\")\n            \n            model_engine.train()\n            losses = AverageMeter()\n            top1 = AverageMeter()\n            \n            for i, (images, targets) in enumerate(train_loader):\n                images = images.to(device, non_blocking=True)\n                targets = targets.to(device, non_blocking=True)\n                \n                # Forward pass\n                outputs = model_engine(images)\n                loss = criterion(outputs, targets)\n                \n                # Backward pass\n                model_engine.backward(loss)\n                model_engine.step()\n                \n                # Metrics\n                acc1, _ = calculate_accuracy(outputs, targets, topk=(1, 5))\n                losses.update(loss.item(), images.size(0))\n                top1.update(acc1.item(), images.size(0))\n                \n                if profiler:\n                    profiler.step()\n                \n                if i % 10 == 0 and local_rank == 0:\n                    print(f'Epoch: [{epoch}][{i}/{len(train_loader)}] '\n                          f'Loss {losses.val:.4f} ({losses.avg:.4f}) '\n                          f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})')\n                    \n                    # GPU stats\n                    stats = gpu_monitor.get_stats()\n                    if stats:\n                        print(f'GPU: {stats.get(\"gpu_util_avg\", 0):.1f}% util, '\n                              f'{stats.get(\"gpu_memory_avg\", 0):.1f}GB mem')\n                \n                if i >= 50:  # Limit for demo\n                    break\n            \n            # Validation\n            if local_rank == 0:\n                val_metrics = validate(model_engine, val_loader, criterion, device)\n                \n                epoch_results = {\n                    'epoch': epoch + 1,\n                    'method': 'DeepSpeed ZeRO-1',\n                    'train_loss': losses.avg,\n                    'train_acc1': top1.avg,\n                    'val_loss': val_metrics['loss'],\n                    'val_acc1': val_metrics['top1'],\n                    'gpu_stats': gpu_monitor.get_stats()\n                }\n                results.append(epoch_results)\n    \n    # Cleanup\n    gpu_monitor.stop()\n    if local_rank == 0:\n        gpu_monitor.plot_metrics(save_path='gpu_metrics_zero1.png')\n        \n        with open('training_results_zero1.json', 'w') as f:\n            json.dump(results, f, indent=2)\n    \n    # Memory stats\n    if local_rank == 0:\n        print(\"\\nDeepSpeed ZeRO-1 Memory Stats:\")\n        print(f\"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M\")\n        print(f\"Peak GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n        \n        # DeepSpeed memory breakdown\n        memory_breakdown = model_engine.memory_breakdown()\n        for key, value in memory_breakdown.items():\n            if isinstance(value, (int, float)):\n                print(f\"{key}: {value / 1024**3:.2f} GB\")\n    \n    print(\"\\nDeepSpeed ZeRO-1 training completed!\")\n    return results\n\n# Run if this cell is executed\nif __name__ == \"__main__\":\n    results = train_with_deepspeed_zero1()"


## DeepSpeed ZeRO Stage 2

ZeRO-2 partitions both optimizer states and gradients, further reducing memory usage.


In [None]:
def get_deepspeed_config_zero2():
    """DeepSpeed configuration for ZeRO Stage 2"""
    config = get_deepspeed_config_zero1()  # Start with ZeRO-1 config
    
    # Update to ZeRO Stage 2
    config["zero_optimization"]["stage"] = 2
    config["zero_optimization"]["cpu_offload"] = False  # Keep on GPU for A100
    
    return config

def train_with_deepspeed_zero2():
    """Train ViT with DeepSpeed ZeRO-2"""
    print("Starting DeepSpeed ZeRO-2 training...")
    
    # Setup distributed
    local_rank = setup_deepspeed_distributed()
    device = torch.device(f'cuda:{local_rank}')
    
    # Model
    model = create_vit_model(num_classes=1000, use_pretrained=False)
    
    # Data loaders
    train_loader, val_loader = get_imagenet_loaders(batch_size=8)
    
    # DeepSpeed config
    ds_config = get_deepspeed_config_zero2()
    
    # Initialize DeepSpeed
    model_engine, optimizer, train_loader, _ = deepspeed.initialize(
        args=None,
        model=model,
        model_parameters=model.parameters(),
        training_data=train_loader.dataset,
        config=ds_config
    )
    
    # Profiling setup
    gpu_monitor = GPUMonitor(device_id=local_rank)
    gpu_monitor.start()
    
    criterion = torch.nn.CrossEntropyLoss()
    results = []
    
    # Training loop
    with TorchProfiler(log_dir="./logs/deepspeed_zero2") as profiler:
        for epoch in range(2):
            print(f"\n=== DeepSpeed ZeRO-2 Epoch {epoch + 1} ===")
            
            model_engine.train()
            losses = AverageMeter()
            top1 = AverageMeter()
            
            for i, (images, targets) in enumerate(train_loader):
                images = images.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                
                # Forward pass
                outputs = model_engine(images)
                loss = criterion(outputs, targets)
                
                # Backward pass
                model_engine.backward(loss)
                model_engine.step()
                
                # Metrics
                acc1, _ = calculate_accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1.item(), images.size(0))
                
                if profiler:
                    profiler.step()
                
                if i % 10 == 0 and local_rank == 0:
                    print(f'Epoch: [{epoch}][{i}/{len(train_loader)}] '
                          f'Loss {losses.val:.4f} ({losses.avg:.4f}) '
                          f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})')
                    
                    # GPU stats
                    stats = gpu_monitor.get_stats()
                    if stats:
                        print(f'GPU: {stats.get("gpu_util_avg", 0):.1f}% util, '
                              f'{stats.get("gpu_memory_avg", 0):.1f}GB mem')
                
                if i >= 50:  # Limit for demo
                    break
            
            # Validation
            if local_rank == 0:
                val_metrics = validate(model_engine, val_loader, criterion, device)
                
                epoch_results = {
                    'epoch': epoch + 1,
                    'method': 'DeepSpeed ZeRO-2',
                    'train_loss': losses.avg,
                    'train_acc1': top1.avg,
                    'val_loss': val_metrics['loss'],
                    'val_acc1': val_metrics['top1'],
                    'gpu_stats': gpu_monitor.get_stats()
                }
                results.append(epoch_results)
    
    # Cleanup
    gpu_monitor.stop()
    if local_rank == 0:
        gpu_monitor.plot_metrics(save_path='gpu_metrics_zero2.png')
        
        with open('training_results_zero2.json', 'w') as f:
            json.dump(results, f, indent=2)
    
    # Memory stats
    if local_rank == 0:
        print("\nDeepSpeed ZeRO-2 Memory Stats:")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
        print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        
        # DeepSpeed memory breakdown
        memory_breakdown = model_engine.memory_breakdown()
        for key, value in memory_breakdown.items():
            if isinstance(value, (int, float)):
                print(f"{key}: {value / 1024**3:.2f} GB")
    
    print("\nDeepSpeed ZeRO-2 training completed!")
    return results

# Run if this cell is executed
if __name__ == "__main__":
    results = train_with_deepspeed_zero2()


## DeepSpeed ZeRO Stage 3

ZeRO-3 partitions model parameters, gradients, and optimizer states for maximum memory efficiency.


In [None]:
def get_deepspeed_config_zero3():
    """DeepSpeed configuration for ZeRO Stage 3"""
    config = get_deepspeed_config_zero1()  # Start with ZeRO-1 config
    
    # Update to ZeRO Stage 3
    config["zero_optimization"]["stage"] = 3
    config["zero_optimization"]["stage3_max_live_parameters"] = 1e9
    config["zero_optimization"]["stage3_max_reuse_distance"] = 1e9
    config["zero_optimization"]["stage3_prefetch_bucket_size"] = 5e8
    config["zero_optimization"]["stage3_param_persistence_threshold"] = 1e6
    config["zero_optimization"]["sub_group_size"] = 1e9
    config["zero_optimization"]["reduce_bucket_size"] = 5e8
    
    # CPU offload for extreme memory savings (optional)
    config["zero_optimization"]["offload_optimizer"] = {
        "device": "cpu",
        "pin_memory": True
    }
    config["zero_optimization"]["offload_param"] = {
        "device": "cpu",
        "pin_memory": True
    }
    
    return config

def train_with_deepspeed_zero3():
    """Train ViT with DeepSpeed ZeRO-3"""
    print("Starting DeepSpeed ZeRO-3 training...")
    
    # Setup distributed
    local_rank = setup_deepspeed_distributed()
    device = torch.device(f'cuda:{local_rank}')
    
    # Model
    model = create_vit_model(num_classes=1000, use_pretrained=False)
    
    # Data loaders
    train_loader, val_loader = get_imagenet_loaders(batch_size=16)  # Can use larger batch with ZeRO-3
    
    # DeepSpeed config
    ds_config = get_deepspeed_config_zero3()
    
    # Initialize DeepSpeed
    model_engine, optimizer, train_loader, _ = deepspeed.initialize(
        args=None,
        model=model,
        model_parameters=model.parameters(),
        training_data=train_loader.dataset,
        config=ds_config
    )
    
    # Profiling setup
    gpu_monitor = GPUMonitor(device_id=local_rank)
    gpu_monitor.start()
    
    criterion = torch.nn.CrossEntropyLoss()
    results = []
    
    # Training loop
    with TorchProfiler(log_dir="./logs/deepspeed_zero3") as profiler:
        for epoch in range(2):
            print(f"\n=== DeepSpeed ZeRO-3 Epoch {epoch + 1} ===")
            
            model_engine.train()
            losses = AverageMeter()
            top1 = AverageMeter()
            
            for i, (images, targets) in enumerate(train_loader):
                images = images.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                
                # Forward pass
                outputs = model_engine(images)
                loss = criterion(outputs, targets)
                
                # Backward pass
                model_engine.backward(loss)
                model_engine.step()
                
                # Metrics
                acc1, _ = calculate_accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1.item(), images.size(0))
                
                if profiler:
                    profiler.step()
                
                if i % 10 == 0 and local_rank == 0:
                    print(f'Epoch: [{epoch}][{i}/{len(train_loader)}] '
                          f'Loss {losses.val:.4f} ({losses.avg:.4f}) '
                          f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})')
                    
                    # GPU stats
                    stats = gpu_monitor.get_stats()
                    if stats:
                        print(f'GPU: {stats.get("gpu_util_avg", 0):.1f}% util, '
                              f'{stats.get("gpu_memory_avg", 0):.1f}GB mem')
                
                if i >= 50:  # Limit for demo
                    break
            
            # Validation
            if local_rank == 0:
                val_metrics = validate(model_engine, val_loader, criterion, device)
                
                epoch_results = {
                    'epoch': epoch + 1,
                    'method': 'DeepSpeed ZeRO-3',
                    'train_loss': losses.avg,
                    'train_acc1': top1.avg,
                    'val_loss': val_metrics['loss'],
                    'val_acc1': val_metrics['top1'],
                    'gpu_stats': gpu_monitor.get_stats()
                }
                results.append(epoch_results)
    
    # Cleanup
    gpu_monitor.stop()
    if local_rank == 0:
        gpu_monitor.plot_metrics(save_path='gpu_metrics_zero3.png')
        
        with open('training_results_zero3.json', 'w') as f:
            json.dump(results, f, indent=2)
    
    # Memory stats
    if local_rank == 0:
        print("\nDeepSpeed ZeRO-3 Memory Stats:")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
        print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        
        # DeepSpeed memory breakdown
        memory_breakdown = model_engine.memory_breakdown()
        for key, value in memory_breakdown.items():
            if isinstance(value, (int, float)):
                print(f"{key}: {value / 1024**3:.2f} GB")
    
    print("\nDeepSpeed ZeRO-3 training completed!")
    return results

# Run if this cell is executed
if __name__ == "__main__":
    results = train_with_deepspeed_zero3()


## Results Comparison and Analysis

Compare performance and GPU utilization across all training methods.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import glob
from pathlib import Path

def analyze_training_results():
    """Analyze and compare all training results"""
    print("Analyzing training results...")
    
    # Load all result files
    result_files = glob.glob("training_results_*.json")
    all_results = {}
    
    for file in result_files:
        method = file.replace("training_results_", "").replace(".json", "")
        try:
            with open(file, 'r') as f:
                all_results[method] = json.load(f)
                print(f"[LOADED] {method} results")
        except Exception as e:
            print(f"WARNING: Could not load {file}: {e}")
    
    if not all_results:
        print("No results found. Run training cells first.")
        return
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Training Loss Comparison
    ax = axes[0, 0]
    for method, results in all_results.items():
        epochs = [r['epoch'] for r in results]
        losses = [r['train']['loss'] if 'train' in r else r['train_loss'] for r in results]
        ax.plot(epochs, losses, marker='o', label=method, linewidth=2)
    ax.set_title('Training Loss Comparison')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Validation Accuracy Comparison
    ax = axes[0, 1]
    for method, results in all_results.items():
        epochs = [r['epoch'] for r in results]
        accs = [r['val']['top1'] if 'val' in r else r['val_acc1'] for r in results]
        ax.plot(epochs, accs, marker='s', label=method, linewidth=2)
    ax.set_title('Validation Accuracy Comparison')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Top-1 Accuracy (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. GPU Memory Usage
    ax = axes[1, 0]
    methods = []
    memory_usage = []
    for method, results in all_results.items():
        if results and 'gpu_stats' in results[-1]:
            methods.append(method)
            memory_usage.append(results[-1]['gpu_stats'].get('gpu_memory_max', 0))
    
    if methods:
        bars = ax.bar(methods, memory_usage, color=['skyblue', 'lightcoral', 'lightgreen', 'orange'][:len(methods)])
        ax.set_title('Peak GPU Memory Usage')
        ax.set_ylabel('Memory (GB)')
        ax.tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, value in zip(bars, memory_usage):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                   f'{value:.1f}GB', ha='center', va='bottom')
    
    # 4. GPU Utilization
    ax = axes[1, 1]
    methods = []
    gpu_util = []
    for method, results in all_results.items():
        if results and 'gpu_stats' in results[-1]:
            methods.append(method)
            gpu_util.append(results[-1]['gpu_stats'].get('gpu_util_avg', 0))
    
    if methods:
        bars = ax.bar(methods, gpu_util, color=['skyblue', 'lightcoral', 'lightgreen', 'orange'][:len(methods)])
        ax.set_title('Average GPU Utilization')
        ax.set_ylabel('Utilization (%)')
        ax.tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, value in zip(bars, gpu_util):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                   f'{value:.1f}%', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('training_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Summary table
    print("\nTraining Summary:")
    print("=" * 80)
    summary_data = []
    
    for method, results in all_results.items():
        if not results:
            continue
            
        final_result = results[-1]
        summary = {
            'Method': method,
            'Final Train Loss': final_result.get('train', {}).get('loss') or final_result.get('train_loss', 'N/A'),
            'Final Val Acc@1': final_result.get('val', {}).get('top1') or final_result.get('val_acc1', 'N/A'),
            'Peak GPU Memory (GB)': final_result.get('gpu_stats', {}).get('gpu_memory_max', 'N/A'),
            'Avg GPU Util (%)': final_result.get('gpu_stats', {}).get('gpu_util_avg', 'N/A'),
            'Avg Power (W)': final_result.get('gpu_stats', {}).get('power_avg', 'N/A'),
        }
        summary_data.append(summary)
    
    if summary_data:
        df = pd.DataFrame(summary_data)
        print(df.to_string(index=False, float_format='%.2f'))
    
    print("\nKey Insights:")
    print("- ZeRO-1: Partitions optimizer states - moderate memory savings")
    print("- ZeRO-2: Partitions optimizer + gradients - better memory efficiency")
    print("- ZeRO-3: Partitions all states - maximum memory savings, some overhead")
    print("- Basic training: No optimizations - highest memory usage")
    
    return all_results

def create_profiling_commands():
    """Generate commands for external profiling tools"""
    print("\nExternal Profiling Commands:")
    print("=" * 50)
    
    # Nsight Systems
    print("1. NVIDIA Nsight Systems:")
    print("   nsys profile --trace=cuda,nvtx,osrt,cudnn,cublas -o vit_training python training_script.py")
    
    # Nsight Compute
    print("\n2. NVIDIA Nsight Compute:")
    print("   ncu --set full --export vit_kernels python training_script.py")
    
    # nvprof (legacy)
    print("\n3. nvprof (legacy):")
    print("   nvprof --print-gpu-trace --export-profile vit_profile.nvvp python training_script.py")
    
    # PyTorch Profiler TensorBoard
    print("\n4. View PyTorch Profiler in TensorBoard:")
    print("   tensorboard --logdir=./logs/")
    
    # Memory profiler
    print("\n5. Memory profiling:")
    print("   python -m memory_profiler training_script.py")
    
    if COLAB:
        print("\nNote: In Colab, use the following to install TensorBoard:")
        print("   %load_ext tensorboard")
        print("   %tensorboard --logdir=./logs/")

def run_comprehensive_analysis():
    """Run all training methods and compare results"""
    print("Running comprehensive training analysis...")
    
    methods = [
        ("Basic Training", main_training_basic),
        ("DeepSpeed ZeRO-1", train_with_deepspeed_zero1),
        ("DeepSpeed ZeRO-2", train_with_deepspeed_zero2),
        ("DeepSpeed ZeRO-3", train_with_deepspeed_zero3),
    ]
    
    results = {}
    
    for method_name, method_func in methods:
        try:
            print(f"\n{'='*60}")
            print(f"Running {method_name}")
            print(f"{'='*60}")
            
            result = method_func()
            results[method_name] = result
            
            print(f"[SUCCESS] {method_name} completed successfully")
            
        except Exception as e:
            print(f"[FAILED] {method_name} failed: {e}")
            results[method_name] = None
    
    # Analyze results
    print(f"\n{'='*60}")
    print("COMPREHENSIVE ANALYSIS")
    print(f"{'='*60}")
    
    analyze_training_results()
    create_profiling_commands()
    
    return results

# Example usage
if __name__ == "__main__":
    # Option 1: Analyze existing results
    analyze_training_results()
    
    # Option 2: Run all methods and analyze (uncomment to run)
    # comprehensive_results = run_comprehensive_analysis()
