# 🏙️ DeepLabv3+ Cityscapes Reproduction on Kaggle

## **Project Overview**
This notebook reproduces **DeepLabv3+ semantic segmentation** results on the **Cityscapes dataset** using PyTorch and Torchvision. Adapted from our PASCAL VOC implementation with critical modifications for urban scene understanding.

### **🔧 Key Adaptations for Cityscapes:**
- **NUM_CLASSES**: 19 (vs 21 for PASCAL VOC)
- **RESOLUTION**: 769×769 (vs 513×513 for PASCAL VOC) 
- **LABEL MAPPING**: Critical remapping from labelIds → trainIds
- **TRAINING ITERATIONS**: 60,000 (vs 30,000 for PASCAL VOC)
- **IGNORE INDEX**: 255 for unlabeled pixels

### **📊 Dataset Info:**
- **Training**: ~3,000 images (fine annotations)
- **Validation**: ~500 images (fine annotations)
- **Classes**: 19 semantic classes (road, sidewalk, building, wall, fence, pole, traffic light, traffic sign, vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle)

---

In [None]:
# 🔧 TPU v5e-8 Setup
try:
    # TPU Setup for Kaggle
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.utils.utils as xu
    
    device = xm.xla_device()
    print(f"🚀 Using TPU v5e-8: {device}")
    
    # Get world size with API compatibility
    try:
        world_size = xm.xrt_world_size()  # Old API
    except AttributeError:
        world_size = xm.get_world_size()  # New API
    
    try:
        ordinal = xm.get_ordinal()  # Should work in both versions
    except AttributeError:
        ordinal = 0  # Fallback
    
    print(f"   TPU Cores: {world_size}")
    print(f"   Current Core: {ordinal}")
    
    IS_TPU = True
    
except ImportError:
    # Fallback to GPU/CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
    print(f"🚀 TPU not available, using: {device}")
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    IS_TPU = False
    world_size = 1  # Set fallback world_size for non-TPU
    ordinal = 0

In [None]:
# 🏙️ **CITYSCAPES CONFIGURATION** - Optimized for TPU v5e-8
CFG = {
    # 📊 Dataset Configuration (CITYSCAPES SPECIFIC)
    'NUM_CLASSES': 19,              # 19 semantic classes for Cityscapes (vs 21 for PASCAL VOC)
    'IGNORE_INDEX': 255,            # Standard ignore index for unlabeled pixels
    'CROP_SIZE': 769,               # Higher resolution for urban scenes (vs 513 for PASCAL VOC)  
    'BASE_SIZE': 769,               # Base size for resize operations
    
    # 🎯 Training Configuration (TPU v5e-8 OPTIMIZED)
    'BATCH_SIZE': 32,               # Higher batch size leveraging TPU v5e-8 HBM (vs 2 for GPU)
    'NUM_WORKERS': 8,               # More workers for TPU data loading efficiency
    'MAX_ITERATIONS': 40000,        # Reduced due to larger effective batch size
    'EVAL_INTERVAL': 1000,          # More frequent evaluation with faster TPU
    'SAVE_INTERVAL': 2000,          # More frequent checkpoints
    
    # 🔧 Optimization Configuration (TPU TUNED)
    'LEARNING_RATE': 0.08,          # Higher LR for larger batch size (linear scaling: 0.01 * 8)
    'WEIGHT_DECAY': 5e-4,           # L2 regularization
    'MOMENTUM': 0.9,                # SGD momentum
    'POWER': 0.9,                   # Polynomial LR decay power
    'WARMUP_ITERATIONS': 1000,      # LR warmup for large batch training
    
    # 📁 Path Configuration (KAGGLE SPECIFIC)
    'DATASET_ROOT': '/kaggle/input/cityscapes',  # Input dataset path
    'OUTPUT_DIR': '/kaggle/working',              # Output directory for models/logs
    
    # 🖼️ Data Augmentation Configuration
    'RANDOM_SCALE_MIN': 0.5,        # Minimum scale for random scaling
    'RANDOM_SCALE_MAX': 2.0,        # Maximum scale for random scaling  
    'HORIZONTAL_FLIP_PROB': 0.5,    # Probability for horizontal flip
    
    # 🧠 Model Configuration (TPU OPTIMIZED)
    'BACKBONE': 'resnet101',         # ResNet-101 backbone
    'PRETRAINED': True,              # Use ImageNet pretrained weights
    'MIXED_PRECISION': True,         # bfloat16 for TPU (instead of fp16)
    'GRADIENT_ACCUMULATION': 1,      # No accumulation needed with larger batch size
    'TPU_CORES': 8,                  # TPU v5e-8 has 8 cores
    
    # 📊 ImageNet Normalization (Standard for pretrained models)
    'MEAN': [0.485, 0.456, 0.406],
    'STD': [0.229, 0.224, 0.225]
}

print("🏙️ **CITYSCAPES CONFIGURATION - TPU v5e-8 OPTIMIZED**")
print(f"   Classes: {CFG['NUM_CLASSES']}")
print(f"   Resolution: {CFG['CROP_SIZE']}×{CFG['CROP_SIZE']}")  
print(f"   Batch Size: {CFG['BATCH_SIZE']} per core")
print(f"   TPU Cores: {CFG['TPU_CORES']} (Total batch: {CFG['BATCH_SIZE'] * CFG['TPU_CORES']})")
print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
print(f"   Learning Rate: {CFG['LEARNING_RATE']} (scaled for large batch)")
print(f"   Dataset Path: {CFG['DATASET_ROOT']}")
print(f"   Mixed Precision: {CFG['MIXED_PRECISION']} (bfloat16)")

In [None]:
# 🏷️ **CITYSCAPES LABEL MAPPING** - Critical for Correct Training
"""
Cityscapes uses complex label system:
- labelIds: Original labels in *_labelIds.png files (0-33)  
- trainIds: Training labels we need (0-18 + 255 for ignore)

This mapping is ESSENTIAL for correct training!
"""

# 🎯 Official Cityscapes Label Mapping (labelId -> trainId)
CITYSCAPES_LABEL_MAP = {
    # Road & Ground
    7: 0,    # road
    8: 1,    # sidewalk
    11: 2,   # building
    12: 3,   # wall
    13: 4,   # fence
    17: 5,   # pole
    19: 6,   # traffic light
    20: 7,   # traffic sign
    21: 8,   # vegetation
    22: 9,   # terrain
    23: 10,  # sky
    24: 11,  # person
    25: 12,  # rider
    26: 13,  # car
    27: 14,  # truck
    28: 15,  # bus
    31: 16,  # train
    32: 17,  # motorcycle
    33: 18,  # bicycle
}

# 📋 Class Names for Reference
CITYSCAPES_CLASSES = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
    'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
    'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]

def remap_labels(label_img):
    """
    🔄 Remap Cityscapes labelIds to trainIds
    
    Args:
        label_img: numpy array with original labelIds
    Returns:
        remapped_img: numpy array with trainIds (0-18) and ignore_index (255)
    """
    # Initialize with ignore index (255)
    remapped = np.full_like(label_img, CFG['IGNORE_INDEX'], dtype=np.uint8)
    
    # Apply mapping for valid classes
    for label_id, train_id in CITYSCAPES_LABEL_MAP.items():
        mask = (label_img == label_id)
        remapped[mask] = train_id
    
    return remapped

# 🧪 Test the mapping function
print("🏷️ **CITYSCAPES LABEL MAPPING LOADED**")
print(f"   Valid Classes: {len(CITYSCAPES_CLASSES)}")
print(f"   Label Mappings: {len(CITYSCAPES_LABEL_MAP)}")
print(f"   Ignore Index: {CFG['IGNORE_INDEX']}")
print("\n📋 **Class List:**")
for i, class_name in enumerate(CITYSCAPES_CLASSES):
    print(f"   {i:2d}: {class_name}")
    
# Test mapping with dummy data
test_labels = np.array([7, 24, 26, 0, 255])  # road, person, car, void, void
test_remapped = remap_labels(test_labels)
print(f"\n🧪 **Mapping Test:**")
print(f"   Original: {test_labels}")
print(f"   Remapped: {test_remapped}")  # Should be [0, 11, 13, 255, 255]

In [None]:
# 📥 Download Cityscapes Dataset
"""
Download and setup Cityscapes dataset from Kaggle
Expected structure:
/kaggle/input/cityscapes/
├── leftImg8bit/
│   ├── train/
│   └── val/
└── gtFine/
    ├── train/
    └── val/
"""

import kagglehub

# Download Cityscapes dataset
try:
    print("📥 Downloading Cityscapes dataset...")
    dataset_path = kagglehub.dataset_download("dansbecker/cityscapes-image-pairs")
    print(f"✅ Dataset downloaded to: {dataset_path}")
    
    # Update CFG with actual dataset path
    CFG['DATASET_ROOT'] = dataset_path
    print(f"🔄 Updated dataset root: {CFG['DATASET_ROOT']}")
    
    # Verify dataset structure
    if os.path.exists(dataset_path):
        print("\n📁 **Dataset Structure:**")
        for item in sorted(os.listdir(dataset_path)):
            item_path = os.path.join(dataset_path, item)
            if os.path.isdir(item_path):
                print(f"   📂 {item}/")
                # Show subdirectories
                try:
                    subdirs = [d for d in os.listdir(item_path) if os.path.isdir(os.path.join(item_path, d))]
                    for subdir in sorted(subdirs)[:3]:  # Show first 3 subdirs
                        print(f"      📂 {subdir}/")
                    if len(subdirs) > 3:
                        print(f"      ... and {len(subdirs)-3} more")
                except:
                    pass
            else:
                print(f"   📄 {item}")
    
except Exception as e:
    print(f"❌ Error downloading dataset: {e}")
    print("💡 Using default path: /kaggle/input/cityscapes")
    CFG['DATASET_ROOT'] = "/kaggle/input/cityscapes"

In [None]:
# 🏙️ **CITYSCAPES DATASET CLASS** - Custom Dataset Implementation
class CityscapesDataset(Dataset):
    """
    🏙️ Custom Cityscapes Dataset for semantic segmentation
    
    Key Features:
    - Handles leftImg8bit (RGB images) and gtFine (segmentation masks)
    - Applies critical label remapping (labelIds -> trainIds)
    - Supports different augmentations for train/val
    - Optimized for TPU v5e-8 with efficient data loading
    """
    
    def __init__(self, root_dir, split='train', transforms=None):
        """
        Args:
            root_dir: Path to cityscapes dataset
            split: 'train' or 'val'
            transforms: Albumentations transforms
        """
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms
        
        # 📁 Setup paths
        self.images_dir = os.path.join(root_dir, 'leftImg8bit', split)
        self.labels_dir = os.path.join(root_dir, 'gtFine', split)
        
        # 🔍 Find all image files
        self.image_files = []
        self.label_files = []
        
        if os.path.exists(self.images_dir):
            # Standard Cityscapes structure: city folders
            for city in sorted(os.listdir(self.images_dir)):
                city_img_dir = os.path.join(self.images_dir, city)
                city_lbl_dir = os.path.join(self.labels_dir, city)
                
                if os.path.isdir(city_img_dir):
                    # Get image files
                    img_files = glob.glob(os.path.join(city_img_dir, '*_leftImg8bit.png'))
                    
                    for img_file in sorted(img_files):
                        # Corresponding label file
                        basename = os.path.basename(img_file).replace('_leftImg8bit.png', '')
                        label_file = os.path.join(city_lbl_dir, f'{basename}_gtFine_labelIds.png')
                        
                        if os.path.exists(label_file):
                            self.image_files.append(img_file)
                            self.label_files.append(label_file)
        else:
            # Alternative: flat structure
            print(f"⚠️  Standard structure not found, trying alternative...")
            img_pattern = os.path.join(root_dir, '**', '*leftImg8bit*.png')
            self.image_files = glob.glob(img_pattern, recursive=True)
            
            # Find corresponding labels
            for img_file in self.image_files:
                # Try to find corresponding label
                basename = os.path.basename(img_file)
                label_basename = basename.replace('leftImg8bit', 'gtFine_labelIds')
                
                # Search for label file
                label_pattern = os.path.join(root_dir, '**', label_basename)
                label_matches = glob.glob(label_pattern, recursive=True)
                
                if label_matches:
                    self.label_files.append(label_matches[0])
                else:
                    # Remove image if no corresponding label
                    self.image_files.remove(img_file)
        
        print(f"🏙️ **Cityscapes {split.upper()} Dataset:**")
        print(f"   Images: {len(self.image_files)}")
        print(f"   Labels: {len(self.label_files)}")
        
        if len(self.image_files) == 0:
            print("❌ No images found! Check dataset structure.")
        elif len(self.image_files) != len(self.label_files):
            print(f"⚠️  Mismatch: {len(self.image_files)} images vs {len(self.label_files)} labels")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Get image and label pair with preprocessing
        """
        # 📖 Load image and label
        image_path = self.image_files[idx] 
        label_path = self.label_files[idx]
        
        # Load image (RGB)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load label (grayscale)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        
        # 🔄 Critical: Remap labels (labelIds -> trainIds)
        label = remap_labels(label)
        
        # 🖼️ Apply transforms
        if self.transforms:
            transformed = self.transforms(image=image, mask=label)
            image = transformed['image']
            label = transformed['mask']
        
        return image, label.long()
    
    def get_sample_info(self, idx):
        """Get file paths for debugging"""
        return {
            'image_path': self.image_files[idx],
            'label_path': self.label_files[idx]
        }

In [None]:
# 🖼️ **DATA AUGMENTATION PIPELINE** - Cityscapes Specific
"""
Augmentation strategy adapted for Cityscapes:
- Higher resolution (769x769) for urban scene detail
- Scale range 0.5-2.0 for variety
- Careful padding with ignore_index for masks
"""

def get_train_transforms():
    """🏋️ Training augmentations for Cityscapes"""
    return A.Compose([
        # 📏 Random scaling (0.5x to 2.0x)
        A.RandomScale(scale_limit=(CFG['RANDOM_SCALE_MIN']-1, CFG['RANDOM_SCALE_MAX']-1), 
                      interpolation=cv2.INTER_LINEAR, p=1.0),
        
        # 📐 Pad if needed to ensure minimum size
        A.PadIfNeeded(min_height=CFG['CROP_SIZE'], min_width=CFG['CROP_SIZE'],
                      border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=CFG['IGNORE_INDEX']),
        
        # ✂️ Random crop to target size
        A.RandomCrop(height=CFG['CROP_SIZE'], width=CFG['CROP_SIZE']),
        
        # 🔄 Horizontal flip
        A.HorizontalFlip(p=CFG['HORIZONTAL_FLIP_PROB']),
        
        # 🎨 Color augmentations (optional - comment out if too aggressive)
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
        
        # 📊 Normalization & tensor conversion
        A.Normalize(mean=CFG['MEAN'], std=CFG['STD']),
        ToTensorV2()
    ])

def get_val_transforms():
    """🧪 Validation transforms (minimal processing)"""
    return A.Compose([
        # 📏 Resize to target size (can use different strategy)
        A.Resize(height=CFG['CROP_SIZE'], width=CFG['CROP_SIZE'], 
                 interpolation=cv2.INTER_LINEAR),
        
        # 📊 Normalization & tensor conversion  
        A.Normalize(mean=CFG['MEAN'], std=CFG['STD']),
        ToTensorV2()
    ])

# 🧪 Test transforms
print("🖼️ **AUGMENTATION PIPELINE READY**")
print(f"   Crop Size: {CFG['CROP_SIZE']}×{CFG['CROP_SIZE']}")
print(f"   Scale Range: {CFG['RANDOM_SCALE_MIN']:.1f} - {CFG['RANDOM_SCALE_MAX']:.1f}")
print(f"   Horizontal Flip: {CFG['HORIZONTAL_FLIP_PROB']*100:.0f}%")
print(f"   Ignore Index: {CFG['IGNORE_INDEX']} (for padding)")

# Test with dummy data
test_img = np.random.randint(0, 255, (1024, 2048, 3), dtype=np.uint8)
test_mask = np.random.randint(0, 19, (1024, 2048), dtype=np.uint8)

train_transform = get_train_transforms()
val_transform = get_val_transforms()

try:
    # Test transforms
    train_result = train_transform(image=test_img, mask=test_mask)
    val_result = val_transform(image=test_img, mask=test_mask)
    
    print(f"\n✅ **Transform Test Passed:**")
    print(f"   Train Output: {train_result['image'].shape}, {train_result['mask'].shape}")
    print(f"   Val Output: {val_result['image'].shape}, {val_result['mask'].shape}")
    print(f"   Train Image Range: [{train_result['image'].min():.3f}, {train_result['image'].max():.3f}]")
    
except Exception as e:
    print(f"❌ Transform Test Failed: {e}")

In [None]:
# 🧠 **DEEPLABV3+ MODEL CREATION** - TPU Optimized
"""
Create DeepLabv3+ model with ResNet-101 backbone
Key modifications for Cityscapes:
- NUM_CLASSES = 19 (vs 21 for PASCAL VOC)
- TPU-optimized with bfloat16 mixed precision
- Proper classifier head modification
"""

def create_deeplabv3plus_model():
    """
    🏗️ Create DeepLabv3+ model for Cityscapes
    """
    print("🧠 Creating DeepLabv3+ model...")
    
    # Load pretrained DeepLabv3+ with ResNet-101
    model = models.deeplabv3_resnet101(
        pretrained=CFG['PRETRAINED'],
        progress=True,
        num_classes=21  # Start with PASCAL VOC pretrained
    )
    
    # 🔧 Modify classifier for Cityscapes (19 classes)
    # DeepLabv3+ has classifier and aux_classifier
    model.classifier[4] = nn.Conv2d(
        in_channels=256,
        out_channels=CFG['NUM_CLASSES'], 
        kernel_size=1
    )
    
    if hasattr(model, 'aux_classifier') and model.aux_classifier is not None:
        model.aux_classifier[4] = nn.Conv2d(
            in_channels=256,
            out_channels=CFG['NUM_CLASSES'],
            kernel_size=1
        )
    
    print(f"✅ Model created with {CFG['NUM_CLASSES']} classes")
    
    # 📊 Model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    
    return model

# 🏗️ Create model
model = create_deeplabv3plus_model()

# 🚀 Move to TPU device
if IS_TPU:
    model = model.to(device)
    print(f"📱 Model moved to TPU: {device}")
else:
    model = model.to(device)
    print(f"📱 Model moved to device: {device}")

# 🧪 Test model with dummy input
try:
    dummy_input = torch.randn(1, 3, CFG['CROP_SIZE'], CFG['CROP_SIZE']).to(device)
    
    if IS_TPU:
        # TPU requires different handling
        model.eval()
        with torch.no_grad():
            output = model(dummy_input)
    else:
        model.eval()
        with torch.no_grad():
            output = model(dummy_input)
    
    print(f"✅ **Model Test Passed:**")
    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output['out'].shape}")
    if 'aux' in output:
        print(f"   Aux output shape: {output['aux'].shape}")
    
    # Verify output channels
    expected_shape = (1, CFG['NUM_CLASSES'], CFG['CROP_SIZE'], CFG['CROP_SIZE'])
    if output['out'].shape == expected_shape:
        print(f"   ✅ Output shape correct: {expected_shape}")
    else:
        print(f"   ⚠️  Output shape mismatch: expected {expected_shape}, got {output['out'].shape}")

except Exception as e:
    print(f"❌ Model test failed: {e}")

model.train()  # Set back to training mode

In [None]:
# 📊 **DATASET CREATION & TPU DATALOADER** - Optimized Pipeline
"""
Create train/val datasets and TPU-optimized DataLoaders
Key features:
- CityscapesDataset with label remapping
- TPU-aware data loading
- Parallel data loading across cores
"""

# 🏗️ Create datasets
print("📊 Creating Cityscapes datasets...")

train_dataset = CityscapesDataset(
    root_dir=CFG['DATASET_ROOT'],
    split='train',
    transforms=get_train_transforms()
)

val_dataset = CityscapesDataset(
    root_dir=CFG['DATASET_ROOT'], 
    split='val',
    transforms=get_val_transforms()
)

print(f"\n📈 **Dataset Summary:**")
print(f"   Training samples: {len(train_dataset):,}")
print(f"   Validation samples: {len(val_dataset):,}")
print(f"   Total samples: {len(train_dataset) + len(val_dataset):,}")

# 🚀 Create TPU-optimized DataLoaders
if IS_TPU:
    # TPU requires special data loading
    print("🚀 Creating TPU DataLoaders...")
    
    # Create samplers for distributed TPU training
    # Get world size with API compatibility
    try:
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()
    except AttributeError:
        world_size = xm.get_world_size()
        rank = xm.get_ordinal()
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['BATCH_SIZE'],
        sampler=train_sampler,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=True,
        pin_memory=False  # Not needed for TPU
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['BATCH_SIZE'],
        sampler=val_sampler,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=False,
        pin_memory=False
    )
    
    # Wrap with TPU ParallelLoader
    train_loader = pl.MpDeviceLoader(train_loader, device)
    val_loader = pl.MpDeviceLoader(val_loader, device)
    
    effective_batch_size = CFG['BATCH_SIZE'] * world_size
    print(f"   TPU Cores: {world_size}")
    print(f"   Batch per core: {CFG['BATCH_SIZE']}")
    print(f"   Effective batch size: {effective_batch_size}")
    
else:
    # Regular GPU/CPU DataLoaders
    print("💻 Creating GPU/CPU DataLoaders...")
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['BATCH_SIZE'],
        shuffle=True,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=True,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['BATCH_SIZE'],
        shuffle=False,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=False,
        pin_memory=True
    )
    
    effective_batch_size = CFG['BATCH_SIZE']
    print(f"   Batch size: {effective_batch_size}")

print(f"   Training batches: {len(train_loader):,}")
print(f"   Validation batches: {len(val_loader):,}")

# 🧪 Test data loading
print("\n🧪 Testing data loading...")
try:
    # Get one batch
    data_iter = iter(train_loader)
    images, labels = next(data_iter)
    
    print(f"✅ **Data Loading Test Passed:**")
    print(f"   Batch images shape: {images.shape}")
    print(f"   Batch labels shape: {labels.shape}")
    print(f"   Image dtype: {images.dtype}")
    print(f"   Label dtype: {labels.dtype}")
    print(f"   Image range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"   Label range: [{labels.min()}, {labels.max()}]")
    print(f"   Unique labels in batch: {torch.unique(labels).tolist()}")
    
    # Check for ignore index
    ignore_count = (labels == CFG['IGNORE_INDEX']).sum().item()
    total_pixels = labels.numel()
    print(f"   Ignore pixels: {ignore_count:,} / {total_pixels:,} ({ignore_count/total_pixels*100:.2f}%)")
    
except Exception as e:
    print(f"❌ Data loading test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# 🎯 **LOSS FUNCTION & OPTIMIZER** - TPU Optimized Setup
"""
Setup loss function, optimizer, and learning rate scheduler
Key configurations:
- CrossEntropyLoss with ignore_index=255
- SGD optimizer with momentum
- Polynomial learning rate decay
- TPU-optimized settings
"""

# 📉 Loss Function
criterion = nn.CrossEntropyLoss(ignore_index=CFG['IGNORE_INDEX'])
print(f"📉 **Loss Function:** CrossEntropyLoss(ignore_index={CFG['IGNORE_INDEX']})")

# 🔧 Optimizer Setup  
optimizer = optim.SGD(
    model.parameters(),
    lr=CFG['LEARNING_RATE'],
    momentum=CFG['MOMENTUM'],
    weight_decay=CFG['WEIGHT_DECAY']
)

print(f"🔧 **Optimizer:** SGD")
print(f"   Learning Rate: {CFG['LEARNING_RATE']}")
print(f"   Momentum: {CFG['MOMENTUM']}")
print(f"   Weight Decay: {CFG['WEIGHT_DECAY']}")

# 📅 Learning Rate Scheduler (Polynomial Decay)
def poly_lr_scheduler(optimizer, init_lr, iter, max_iter, power):
    """Polynomial learning rate decay"""
    lr = init_lr * (1 - iter / max_iter) ** power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# 🔥 Mixed Precision Setup
if CFG['MIXED_PRECISION']:
    if IS_TPU:
        # TPU uses bfloat16 automatically
        print("🔥 **Mixed Precision:** bfloat16 (TPU native)")
        scaler = None  # TPU handles this automatically
    else:
        # GPU uses GradScaler for fp16
        from torch.cuda.amp import GradScaler, autocast
        scaler = GradScaler()
        print("🔥 **Mixed Precision:** fp16 + GradScaler")
else:
    scaler = None
    print("🔥 **Mixed Precision:** Disabled")

# 📊 Training State Tracking
training_state = {
    'iteration': 0,
    'best_miou': 0.0,
    'train_losses': [],
    'val_mious': [],
    'learning_rates': []
}

print(f"\n🎯 **Training Configuration:**")
print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
print(f"   Eval Interval: {CFG['EVAL_INTERVAL']:,}")
print(f"   Save Interval: {CFG['SAVE_INTERVAL']:,}")
print(f"   Warmup Iterations: {CFG['WARMUP_ITERATIONS']:,}")
print(f"   Polynomial Power: {CFG['POWER']}")

# 📁 Create output directory
os.makedirs(CFG['OUTPUT_DIR'], exist_ok=True)
print(f"📁 Output directory: {CFG['OUTPUT_DIR']}")

# 🧪 Test loss computation
print("\n🧪 Testing loss computation...")
try:
    # Create dummy predictions and targets
    dummy_pred = torch.randn(2, CFG['NUM_CLASSES'], 64, 64).to(device)
    dummy_target = torch.randint(0, CFG['NUM_CLASSES'], (2, 64, 64)).to(device)
    
    # Add some ignore pixels
    dummy_target[0, :10, :10] = CFG['IGNORE_INDEX']
    
    # Compute loss
    loss = criterion(dummy_pred, dummy_target)
    
    print(f"✅ **Loss Test Passed:**")
    print(f"   Dummy prediction shape: {dummy_pred.shape}")
    print(f"   Dummy target shape: {dummy_target.shape}")
    print(f"   Loss value: {loss.item():.4f}")
    
    # Test backward pass
    loss.backward()
    print(f"   ✅ Backward pass successful")
    
    # Clear gradients
    optimizer.zero_grad()
    
except Exception as e:
    print(f"❌ Loss test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# 📊 **mIoU EVALUATION METRICS** - Cityscapes Standard
"""
Mean Intersection over Union (mIoU) computation for Cityscapes
Key features:
- Per-class IoU calculation
- Proper handling of ignore_index (255)
- TPU-optimized computation
- Standard Cityscapes evaluation protocol
"""

class mIoUCalculator:
    """
    📊 Calculate mean Intersection over Union for semantic segmentation
    """
    
    def __init__(self, num_classes, ignore_index=255):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.reset()
    
    def reset(self):
        """Reset all statistics"""
        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
    
    def update(self, predictions, targets):
        """
        Update confusion matrix with new predictions and targets
        
        Args:
            predictions: Model predictions (B, H, W) - class indices
            targets: Ground truth labels (B, H, W) - class indices
        """
        # Convert to numpy if tensors
        if torch.is_tensor(predictions):
            predictions = predictions.cpu().numpy()
        if torch.is_tensor(targets):
            targets = targets.cpu().numpy()
        
        # Flatten arrays
        predictions = predictions.flatten()
        targets = targets.flatten()
        
        # Remove ignore pixels
        valid_mask = (targets != self.ignore_index)
        predictions = predictions[valid_mask]
        targets = targets[valid_mask]
        
        # Ensure predictions are within valid range
        predictions = np.clip(predictions, 0, self.num_classes - 1)
        
        # Update confusion matrix
        for pred, target in zip(predictions, targets):
            if 0 <= target < self.num_classes:
                self.confusion_matrix[target, pred] += 1
    
    def compute_iou(self):
        """
        Compute IoU for each class and mean IoU
        
        Returns:
            per_class_iou: IoU for each class
            mean_iou: Mean IoU across all classes
        """
        # Calculate IoU for each class
        intersection = np.diag(self.confusion_matrix)
        union = (
            self.confusion_matrix.sum(axis=1) + 
            self.confusion_matrix.sum(axis=0) - 
            intersection
        )
        
        # Avoid division by zero
        valid_classes = union > 0
        per_class_iou = np.zeros(self.num_classes)
        per_class_iou[valid_classes] = intersection[valid_classes] / union[valid_classes]
        
        # Mean IoU (only for classes that appear in ground truth)
        mean_iou = per_class_iou[valid_classes].mean() if valid_classes.any() else 0.0
        
        return per_class_iou, mean_iou
    
    def get_results(self):
        """Get detailed results"""
        per_class_iou, mean_iou = self.compute_iou()
        
        results = {
            'mIoU': mean_iou,
            'per_class_IoU': per_class_iou,
            'confusion_matrix': self.confusion_matrix.copy()
        }
        
        return results
    
    def print_results(self, class_names=None):
        """Print formatted results"""
        per_class_iou, mean_iou = self.compute_iou()
        
        print(f"📊 **mIoU Results:**")
        print(f"   Mean IoU: {mean_iou:.4f} ({mean_iou*100:.2f}%)")
        print(f"\n📋 **Per-Class IoU:**")
        
        if class_names is None:
            class_names = [f"Class_{i}" for i in range(self.num_classes)]
        
        for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)):
            print(f"   {i:2d}. {class_name:<15}: {iou:.4f} ({iou*100:.2f}%)")

# 🧪 Test mIoU calculator
print("📊 **mIoU CALCULATOR INITIALIZED**")
miou_calculator = mIoUCalculator(
    num_classes=CFG['NUM_CLASSES'], 
    ignore_index=CFG['IGNORE_INDEX']
)

# Test with dummy data
print("\n🧪 Testing mIoU calculator...")
try:
    # Create dummy predictions and targets
    dummy_predictions = np.random.randint(0, CFG['NUM_CLASSES'], (2, 100, 100))
    dummy_targets = np.random.randint(0, CFG['NUM_CLASSES'], (2, 100, 100))
    
    # Add some ignore pixels
    dummy_targets[0, :10, :10] = CFG['IGNORE_INDEX']
    
    # Update calculator
    miou_calculator.update(dummy_predictions, dummy_targets)
    
    # Compute results
    results = miou_calculator.get_results()
    
    print(f"✅ **mIoU Test Passed:**")
    print(f"   Mean IoU: {results['mIoU']:.4f}")
    print(f"   Confusion matrix shape: {results['confusion_matrix'].shape}")
    print(f"   Per-class IoU shape: {results['per_class_IoU'].shape}")
    
    # Test with class names
    miou_calculator.print_results(CITYSCAPES_CLASSES)
    
    # Reset for actual training
    miou_calculator.reset()
    
except Exception as e:
    print(f"❌ mIoU test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# 🔄 **EVALUATION FUNCTION** - TPU Optimized Validation
"""
Comprehensive evaluation function for Cityscapes validation
Key features:
- TPU-aware evaluation loop
- Memory-efficient processing
- mIoU computation with proper aggregation
- Progress tracking and logging
"""

def evaluate_model(model, val_loader, criterion, miou_calculator, device, is_tpu=False):
    """
    🔍 Evaluate model on validation set
    
    Returns:
        eval_results: Dictionary with loss, mIoU, and per-class IoU
    """
    model.eval()
    
    # Reset metrics
    miou_calculator.reset()
    total_loss = 0.0
    total_samples = 0
    
    print("🔍 Starting validation...")
    
    with torch.no_grad():
        # Progress bar
        pbar = tqdm(val_loader, desc="🧪 Validating", leave=False) if not is_tpu else val_loader
        
        for batch_idx, (images, targets) in enumerate(pbar):
            # Move to device (already on TPU if using TPU loader)
            if not is_tpu:
                images = images.to(device)
                targets = targets.to(device)
            
            # Forward pass
            if CFG['MIXED_PRECISION'] and not is_tpu:
                # GPU with mixed precision
                with autocast():
                    outputs = model(images)
                    loss = criterion(outputs['out'], targets)
            else:
                # TPU or regular computation
                outputs = model(images)
                loss = criterion(outputs['out'], targets)
            
            # Accumulate loss
            total_loss += loss.item()
            total_samples += images.size(0)
            
            # Get predictions
            predictions = torch.argmax(outputs['out'], dim=1)
            
            # Update mIoU calculator
            miou_calculator.update(predictions, targets)
            
            # Update progress bar
            if not is_tpu and hasattr(pbar, 'set_postfix'):
                current_loss = total_loss / (batch_idx + 1)
                pbar.set_postfix({
                    'Loss': f'{current_loss:.4f}',
                    'Samples': f'{total_samples:,}'
                })
            
            # Periodic memory cleanup for TPU
            if is_tpu and batch_idx % 10 == 0:
                xm.mark_step()  # TPU step
    
    # Compute final metrics
    avg_loss = total_loss / len(val_loader)
    eval_results = miou_calculator.get_results()
    eval_results['loss'] = avg_loss
    eval_results['total_samples'] = total_samples
    
    # TPU synchronization
    if is_tpu:
        # Reduce metrics across TPU cores
        xm.master_print(f"🔍 Validation completed on {total_samples:,} samples")
        # Note: For proper TPU evaluation, we'd need to aggregate metrics across cores
        # This is simplified for demonstration
    
    print(f"✅ Validation completed:")
    print(f"   Average Loss: {avg_loss:.4f}")
    print(f"   Mean IoU: {eval_results['mIoU']:.4f} ({eval_results['mIoU']*100:.2f}%)")
    print(f"   Total Samples: {total_samples:,}")
    
    model.train()  # Set back to training mode
    return eval_results

# 🧪 Test evaluation function (dry run)
print("🔍 **EVALUATION FUNCTION READY**")
print("   - TPU-aware validation loop")
print("   - Memory-efficient processing")
print("   - Comprehensive mIoU computation")
print("   - Progress tracking and logging")

In [None]:
# 🚀 **TRAINING LOOP** - TPU Optimized Main Training
"""
Complete training loop optimized for TPU v5e-8
Key features:
- Polynomial learning rate scheduling with warmup
- TPU-native mixed precision (bfloat16)
- Periodic evaluation and checkpointing
- Memory-efficient gradient computation
- Comprehensive logging and monitoring
"""

def train_model():
    """
    🚀 Main training function
    """
    print("🚀 **STARTING TRAINING**")
    print(f"   Device: {device}")
    print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
    print(f"   Effective Batch Size: {(CFG['BATCH_SIZE'] * world_size if IS_TPU else CFG['BATCH_SIZE'])}")
    print(f"   Learning Rate: {CFG['LEARNING_RATE']}")
    print(f"   Mixed Precision: {CFG['MIXED_PRECISION']}")
    
    # Training state
    global training_state
    iteration = training_state['iteration']
    best_miou = training_state['best_miou']
    
    # Create progress tracking
    if not IS_TPU or xm.is_master_ordinal():
        pbar = tqdm(total=CFG['MAX_ITERATIONS'], initial=iteration, desc="🚀 Training")
    
    # Training loop
    model.train()
    
    while iteration < CFG['MAX_ITERATIONS']:
        
        # Set epoch for distributed sampler
        if IS_TPU:
            epoch = iteration // len(train_loader)
            train_loader._loader.sampler.set_epoch(epoch)
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            
            # Check iteration limit
            if iteration >= CFG['MAX_ITERATIONS']:
                break
            
            # Move to device (already on TPU if using TPU loader)
            if not IS_TPU:
                images = images.to(device)
                targets = targets.to(device)
            
            # Learning rate scheduling with warmup
            if iteration < CFG['WARMUP_ITERATIONS']:
                # Linear warmup
                lr = CFG['LEARNING_RATE'] * (iteration / CFG['WARMUP_ITERATIONS'])
            else:
                # Polynomial decay
                lr = poly_lr_scheduler(
                    optimizer, 
                    CFG['LEARNING_RATE'],
                    iteration - CFG['WARMUP_ITERATIONS'],
                    CFG['MAX_ITERATIONS'] - CFG['WARMUP_ITERATIONS'],
                    CFG['POWER']
                )
            
            # Set learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            if CFG['MIXED_PRECISION'] and not IS_TPU:
                # GPU mixed precision
                with autocast():
                    outputs = model(images)
                    main_loss = criterion(outputs['out'], targets)
                    
                    # Auxiliary loss (if available)
                    if 'aux' in outputs and outputs['aux'] is not None:
                        aux_loss = criterion(outputs['aux'], targets)
                        loss = main_loss + 0.4 * aux_loss  # Standard weight for aux loss
                    else:
                        loss = main_loss
                
                # Backward pass with scaling
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
            else:
                # TPU or regular computation
                outputs = model(images)
                main_loss = criterion(outputs['out'], targets)
                
                # Auxiliary loss (if available)
                if 'aux' in outputs and outputs['aux'] is not None:
                    aux_loss = criterion(outputs['aux'], targets)
                    loss = main_loss + 0.4 * aux_loss
                else:
                    loss = main_loss
                
                # Backward pass
                loss.backward()
                optimizer.step()
            
            # TPU step marking
            if IS_TPU:
                xm.mark_step()
            
            # Record training state
            training_state['train_losses'].append(loss.item())
            training_state['learning_rates'].append(lr)
            iteration += 1
            training_state['iteration'] = iteration
            
            # Update progress bar
            if not IS_TPU or xm.is_master_ordinal():
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'LR': f'{lr:.6f}',
                    'Iter': f'{iteration}/{CFG["MAX_ITERATIONS"]}'
                })
                pbar.update(1)
            
            # Evaluation
            if iteration % CFG['EVAL_INTERVAL'] == 0:
                print(f"\n🔍 **Evaluation at iteration {iteration:,}**")
                
                # Run evaluation
                eval_results = evaluate_model(
                    model, val_loader, criterion, 
                    miou_calculator, device, IS_TPU
                )
                
                current_miou = eval_results['mIoU']
                training_state['val_mious'].append(current_miou)
                
                # Check for best model
                if current_miou > best_miou:
                    best_miou = current_miou
                    training_state['best_miou'] = best_miou
                    
                    # Save best model
                    if not IS_TPU or xm.is_master_ordinal():
                        model_path = os.path.join(CFG['OUTPUT_DIR'], 'best_model.pth')
                        torch.save({
                            'iteration': iteration,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_miou': best_miou,
                            'config': CFG
                        }, model_path)
                        print(f"💾 Best model saved: {model_path} (mIoU: {best_miou:.4f})")
                
                print(f"   Current mIoU: {current_miou:.4f}")
                print(f"   Best mIoU: {best_miou:.4f}")
                
                # Print per-class results
                miou_calculator.print_results(CITYSCAPES_CLASSES)
                
                model.train()  # Set back to training mode
            
            # Save checkpoint
            if iteration % CFG['SAVE_INTERVAL'] == 0:
                if not IS_TPU or xm.is_master_ordinal():
                    checkpoint_path = os.path.join(CFG['OUTPUT_DIR'], f'checkpoint_{iteration}.pth')
                    torch.save({
                        'iteration': iteration,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'training_state': training_state,
                        'config': CFG
                    }, checkpoint_path)
                    print(f"\n💾 Checkpoint saved: {checkpoint_path}")
            
            # Memory cleanup
            del outputs, loss
            if IS_TPU and iteration % 50 == 0:
                xm.mark_step()  # Additional TPU synchronization
    
    # Close progress bar
    if not IS_TPU or xm.is_master_ordinal():
        pbar.close()
    
    print(f"\n✅ **Training completed!**")
    print(f"   Total iterations: {iteration:,}")
    print(f"   Best mIoU: {best_miou:.4f} ({best_miou*100:.2f}%)")
    
    return training_state

print("🚀 **TRAINING LOOP READY**")
print("   - TPU-optimized training with bfloat16")
print("   - Polynomial LR scheduling with warmup")
print("   - Periodic evaluation and checkpointing")
print("   - Memory-efficient gradient computation")
print("   - Comprehensive progress tracking")

In [None]:
# 🎬 **START TRAINING** - Execute Training Loop
"""
Launch the complete training process
This will run for 40,000 iterations on TPU v5e-8
"""

# 🚀 Start training
print("🎬 **LAUNCHING CITYSCAPES TRAINING**")
print("=" * 60)

try:
    # Run training
    final_state = train_model()
    
    print("\n🎉 **TRAINING COMPLETED SUCCESSFULLY!**")
    print(f"   Final Best mIoU: {final_state['best_miou']:.4f}")
    print(f"   Total Iterations: {final_state['iteration']:,}")
    print(f"   Training Losses: {len(final_state['train_losses']):,} recorded")
    print(f"   Validation mIoUs: {len(final_state['val_mious'])} recorded")
    
except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print(f"   Current iteration: {training_state['iteration']:,}")
    print(f"   Current best mIoU: {training_state['best_miou']:.4f}")
    
except Exception as e:
    print(f"\n❌ Training failed: {e}")
    import traceback
    traceback.print_exc()
    
    # Save emergency checkpoint if possible
    try:
        emergency_path = os.path.join(CFG['OUTPUT_DIR'], 'emergency_checkpoint.pth')
        torch.save({
            'iteration': training_state['iteration'],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_state': training_state,
            'config': CFG,
            'error': str(e)
        }, emergency_path)
        print(f"💾 Emergency checkpoint saved: {emergency_path}")
    except:
        print("❌ Could not save emergency checkpoint")

print("=" * 60)

In [None]:
# 📊 **TRAINING VISUALIZATION & ANALYSIS** - Results Dashboard
"""
Visualize training progress and analyze results
Key features:
- Training loss curves
- Validation mIoU progression
- Learning rate schedule
- Per-class IoU breakdown
- Model performance analysis
"""

def plot_training_results(training_state):
    """
    📈 Plot comprehensive training results
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('🏙️ DeepLabv3+ Cityscapes Training Results', fontsize=16, y=0.98)
    
    # 1. Training Loss
    if training_state['train_losses']:
        axes[0, 0].plot(training_state['train_losses'], 'b-', alpha=0.7, linewidth=1)
        axes[0, 0].set_title('📉 Training Loss')
        axes[0, 0].set_xlabel('Iteration')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Add smoothed line
        if len(training_state['train_losses']) > 100:
            window = len(training_state['train_losses']) // 50
            smooth_loss = np.convolve(training_state['train_losses'], 
                                    np.ones(window)/window, mode='valid')
            smooth_x = np.arange(window//2, len(training_state['train_losses']) - window//2 + 1)
            axes[0, 0].plot(smooth_x, smooth_loss, 'r-', linewidth=2, label='Smoothed')
            axes[0, 0].legend()
    
    # 2. Validation mIoU
    if training_state['val_mious']:
        eval_iterations = [CFG['EVAL_INTERVAL'] * (i+1) for i in range(len(training_state['val_mious']))]
        axes[0, 1].plot(eval_iterations, training_state['val_mious'], 'g-o', linewidth=2, markersize=6)
        axes[0, 1].set_title('📊 Validation mIoU')
        axes[0, 1].set_xlabel('Iteration')
        axes[0, 1].set_ylabel('mIoU')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].set_ylim(0, 1)
        
        # Highlight best score
        best_idx = np.argmax(training_state['val_mious'])
        best_iter = eval_iterations[best_idx]
        best_miou = training_state['val_mious'][best_idx]
        axes[0, 1].scatter([best_iter], [best_miou], color='red', s=100, zorder=5)
        axes[0, 1].annotate(f'Best: {best_miou:.4f}', 
                          xy=(best_iter, best_miou), 
                          xytext=(10, 10), textcoords='offset points',
                          bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
    
    # 3. Learning Rate Schedule
    if training_state['learning_rates']:
        axes[1, 0].plot(training_state['learning_rates'], 'purple', linewidth=2)
        axes[1, 0].set_title('📅 Learning Rate Schedule')
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].grid(True, alpha=0.3)
        axes[1, 0].set_yscale('log')
    
    # 4. Training Summary
    axes[1, 1].axis('off')
    summary_text = f\"\"\"\n🎯 Training Summary:\n\n• Total Iterations: {training_state['iteration']:,}\n• Best mIoU: {training_state['best_miou']:.4f} ({training_state['best_miou']*100:.2f}%)\n• Final Loss: {training_state['train_losses'][-1]:.4f if training_state['train_losses'] else 'N/A'}\n• Dataset: Cityscapes (19 classes)\n• Resolution: {CFG['CROP_SIZE']}×{CFG['CROP_SIZE']}\n• Batch Size: {CFG['BATCH_SIZE']} per core\n• TPU Cores: {CFG['TPU_CORES']}\n• Mixed Precision: {CFG['MIXED_PRECISION']}\n• Backbone: {CFG['BACKBONE']}\n\"\"\"\n    axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, \n                    fontsize=11, verticalalignment='top', \n                    bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8))\n    \n    plt.tight_layout()\n    plt.show()\n    \n    # Save plot\n    plot_path = os.path.join(CFG['OUTPUT_DIR'], 'training_results.png')\n    fig.savefig(plot_path, dpi=150, bbox_inches='tight')\n    print(f\"📊 Training plot saved: {plot_path}\")\n\ndef analyze_final_results(training_state):\n    \"\"\"📋 Analyze and summarize final training results\"\"\"\n    print(\"\\n📋 **FINAL TRAINING ANALYSIS**\")\n    print(\"=\" * 50)\n    \n    # Basic stats\n    print(f\"🎯 **Training Completed:**\")\n    print(f\"   Total Iterations: {training_state['iteration']:,}\")\n    print(f\"   Best mIoU: {training_state['best_miou']:.4f} ({training_state['best_miou']*100:.2f}%)\")\n    \n    if training_state['train_losses']:\n        final_loss = training_state['train_losses'][-1]\n        avg_loss = np.mean(training_state['train_losses'][-1000:])  # Last 1000 iterations\n        print(f\"   Final Loss: {final_loss:.4f}\")\n        print(f\"   Average Loss (last 1000): {avg_loss:.4f}\")\n    \n    if training_state['val_mious']:\n        print(f\"   Evaluations: {len(training_state['val_mious'])}\")\n        print(f\"   mIoU Improvement: {training_state['val_mious'][-1] - training_state['val_mious'][0]:.4f}\")\n    \n    # Model files\n    print(f\"\\n📁 **Output Files:**\")\n    output_files = os.listdir(CFG['OUTPUT_DIR'])\n    for file in sorted(output_files):\n        file_path = os.path.join(CFG['OUTPUT_DIR'], file)\n        size_mb = os.path.getsize(file_path) / (1024 * 1024)\n        print(f\"   {file}: {size_mb:.1f} MB\")\n    \n    print(\"=\" * 50)\n\n# 📊 Check if training has results to visualize\nif training_state['iteration'] > 0:\n    print(\"📊 **VISUALIZING TRAINING RESULTS**\")\n    plot_training_results(training_state)\n    analyze_final_results(training_state)\nelse:\n    print(\"📊 **NO TRAINING DATA TO VISUALIZE**\")\n    print(\"   Run the training cells above first!\")

In [None]:
# 🖼️ **INFERENCE & VISUALIZATION** - Model Predictions
"""
Inference pipeline for trained model
Key features:
- Load best model checkpoint
- Inference on validation samples
- Visualization of predictions vs ground truth
- Cityscapes color mapping for beautiful results
"""

# 🎨 Cityscapes Color Palette (for visualization)
CITYSCAPES_COLORS = [
    [128, 64, 128],   # road
    [244, 35, 232],   # sidewalk  
    [70, 70, 70],     # building
    [102, 102, 156],  # wall
    [190, 153, 153],  # fence
    [153, 153, 153],  # pole
    [250, 170, 30],   # traffic light
    [220, 220, 0],    # traffic sign
    [107, 142, 35],   # vegetation
    [152, 251, 152],  # terrain
    [70, 130, 180],   # sky
    [220, 20, 60],    # person
    [255, 0, 0],      # rider
    [0, 0, 142],      # car
    [0, 0, 70],       # truck
    [0, 60, 100],     # bus
    [0, 80, 100],     # train
    [0, 0, 230],      # motorcycle
    [119, 11, 32]     # bicycle
]

def load_best_model():
    """📥 Load the best trained model"""
    model_path = os.path.join(CFG['OUTPUT_DIR'], 'best_model.pth')
    
    if os.path.exists(model_path):
        print(f"📥 Loading best model: {model_path}")
        checkpoint = torch.load(model_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        best_miou = checkpoint.get('best_miou', 0.0)
        iteration = checkpoint.get('iteration', 0)
        
        print(f"✅ Model loaded successfully:")
        print(f"   Iteration: {iteration:,}")
        print(f"   Best mIoU: {best_miou:.4f}")
        
        model.eval()
        return True
    else:
        print(f"❌ Model not found: {model_path}")
        print("   Train the model first!")
        return False

def colorize_prediction(prediction):
    """🎨 Convert prediction to colored image"""
    h, w = prediction.shape
    colored = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_id, color in enumerate(CITYSCAPES_COLORS):
        mask = (prediction == class_id)
        colored[mask] = color
    
    return colored

def inference_on_samples(num_samples=4):
    """🔍 Run inference on validation samples"""
    if not load_best_model():
        return
    
    print(f"🔍 Running inference on {num_samples} validation samples...")
    
    # Get samples from validation set
    val_iter = iter(val_loader)
    
    with torch.no_grad():
        for sample_idx in range(min(num_samples, len(val_loader))):
            try:
                images, targets = next(val_iter)
                
                # Take first image from batch
                image = images[0:1]  # Keep batch dimension
                target = targets[0].cpu().numpy()
                
                # Move to device
                if not IS_TPU:
                    image = image.to(device)
                
                # Forward pass
                outputs = model(image)
                prediction = torch.argmax(outputs['out'], dim=1)[0].cpu().numpy()
                
                # Convert to visualization format
                original_img = images[0].cpu().numpy().transpose(1, 2, 0)
                # Denormalize image
                mean = np.array(CFG['MEAN'])
                std = np.array(CFG['STD'])
                original_img = (original_img * std + mean) * 255
                original_img = np.clip(original_img, 0, 255).astype(np.uint8)
                
                # Colorize masks
                target_colored = colorize_prediction(target)
                pred_colored = colorize_prediction(prediction)
                
                # Create visualization
                fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                fig.suptitle(f'🏙️ Sample {sample_idx + 1} - Cityscapes Inference', fontsize=14)
                
                # Original image
                axes[0].imshow(original_img)
                axes[0].set_title('📷 Original Image')
                axes[0].axis('off')
                
                # Ground truth
                axes[1].imshow(target_colored)
                axes[1].set_title('🎯 Ground Truth')
                axes[1].axis('off')
                
                # Prediction
                axes[2].imshow(pred_colored)
                axes[2].set_title('🤖 Prediction')  
                axes[2].axis('off')
                
                # Overlay
                alpha = 0.6
                overlay = (alpha * original_img + (1-alpha) * pred_colored).astype(np.uint8)
                axes[3].imshow(overlay)
                axes[3].set_title('🎨 Overlay')
                axes[3].axis('off')
                
                plt.tight_layout()
                plt.show()
                
                # Compute sample mIoU
                sample_miou = mIoUCalculator(CFG['NUM_CLASSES'], CFG['IGNORE_INDEX'])
                sample_miou.update(prediction, target)
                results = sample_miou.get_results()
                
                print(f"   Sample {sample_idx + 1} mIoU: {results['mIoU']:.4f}")
                
                # Save inference result
                inference_path = os.path.join(CFG['OUTPUT_DIR'], f'inference_sample_{sample_idx+1}.png')
                fig.savefig(inference_path, dpi=150, bbox_inches='tight')
                print(f"   💾 Saved: {inference_path}")
                
            except StopIteration:
                print(f"   ⚠️ Only {sample_idx} samples available")
                break
            except Exception as e:
                print(f"   ❌ Error processing sample {sample_idx + 1}: {e}")

# 🖼️ Create class legend
def create_class_legend():
    """🎨 Create color legend for Cityscapes classes"""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create color patches
    colors_normalized = [[c/255.0 for c in color] for color in CITYSCAPES_COLORS]
    
    # Display as grid
    n_cols = 4
    n_rows = (len(CITYSCAPES_CLASSES) + n_cols - 1) // n_cols
    
    for i, (class_name, color) in enumerate(zip(CITYSCAPES_CLASSES, colors_normalized)):
        row = i // n_cols
        col = i % n_cols
        
        # Create rectangle
        rect = plt.Rectangle((col, n_rows - row - 1), 0.8, 0.8, 
                           facecolor=color, edgecolor='black', linewidth=1)
        ax.add_patch(rect)
        
        # Add text
        ax.text(col + 0.4, n_rows - row - 0.5, f"{i}: {class_name}", 
               ha='center', va='center', fontsize=10, weight='bold')
    
    ax.set_xlim(0, n_cols)
    ax.set_ylim(0, n_rows)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('🏙️ Cityscapes Classes Color Legend', fontsize=16, pad=20)
    
    plt.tight_layout()
    plt.show()
    
    # Save legend
    legend_path = os.path.join(CFG['OUTPUT_DIR'], 'cityscapes_legend.png')
    fig.savefig(legend_path, dpi=150, bbox_inches='tight')
    print(f"🎨 Class legend saved: {legend_path}")

print("🖼️ **INFERENCE PIPELINE READY**")
print("   - Load best model checkpoint")
print("   - Inference on validation samples") 
print("   - Beautiful Cityscapes color visualization")
print("   - Per-sample mIoU computation")

# Show class legend
create_class_legend()

In [None]:
# 🚀 **RUN INFERENCE** - Generate Beautiful Predictions
"""
Execute inference on validation samples to see model results
"""

# 🔍 Run inference on samples
print("🔍 **STARTING INFERENCE**")
print("=" * 40)

try:
    # Run inference on 4 validation samples
    inference_on_samples(num_samples=4)
    
    print("\n✅ **INFERENCE COMPLETED SUCCESSFULLY!**")
    print("   Check the visualizations above and saved PNG files")
    
except Exception as e:
    print(f"\n❌ Inference failed: {e}")
    import traceback
    traceback.print_exc()

print("=" * 40)