# üèôÔ∏è DeepLabv3+ Cityscapes Reproduction on Kaggle

## **Project Overview**
This notebook reproduces **DeepLabv3+ semantic segmentation** results on the **Cityscapes dataset** using PyTorch and Torchvision. Adapted from our PASCAL VOC implementation with critical modifications for urban scene understanding.

### **üîß Key Adaptations for Cityscapes:**
- **NUM_CLASSES**: 19 (vs 21 for PASCAL VOC)
- **RESOLUTION**: 769√ó769 (vs 513√ó513 for PASCAL VOC) 
- **LABEL MAPPING**: Critical remapping from labelIds ‚Üí trainIds
- **TRAINING ITERATIONS**: 60,000 (vs 30,000 for PASCAL VOC)
- **IGNORE INDEX**: 255 for unlabeled pixels

### **üìä Dataset Info:**
- **Training**: ~3,000 images (fine annotations)
- **Validation**: ~500 images (fine annotations)
- **Classes**: 19 semantic classes (road, sidewalk, building, wall, fence, pole, traffic light, traffic sign, vegetation, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle)

---

In [None]:
# ? Install Required Dependencies
!pip install albumentations==1.3.1 kagglehub --quiet

import warnings
warnings.filterwarnings('ignore')

# üì¶ Import Essential Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models.segmentation as models
import torchvision.transforms as transforms

# üñºÔ∏è Image Processing & Augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from PIL import Image
import numpy as np

# üìÅ File System & Utilities  
import os
import glob
import json
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
import kagglehub

# ?üîß TPU v5e-8 Setup
try:
    # TPU Setup for Kaggle
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    
    device = xm.xla_device()
    print(f"üöÄ Using TPU v5e-8: {device}")
    
    # Modern torch_xla API (v2.0+)
    if hasattr(xm, 'xrt_world_size'):
        # Legacy API
        world_size = xm.xrt_world_size()
        ordinal = xm.get_ordinal()
    elif hasattr(xm, 'get_world_size'):
        # Newer API
        world_size = xm.get_world_size()
        ordinal = xm.get_ordinal()
    else:
        # Latest API - use environment variables
        import os
        world_size = int(os.environ.get('WORLD_SIZE', '8'))
        ordinal = int(os.environ.get('RANK', '0'))
    
    print(f"   TPU Cores: {world_size}")
    print(f"   Current Core: {ordinal}")
    
    IS_TPU = True
    
except ImportError:
    # Fallback to GPU/CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
    print(f"üöÄ TPU not available, using: {device}")
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    IS_TPU = False
    world_size = 1  # Set fallback world_size for non-TPU
    ordinal = 0

In [None]:
# üèôÔ∏è **CITYSCAPES CONFIGURATION** - Optimized for TPU v5e-8
CFG = {
    # üìä Dataset Configuration (CITYSCAPES SPECIFIC)
    'NUM_CLASSES': 19,              # 19 semantic classes for Cityscapes (vs 21 for PASCAL VOC)
    'IGNORE_INDEX': 255,            # Standard ignore index for unlabeled pixels
    'CROP_SIZE': 769,               # Higher resolution for urban scenes (vs 513 for PASCAL VOC)  
    'BASE_SIZE': 769,               # Base size for resize operations
    
    # üéØ Training Configuration (TPU v5e-8 OPTIMIZED)
    'BATCH_SIZE': 32,               # Higher batch size leveraging TPU v5e-8 HBM (vs 2 for GPU)
    'NUM_WORKERS': 8,               # More workers for TPU data loading efficiency
    'MAX_ITERATIONS': 40000,        # Reduced due to larger effective batch size
    'EVAL_INTERVAL': 1000,          # More frequent evaluation with faster TPU
    'SAVE_INTERVAL': 2000,          # More frequent checkpoints
    
    # üîß Optimization Configuration (TPU TUNED)
    'LEARNING_RATE': 0.08,          # Higher LR for larger batch size (linear scaling: 0.01 * 8)
    'WEIGHT_DECAY': 5e-4,           # L2 regularization
    'MOMENTUM': 0.9,                # SGD momentum
    'POWER': 0.9,                   # Polynomial LR decay power
    'WARMUP_ITERATIONS': 1000,      # LR warmup for large batch training
    
    # üìÅ Path Configuration (KAGGLE SPECIFIC)
    'DATASET_ROOT': '/kaggle/input/cityscapes',  # Input dataset path
    'OUTPUT_DIR': '/kaggle/working',              # Output directory for models/logs
    
    # üñºÔ∏è Data Augmentation Configuration
    'RANDOM_SCALE_MIN': 0.5,        # Minimum scale for random scaling
    'RANDOM_SCALE_MAX': 2.0,        # Maximum scale for random scaling  
    'HORIZONTAL_FLIP_PROB': 0.5,    # Probability for horizontal flip
    
    # üß† Model Configuration (TPU OPTIMIZED)
    'BACKBONE': 'resnet101',         # ResNet-101 backbone
    'PRETRAINED': True,              # Use ImageNet pretrained weights
    'MIXED_PRECISION': True,         # bfloat16 for TPU (instead of fp16)
    'GRADIENT_ACCUMULATION': 1,      # No accumulation needed with larger batch size
    'TPU_CORES': 8,                  # TPU v5e-8 has 8 cores
    
    # üìä ImageNet Normalization (Standard for pretrained models)
    'MEAN': [0.485, 0.456, 0.406],
    'STD': [0.229, 0.224, 0.225]
}

print("üèôÔ∏è **CITYSCAPES CONFIGURATION - TPU v5e-8 OPTIMIZED**")
print(f"   Classes: {CFG['NUM_CLASSES']}")
print(f"   Resolution: {CFG['CROP_SIZE']}√ó{CFG['CROP_SIZE']}")  
print(f"   Batch Size: {CFG['BATCH_SIZE']} per core")
print(f"   TPU Cores: {CFG['TPU_CORES']} (Total batch: {CFG['BATCH_SIZE'] * CFG['TPU_CORES']})")
print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
print(f"   Learning Rate: {CFG['LEARNING_RATE']} (scaled for large batch)")
print(f"   Dataset Path: {CFG['DATASET_ROOT']}")
print(f"   Mixed Precision: {CFG['MIXED_PRECISION']} (bfloat16)")

In [None]:
# üè∑Ô∏è **CITYSCAPES LABEL MAPPING** - Critical for Correct Training
"""
Cityscapes uses complex label system:
- labelIds: Original labels in *_labelIds.png files (0-33)  
- trainIds: Training labels we need (0-18 + 255 for ignore)

This mapping is ESSENTIAL for correct training!
"""

# üéØ Official Cityscapes Label Mapping (labelId -> trainId)
CITYSCAPES_LABEL_MAP = {
    # Road & Ground
    7: 0,    # road
    8: 1,    # sidewalk
    11: 2,   # building
    12: 3,   # wall
    13: 4,   # fence
    17: 5,   # pole
    19: 6,   # traffic light
    20: 7,   # traffic sign
    21: 8,   # vegetation
    22: 9,   # terrain
    23: 10,  # sky
    24: 11,  # person
    25: 12,  # rider
    26: 13,  # car
    27: 14,  # truck
    28: 15,  # bus
    31: 16,  # train
    32: 17,  # motorcycle
    33: 18,  # bicycle
}

# üìã Class Names for Reference
CITYSCAPES_CLASSES = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
    'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
    'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]

def remap_labels(label_img):
    """
    üîÑ Remap Cityscapes labelIds to trainIds
    
    Args:
        label_img: numpy array with original labelIds
    Returns:
        remapped_img: numpy array with trainIds (0-18) and ignore_index (255)
    """
    # Initialize with ignore index (255)
    remapped = np.full_like(label_img, CFG['IGNORE_INDEX'], dtype=np.uint8)
    
    # Apply mapping for valid classes
    for label_id, train_id in CITYSCAPES_LABEL_MAP.items():
        mask = (label_img == label_id)
        remapped[mask] = train_id
    
    return remapped

# üß™ Test the mapping function
print("üè∑Ô∏è **CITYSCAPES LABEL MAPPING LOADED**")
print(f"   Valid Classes: {len(CITYSCAPES_CLASSES)}")
print(f"   Label Mappings: {len(CITYSCAPES_LABEL_MAP)}")
print(f"   Ignore Index: {CFG['IGNORE_INDEX']}")
print("\nüìã **Class List:**")
for i, class_name in enumerate(CITYSCAPES_CLASSES):
    print(f"   {i:2d}: {class_name}")
    
# Test mapping with dummy data
test_labels = np.array([7, 24, 26, 0, 255])  # road, person, car, void, void
test_remapped = remap_labels(test_labels)
print(f"\nüß™ **Mapping Test:**")
print(f"   Original: {test_labels}")
print(f"   Remapped: {test_remapped}")  # Should be [0, 11, 13, 255, 255]

In [None]:
# üì• Download Cityscapes Dataset
"""
Download and setup Cityscapes dataset from Kaggle
Expected structure:
/kaggle/input/cityscapes/
‚îú‚îÄ‚îÄ leftImg8bit/
‚îÇ   ‚îú‚îÄ‚îÄ train/
‚îÇ   ‚îî‚îÄ‚îÄ val/
‚îî‚îÄ‚îÄ gtFine/
    ‚îú‚îÄ‚îÄ train/
    ‚îî‚îÄ‚îÄ val/
"""

import kagglehub

# Download Cityscapes dataset
try:
    print("üì• Downloading Cityscapes dataset...")
    
    # Download latest version
    path = kagglehub.dataset_download("dansbecker/cityscapes-image-pairs")
    print("Path to dataset files:", path)
    
    # Update CFG with actual dataset path
    CFG['DATASET_ROOT'] = path
    print(f"üîÑ Updated dataset root: {CFG['DATASET_ROOT']}")
    
    # Verify dataset structure
    if os.path.exists(path):
        print("\nüìÅ **Dataset Structure:**")
        for item in sorted(os.listdir(path)):
            item_path = os.path.join(path, item)
            if os.path.isdir(item_path):
                print(f"   üìÇ {item}/")
                # Show subdirectories
                try:
                    subdirs = [d for d in os.listdir(item_path) if os.path.isdir(os.path.join(item_path, d))]
                    for subdir in sorted(subdirs)[:5]:  # Show first 5 subdirs
                        print(f"      üìÇ {subdir}/")
                        # Show sample files in first subdir
                        if subdir == sorted(subdirs)[0]:
                            subdir_path = os.path.join(item_path, subdir)
                            files = [f for f in os.listdir(subdir_path) if os.path.isfile(os.path.join(subdir_path, f))]
                            for file in sorted(files)[:3]:  # Show first 3 files
                                print(f"         üìÑ {file}")
                    if len(subdirs) > 5:
                        print(f"      ... and {len(subdirs)-5} more")
                except Exception as e:
                    print(f"      Error reading subdirs: {e}")
            else:
                print(f"   üìÑ {item}")
        
        # Check for common Cityscapes patterns
        possible_structures = [
            "leftImg8bit",
            "gtFine", 
            "leftImg8bit_trainvaltest",
            "gtFine_trainvaltest",
            "images",
            "labels",
            "train",
            "val"
        ]
        
        found_structures = []
        for struct in possible_structures:
            if os.path.exists(os.path.join(path, struct)):
                found_structures.append(struct)
        
        if found_structures:
            print(f"\nüîç **Found Cityscapes structures:** {found_structures}")
        else:
            print(f"\n‚ö†Ô∏è  **Unknown dataset structure** - will try adaptive loading")
    
except Exception as e:
    print(f"‚ùå Error downloading dataset: {e}")
    print("üí° Using default path: /kaggle/input/cityscapes")
    CFG['DATASET_ROOT'] = "/kaggle/input/cityscapes"

In [None]:
# üèôÔ∏è **CITYSCAPES DATASET CLASS** - Custom Dataset Implementation
class CityscapesDataset(Dataset):
    """
    üèôÔ∏è Custom Cityscapes Dataset for semantic segmentation
    
    Key Features:
    - Handles leftImg8bit (RGB images) and gtFine (segmentation masks)
    - Applies critical label remapping (labelIds -> trainIds)
    - Supports different augmentations for train/val
    - Optimized for TPU v5e-8 with efficient data loading
    - Adaptive loading for different dataset structures
    """
    
    def __init__(self, root_dir, split='train', transforms=None):
        """
        Args:
            root_dir: Path to cityscapes dataset
            split: 'train' or 'val'
            transforms: Albumentations transforms
        """
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms
        
        # üîç Find all image files with adaptive structure detection
        self.image_files = []
        self.label_files = []
        
        print(f"üîç Searching for {split} images in: {root_dir}")
        
        # Method 1: Standard Cityscapes structure
        self.images_dir = os.path.join(root_dir, 'leftImg8bit', split)
        self.labels_dir = os.path.join(root_dir, 'gtFine', split)
        
        if os.path.exists(self.images_dir) and os.path.exists(self.labels_dir):
            print(f"   ‚úÖ Found standard Cityscapes structure")
            self._load_standard_structure()
        else:
            print(f"   ‚ö†Ô∏è  Standard structure not found, trying alternatives...")
            
            # Method 2: Check for cityscapes_data subdirectory
            cityscapes_subdir = os.path.join(root_dir, 'cityscapes_data')
            if os.path.exists(cityscapes_subdir):
                print(f"   üîç Checking cityscapes_data subdirectory...")
                
                # Try standard structure inside cityscapes_data
                self.images_dir = os.path.join(cityscapes_subdir, 'leftImg8bit', split)
                self.labels_dir = os.path.join(cityscapes_subdir, 'gtFine', split)
                
                if os.path.exists(self.images_dir) and os.path.exists(self.labels_dir):
                    print(f"   ‚úÖ Found standard structure in cityscapes_data/")
                    self._load_standard_structure()
                else:
                    # Try train/val folders directly in cityscapes_data
                    split_dir = os.path.join(cityscapes_subdir, split)
                    if os.path.exists(split_dir):
                        print(f"   üîç Found direct {split} folder in cityscapes_data/")
                        self._load_split_directory(split_dir)
                    else:
                        print(f"   üîç Searching recursively in cityscapes_data/")
                        self._load_recursive_search(cityscapes_subdir)
            else:
                # Method 3: Check direct train/val folders
                split_dir = os.path.join(root_dir, split)
                if os.path.exists(split_dir):
                    print(f"   üîç Found direct {split} folder")
                    self._load_split_directory(split_dir)
                else:
                    # Method 4: Recursive search
                    print(f"   üîç Performing recursive search...")
                    self._load_recursive_search(root_dir)
        
        print(f"üèôÔ∏è **Cityscapes {split.upper()} Dataset:**")
        print(f"   Images: {len(self.image_files)}")
        print(f"   Labels: {len(self.label_files)}")
        
        if len(self.image_files) == 0:
            print("‚ùå No images found! Debugging dataset structure...")
            self._debug_dataset_structure()
        elif len(self.image_files) != len(self.label_files):
            print(f"‚ö†Ô∏è  Mismatch: {len(self.image_files)} images vs {len(self.label_files)} labels")
    
    def _debug_dataset_structure(self):
        """Debug function to show what's actually in the dataset"""
        print("üîç **DEBUGGING DATASET STRUCTURE:**")
        
        def explore_directory(path, max_depth=3, current_depth=0):
            if current_depth > max_depth or not os.path.exists(path):
                return
                
            try:
                items = sorted(os.listdir(path))[:20]  # Show first 20 items
                indent = "  " * current_depth
                
                for item in items:
                    item_path = os.path.join(path, item)
                    if os.path.isdir(item_path):
                        print(f"{indent}üìÇ {item}/")
                        if current_depth < max_depth:
                            explore_directory(item_path, max_depth, current_depth + 1)
                    else:
                        # Show file extension and size
                        size = os.path.getsize(item_path)
                        size_str = f"({size/1024:.1f}KB)" if size < 1024*1024 else f"({size/1024/1024:.1f}MB)"
                        print(f"{indent}üìÑ {item} {size_str}")
                        
                if len(os.listdir(path)) > 20:
                    print(f"{indent}... and {len(os.listdir(path)) - 20} more items")
                    
            except Exception as e:
                print(f"{indent}‚ùå Error reading {path}: {e}")
        
        print(f"üìÅ Root directory: {self.root_dir}")
        explore_directory(self.root_dir)
        
        # Look for any image files
        print(f"\nüîç **Searching for ANY image files:**")
        common_extensions = ['.png', '.jpg', '.jpeg']
        found_images = []
        
        for root, dirs, files in os.walk(self.root_dir):
            for file in files:
                if any(file.lower().endswith(ext) for ext in common_extensions):
                    found_images.append(os.path.join(root, file))
                    if len(found_images) >= 10:  # Limit to first 10
                        break
            if len(found_images) >= 10:
                break
        
        if found_images:
            print(f"üì∑ Found {len(found_images)} image files (showing first 10):")
            for img in found_images:
                rel_path = os.path.relpath(img, self.root_dir)
                print(f"   üìÑ {rel_path}")
        else:
            print("‚ùå No image files found anywhere in the dataset!")
    
    def _load_standard_structure(self):
        """Load from standard Cityscapes structure with city folders"""
        for city in sorted(os.listdir(self.images_dir)):
            city_img_dir = os.path.join(self.images_dir, city)
            city_lbl_dir = os.path.join(self.labels_dir, city)
            
            if os.path.isdir(city_img_dir):
                # Get image files
                img_files = glob.glob(os.path.join(city_img_dir, '*_leftImg8bit.png'))
                
                for img_file in sorted(img_files):
                    # Corresponding label file
                    basename = os.path.basename(img_file).replace('_leftImg8bit.png', '')
                    label_file = os.path.join(city_lbl_dir, f'{basename}_gtFine_labelIds.png')
                    
                    if os.path.exists(label_file):
                        self.image_files.append(img_file)
                        self.label_files.append(label_file)
    
    def _load_split_directory(self, split_dir):
        """Load from a split directory (train/ or val/)"""
        print(f"   üìÅ Scanning directory: {split_dir}")
        
        # Look for image and label patterns
        all_files = []
        for root, dirs, files in os.walk(split_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    all_files.append(os.path.join(root, file))
        
        print(f"   üìÑ Found {len(all_files)} image files")
        
        # Separate images and labels based on filename patterns
        potential_images = []
        potential_labels = []
        
        for file_path in all_files:
            filename = os.path.basename(file_path).lower()
            
            # Image patterns
            if any(pattern in filename for pattern in ['leftimg8bit', 'img', 'image', 'rgb']):
                potential_images.append(file_path)
            # Label patterns  
            elif any(pattern in filename for pattern in ['labelids', 'gtfine', 'label', 'mask', 'seg']):
                potential_labels.append(file_path)
            # If no clear pattern, assume images are larger files
            else:
                try:
                    size = os.path.getsize(file_path)
                    if size > 50000:  # Assume files > 50KB are images
                        potential_images.append(file_path)
                    else:
                        potential_labels.append(file_path)
                except:
                    potential_images.append(file_path)  # Default to image
        
        print(f"   üñºÔ∏è  Potential images: {len(potential_images)}")
        print(f"   üè∑Ô∏è  Potential labels: {len(potential_labels)}")
        
        # Show some examples
        if potential_images:
            print(f"   üìù Image examples: {[os.path.basename(f) for f in potential_images[:3]]}")
        if potential_labels:
            print(f"   ? Label examples: {[os.path.basename(f) for f in potential_labels[:3]]}")
        
        # Match images with labels
        for img_file in sorted(potential_images):
            img_basename = os.path.basename(img_file)
            
            # Try to find corresponding label
            best_label = None
            
            # Method 1: Try exact matching patterns
            for label_file in potential_labels:
                label_basename = os.path.basename(label_file)
                
                # Remove common suffixes and compare
                img_core = img_basename.lower().replace('_leftimg8bit', '').replace('_img', '').replace('_image', '').split('.')[0]
                label_core = label_basename.lower().replace('_gtfine_labelids', '').replace('_label', '').replace('_mask', '').split('.')[0]
                
                if img_core == label_core:
                    best_label = label_file
                    break
            
            # Method 2: If no exact match, try same directory
            if not best_label and potential_labels:
                img_dir = os.path.dirname(img_file)
                for label_file in potential_labels:
                    if os.path.dirname(label_file) == img_dir:
                        best_label = label_file
                        break
            
            # Method 3: Use first available label (fallback)
            if not best_label and potential_labels:
                best_label = potential_labels[0]
                potential_labels.remove(best_label)  # Remove so it's not reused
            
            if best_label:
                self.image_files.append(img_file)
                self.label_files.append(best_label)
    
    def _load_recursive_search(self, search_dir):
        """Recursive search for image-label pairs"""
        print(f"   üîç Recursive search in: {search_dir}")
        self._load_split_directory(search_dir)  # Use same logic as split directory
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Get image and label pair with preprocessing
        """
        if idx >= len(self.image_files):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.image_files)}")
        
        # üìñ Load image and label
        image_path = self.image_files[idx] 
        label_path = self.label_files[idx]
        
        try:
            # Load image (RGB)
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Could not load image: {image_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load label (grayscale)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            if label is None:
                raise ValueError(f"Could not load label: {label_path}")
            
            # üîÑ Critical: Remap labels (labelIds -> trainIds)
            label = remap_labels(label)
            
            # üñºÔ∏è Apply transforms
            if self.transforms:
                transformed = self.transforms(image=image, mask=label)
                image = transformed['image']
                label = transformed['mask']
            
            return image, label.long()
            
        except Exception as e:
            print(f"‚ùå Error loading sample {idx}: {e}")
            print(f"   Image path: {image_path}")
            print(f"   Label path: {label_path}")
            raise
    
    def get_sample_info(self, idx):
        """Get file paths for debugging"""
        return {
            'image_path': self.image_files[idx] if idx < len(self.image_files) else None,
            'label_path': self.label_files[idx] if idx < len(self.label_files) else None
        }

In [None]:
# üîç **QUICK DATASET STRUCTURE INSPECTION**
"""
Let's examine what's actually in the downloaded Cityscapes dataset
"""

def inspect_dataset_structure(root_path, max_files=10):
    """Inspect and display dataset structure"""
    print(f"üîç **INSPECTING DATASET:** {root_path}")
    print("=" * 60)
    
    if not os.path.exists(root_path):
        print(f"‚ùå Path does not exist: {root_path}")
        return
    
    def explore_path(path, depth=0, max_depth=4):
        if depth > max_depth:
            return
            
        indent = "  " * depth
        try:
            items = os.listdir(path)
            dirs = [item for item in items if os.path.isdir(os.path.join(path, item))]
            files = [item for item in items if os.path.isfile(os.path.join(path, item))]
            
            # Show directories first
            for dir_name in sorted(dirs)[:10]:  # Limit to 10 dirs
                dir_path = os.path.join(path, dir_name)
                file_count = 0
                try:
                    file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
                except:
                    pass
                print(f"{indent}üìÇ {dir_name}/ ({file_count} files)")
                if depth < max_depth:
                    explore_path(dir_path, depth + 1, max_depth)
            
            # Show some sample files
            if files and depth <= 2:
                image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                other_files = [f for f in files if not f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                
                if image_files:
                    print(f"{indent}üñºÔ∏è  Image files ({len(image_files)} total):")
                    for file_name in sorted(image_files)[:max_files]:
                        file_path = os.path.join(path, file_name)
                        size = os.path.getsize(file_path) / 1024  # KB
                        print(f"{indent}   üìÑ {file_name} ({size:.1f}KB)")
                    if len(image_files) > max_files:
                        print(f"{indent}   ... and {len(image_files) - max_files} more image files")
                
                if other_files and len(other_files) <= 20:
                    print(f"{indent}üìÑ Other files: {other_files}")
                elif other_files:
                    print(f"{indent}üìÑ Other files ({len(other_files)} total): {other_files[:5]}...")
                    
        except Exception as e:
            print(f"{indent}‚ùå Error reading {path}: {e}")
    
    explore_path(root_path)
    
    # Search for specific Cityscapes patterns
    print(f"\nüîç **SEARCHING FOR CITYSCAPES PATTERNS:**")
    cityscapes_patterns = ['leftImg8bit', 'gtFine', 'labelIds', 'train', 'val']
    
    for pattern in cityscapes_patterns:
        found_paths = []
        for root, dirs, files in os.walk(root_path):
            # Check directories
            for dir_name in dirs:
                if pattern.lower() in dir_name.lower():
                    found_paths.append(os.path.join(root, dir_name))
            # Check files  
            for file_name in files:
                if pattern.lower() in file_name.lower():
                    found_paths.append(os.path.join(root, file_name))
                    if len(found_paths) >= 5:  # Limit results
                        break
            if len(found_paths) >= 5:
                break
        
        if found_paths:
            print(f"   üéØ '{pattern}' found in:")
            for path in found_paths[:5]:
                rel_path = os.path.relpath(path, root_path)
                print(f"      üìç {rel_path}")
        else:
            print(f"   ‚ùå '{pattern}' not found")

# Inspect the actual dataset
if 'CFG' in globals() and 'DATASET_ROOT' in CFG:
    inspect_dataset_structure(CFG['DATASET_ROOT'])
else:
    print("‚ö†Ô∏è CFG not found, using default path")
    inspect_dataset_structure("/kaggle/input/cityscapes-image-pairs")

In [None]:
# üéØ **CITYSCAPES DATASET FIX** - Handle Specific Structure
"""
Fix the dataset loading for the specific structure we found:
cityscapes_data/train/ and cityscapes_data/val/
"""

def quick_test_dataset_loading():
    """Test dataset loading with the actual structure"""
    dataset_root = CFG['DATASET_ROOT']
    
    print(f"üß™ **TESTING DATASET LOADING:**")
    print(f"   Dataset root: {dataset_root}")
    
    # Check the specific paths we know exist
    cityscapes_data_dir = os.path.join(dataset_root, 'cityscapes_data')
    train_dir = os.path.join(cityscapes_data_dir, 'train')
    val_dir = os.path.join(cityscapes_data_dir, 'val')
    
    print(f"   Cityscapes data dir exists: {os.path.exists(cityscapes_data_dir)}")
    print(f"   Train dir exists: {os.path.exists(train_dir)}")
    print(f"   Val dir exists: {os.path.exists(val_dir)}")
    
    # Count files in each directory
    for split, split_dir in [('train', train_dir), ('val', val_dir)]:
        if os.path.exists(split_dir):
            all_files = []
            for root, dirs, files in os.walk(split_dir):
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        all_files.append(os.path.join(root, file))
            
            print(f"   {split.capitalize()} directory: {len(all_files)} image files")
            
            # Show some examples
            if all_files:
                print(f"   {split.capitalize()} examples:")
                for i, file_path in enumerate(sorted(all_files)[:5]):
                    filename = os.path.basename(file_path)
                    size = os.path.getsize(file_path) / 1024  # KB
                    print(f"      {i+1}. {filename} ({size:.1f}KB)")
                
                # Try to identify image/label patterns
                potential_images = []
                potential_labels = []
                
                for file_path in all_files:
                    filename = os.path.basename(file_path).lower()
                    size = os.path.getsize(file_path)
                    
                    # Classify based on filename patterns and size
                    if 'label' in filename or 'mask' in filename or 'seg' in filename:
                        potential_labels.append(file_path)
                    elif 'img' in filename or 'image' in filename or size > 100000:  # > 100KB likely images
                        potential_images.append(file_path)
                    else:
                        # Default classification by size
                        if size > 50000:  # > 50KB
                            potential_images.append(file_path)
                        else:
                            potential_labels.append(file_path)
                
                print(f"   Classified as images: {len(potential_images)}")
                print(f"   Classified as labels: {len(potential_labels)}")
                
                if potential_images:
                    print(f"   Sample images: {[os.path.basename(f) for f in potential_images[:3]]}")
                if potential_labels:
                    print(f"   Sample labels: {[os.path.basename(f) for f in potential_labels[:3]]}")
    
    print("\nüîß Now let's test the actual dataset class...")
    
    # Test with a simple dataset creation
    try:
        # Use the enhanced dataset path
        enhanced_root = cityscapes_data_dir  # Point directly to cityscapes_data
        
        print(f"üß™ Testing dataset with root: {enhanced_root}")
        
        # Test train dataset
        test_train_dataset = CityscapesDataset(
            root_dir=enhanced_root,
            split='train',
            transforms=None  # No transforms for testing
        )
        
        # Test val dataset  
        test_val_dataset = CityscapesDataset(
            root_dir=enhanced_root,
            split='val',
            transforms=None
        )
        
        print(f"\n‚úÖ **DATASET TEST RESULTS:**")
        print(f"   Train samples: {len(test_train_dataset)}")
        print(f"   Val samples: {len(test_val_dataset)}")
        
        if len(test_train_dataset) > 0:
            print(f"   ‚úÖ Successfully found training data!")
            # Test loading one sample
            try:
                sample_img, sample_label = test_train_dataset[0]
                print(f"   Sample image shape: {sample_img.shape}")
                print(f"   Sample label shape: {sample_label.shape}")
                print(f"   Sample label unique values: {torch.unique(sample_label).tolist()}")
            except Exception as e:
                print(f"   ‚ö†Ô∏è Error loading sample: {e}")
        
        # Update the global CFG to use the correct path
        if len(test_train_dataset) > 0 or len(test_val_dataset) > 0:
            CFG['DATASET_ROOT'] = enhanced_root
            print(f"\nüîÑ Updated CFG['DATASET_ROOT'] to: {CFG['DATASET_ROOT']}")
            return True
        else:
            print(f"\n‚ùå No data found even with enhanced path")
            return False
            
    except Exception as e:
        print(f"‚ùå Dataset test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

# Run the test
success = quick_test_dataset_loading()

In [None]:
# üñºÔ∏è **DATA AUGMENTATION PIPELINE** - Cityscapes Specific
"""
Augmentation strategy adapted for Cityscapes:
- Higher resolution (769x769) for urban scene detail
- Scale range 0.5-2.0 for variety
- Careful padding with ignore_index for masks
"""

def get_train_transforms():
    """üèãÔ∏è Training augmentations for Cityscapes"""
    return A.Compose([
        # üìè Random scaling (0.5x to 2.0x)
        A.RandomScale(scale_limit=(CFG['RANDOM_SCALE_MIN']-1, CFG['RANDOM_SCALE_MAX']-1), 
                      interpolation=cv2.INTER_LINEAR, p=1.0),
        
        # üìê Pad if needed to ensure minimum size
        A.PadIfNeeded(min_height=CFG['CROP_SIZE'], min_width=CFG['CROP_SIZE'],
                      border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=CFG['IGNORE_INDEX']),
        
        # ‚úÇÔ∏è Random crop to target size
        A.RandomCrop(height=CFG['CROP_SIZE'], width=CFG['CROP_SIZE']),
        
        # üîÑ Horizontal flip
        A.HorizontalFlip(p=CFG['HORIZONTAL_FLIP_PROB']),
        
        # üé® Color augmentations (optional - comment out if too aggressive)
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
        
        # üìä Normalization & tensor conversion
        A.Normalize(mean=CFG['MEAN'], std=CFG['STD']),
        ToTensorV2()
    ])

def get_val_transforms():
    """üß™ Validation transforms (minimal processing)"""
    return A.Compose([
        # üìè Resize to target size (can use different strategy)
        A.Resize(height=CFG['CROP_SIZE'], width=CFG['CROP_SIZE'], 
                 interpolation=cv2.INTER_LINEAR),
        
        # üìä Normalization & tensor conversion  
        A.Normalize(mean=CFG['MEAN'], std=CFG['STD']),
        ToTensorV2()
    ])

# üß™ Test transforms
print("üñºÔ∏è **AUGMENTATION PIPELINE READY**")
print(f"   Crop Size: {CFG['CROP_SIZE']}√ó{CFG['CROP_SIZE']}")
print(f"   Scale Range: {CFG['RANDOM_SCALE_MIN']:.1f} - {CFG['RANDOM_SCALE_MAX']:.1f}")
print(f"   Horizontal Flip: {CFG['HORIZONTAL_FLIP_PROB']*100:.0f}%")
print(f"   Ignore Index: {CFG['IGNORE_INDEX']} (for padding)")

# Test with dummy data
test_img = np.random.randint(0, 255, (1024, 2048, 3), dtype=np.uint8)
test_mask = np.random.randint(0, 19, (1024, 2048), dtype=np.uint8)

train_transform = get_train_transforms()
val_transform = get_val_transforms()

try:
    # Test transforms
    train_result = train_transform(image=test_img, mask=test_mask)
    val_result = val_transform(image=test_img, mask=test_mask)
    
    print(f"\n‚úÖ **Transform Test Passed:**")
    print(f"   Train Output: {train_result['image'].shape}, {train_result['mask'].shape}")
    print(f"   Val Output: {val_result['image'].shape}, {val_result['mask'].shape}")
    print(f"   Train Image Range: [{train_result['image'].min():.3f}, {train_result['image'].max():.3f}]")
    
except Exception as e:
    print(f"‚ùå Transform Test Failed: {e}")

In [None]:
# üß† **DEEPLABV3+ MODEL CREATION** - TPU Optimized
"""
Create DeepLabv3+ model with ResNet-101 backbone
Key modifications for Cityscapes:
- NUM_CLASSES = 19 (vs 21 for PASCAL VOC)
- TPU-optimized with bfloat16 mixed precision
- Proper classifier head modification
"""

def create_deeplabv3plus_model():
    """
    üèóÔ∏è Create DeepLabv3+ model for Cityscapes
    """
    print("üß† Creating DeepLabv3+ model...")
    
    # Load pretrained DeepLabv3+ with ResNet-101
    model = models.deeplabv3_resnet101(
        pretrained=CFG['PRETRAINED'],
        progress=True,
        num_classes=21  # Start with PASCAL VOC pretrained
    )
    
    # üîß Modify classifier for Cityscapes (19 classes)
    # DeepLabv3+ has classifier and aux_classifier
    model.classifier[4] = nn.Conv2d(
        in_channels=256,
        out_channels=CFG['NUM_CLASSES'], 
        kernel_size=1
    )
    
    if hasattr(model, 'aux_classifier') and model.aux_classifier is not None:
        model.aux_classifier[4] = nn.Conv2d(
            in_channels=256,
            out_channels=CFG['NUM_CLASSES'],
            kernel_size=1
        )
    
    print(f"‚úÖ Model created with {CFG['NUM_CLASSES']} classes")
    
    # üìä Model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    
    return model

# üèóÔ∏è Create model
model = create_deeplabv3plus_model()

# ? Check TPU availability (from previous cell)
try:
    import torch_xla.core.xla_model as xm
    IS_TPU = True
    device = xm.xla_device()
except ImportError:
    IS_TPU = False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ?üöÄ Move to TPU device
if IS_TPU:
    model = model.to(device)
    print(f"üì± Model moved to TPU: {device}")
else:
    model = model.to(device)
    print(f"üì± Model moved to device: {device}")

# üß™ Test model with dummy input
try:
    dummy_input = torch.randn(1, 3, CFG['CROP_SIZE'], CFG['CROP_SIZE']).to(device)
    
    if IS_TPU:
        # TPU requires different handling
        model.eval()
        with torch.no_grad():
            output = model(dummy_input)
    else:
        model.eval()
        with torch.no_grad():
            output = model(dummy_input)
    
    print(f"‚úÖ **Model Test Passed:**")
    print(f"   Input shape: {dummy_input.shape}")
    print(f"   Output shape: {output['out'].shape}")
    if 'aux' in output:
        print(f"   Aux output shape: {output['aux'].shape}")
    
    # Verify output channels
    expected_shape = (1, CFG['NUM_CLASSES'], CFG['CROP_SIZE'], CFG['CROP_SIZE'])
    if output['out'].shape == expected_shape:
        print(f"   ‚úÖ Output shape correct: {expected_shape}")
    else:
        print(f"   ‚ö†Ô∏è  Output shape mismatch: expected {expected_shape}, got {output['out'].shape}")

except Exception as e:
    print(f"‚ùå Model test failed: {e}")

model.train()  # Set back to training mode

In [None]:
# üìä **DATASET CREATION & TPU DATALOADER** - Optimized Pipeline
"""
Create train/val datasets and TPU-optimized DataLoaders
Key features:
- CityscapesDataset with label remapping
- TPU-aware data loading
- Parallel data loading across cores
"""

# üì• Check dataset availability first
if not os.path.exists(CFG['DATASET_ROOT']):
    print(f"‚ùå Dataset not found at: {CFG['DATASET_ROOT']}")
    print("üí° Creating dummy dataset for testing...")
    
    # Create minimal dummy structure for testing
    dummy_root = "/tmp/cityscapes_dummy"
    os.makedirs(os.path.join(dummy_root, "leftImg8bit", "train"), exist_ok=True)
    os.makedirs(os.path.join(dummy_root, "leftImg8bit", "val"), exist_ok=True)
    os.makedirs(os.path.join(dummy_root, "gtFine", "train"), exist_ok=True)
    os.makedirs(os.path.join(dummy_root, "gtFine", "val"), exist_ok=True)
    
    # Create a few dummy files for testing
    import cv2
    for split in ['train', 'val']:
        for i in range(2):  # Just 2 samples each
            # Dummy image
            dummy_img = np.random.randint(0, 255, (512, 1024, 3), dtype=np.uint8)
            img_path = os.path.join(dummy_root, "leftImg8bit", split, f"sample_{i:03d}_leftImg8bit.png")
            cv2.imwrite(img_path, dummy_img)
            
            # Dummy label
            dummy_label = np.random.randint(7, 34, (512, 1024), dtype=np.uint8)  # Use valid labelIds
            label_path = os.path.join(dummy_root, "gtFine", split, f"sample_{i:03d}_gtFine_labelIds.png")
            cv2.imwrite(label_path, dummy_label)
    
    CFG['DATASET_ROOT'] = dummy_root
    print(f"‚úÖ Dummy dataset created at: {CFG['DATASET_ROOT']}")

# üèóÔ∏è Create datasets
print("üìä Creating Cityscapes datasets...")

train_dataset = CityscapesDataset(
    root_dir=CFG['DATASET_ROOT'],
    split='train',
    transforms=get_train_transforms()
)

val_dataset = CityscapesDataset(
    root_dir=CFG['DATASET_ROOT'], 
    split='val',
    transforms=get_val_transforms()
)

print(f"\nüìà **Dataset Summary:**")
print(f"   Training samples: {len(train_dataset):,}")
print(f"   Validation samples: {len(val_dataset):,}")
print(f"   Total samples: {len(train_dataset) + len(val_dataset):,}")

# üîç Get TPU/device info (from previous cells)
try:
    import torch_xla.core.xla_model as xm
    IS_TPU = True
    device = xm.xla_device()
    
    # Get world_size and ordinal with modern API detection
    if hasattr(xm, 'xrt_world_size'):
        world_size = xm.xrt_world_size()
        ordinal = xm.get_ordinal()
    elif hasattr(xm, 'get_world_size'):
        world_size = xm.get_world_size()
        ordinal = xm.get_ordinal()
    else:
        import os
        world_size = int(os.environ.get('WORLD_SIZE', '8'))
        ordinal = int(os.environ.get('RANK', '0'))
        
except ImportError:
    IS_TPU = False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    world_size = 1
    ordinal = 0

# üöÄ Create TPU-optimized DataLoaders
if IS_TPU:
    # TPU requires special data loading
    print("üöÄ Creating TPU DataLoaders...")
    
    # Use the world_size and ordinal variables
    print(f"   Using world_size: {world_size}, ordinal: {ordinal}")
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=ordinal,
        shuffle=True
    )
    
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=world_size,
        rank=ordinal,
        shuffle=False
    )
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['BATCH_SIZE'],
        sampler=train_sampler,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=True,
        pin_memory=False  # Not needed for TPU
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['BATCH_SIZE'],
        sampler=val_sampler,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=False,
        pin_memory=False
    )
    
    # Wrap with TPU ParallelLoader
    import torch_xla.distributed.parallel_loader as pl
    train_loader = pl.MpDeviceLoader(train_loader, device)
    val_loader = pl.MpDeviceLoader(val_loader, device)
    
    effective_batch_size = CFG['BATCH_SIZE'] * world_size
    print(f"   TPU Cores: {world_size}")
    print(f"   Batch per core: {CFG['BATCH_SIZE']}")
    print(f"   Effective batch size: {effective_batch_size}")
    
else:
    # Regular GPU/CPU DataLoaders
    print("üíª Creating GPU/CPU DataLoaders...")
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['BATCH_SIZE'],
        shuffle=True,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=True,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['BATCH_SIZE'],
        shuffle=False,
        num_workers=CFG['NUM_WORKERS'],
        drop_last=False,
        pin_memory=True
    )
    
    effective_batch_size = CFG['BATCH_SIZE']
    print(f"   Batch size: {effective_batch_size}")

print(f"   Training batches: {len(train_loader):,}")
print(f"   Validation batches: {len(val_loader):,}")

# üß™ Test data loading
if len(train_dataset) > 0:
    print("\nüß™ Testing data loading...")
    try:
        # Get one batch
        data_iter = iter(train_loader)
        images, labels = next(data_iter)
        
        print(f"‚úÖ **Data Loading Test Passed:**")
        print(f"   Batch images shape: {images.shape}")
        print(f"   Batch labels shape: {labels.shape}")
        print(f"   Image dtype: {images.dtype}")
        print(f"   Label dtype: {labels.dtype}")
        print(f"   Image range: [{images.min():.3f}, {images.max():.3f}]")
        print(f"   Label range: [{labels.min()}, {labels.max()}]")
        print(f"   Unique labels in batch: {torch.unique(labels).tolist()}")
        
        # Check for ignore index
        ignore_count = (labels == CFG['IGNORE_INDEX']).sum().item()
        total_pixels = labels.numel()
        print(f"   Ignore pixels: {ignore_count:,} / {total_pixels:,} ({ignore_count/total_pixels*100:.2f}%)")
        
    except Exception as e:
        print(f"‚ùå Data loading test failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("\n‚ö†Ô∏è No dataset available - skipping data loading test")
    print("   Please ensure Cityscapes dataset is properly downloaded!")

In [None]:
# üéØ **LOSS FUNCTION & OPTIMIZER** - TPU Optimized Setup
"""
Setup loss function, optimizer, and learning rate scheduler
Key configurations:
- CrossEntropyLoss with ignore_index=255
- SGD optimizer with momentum
- Polynomial learning rate decay
- TPU-optimized settings
"""

# üìâ Loss Function
criterion = nn.CrossEntropyLoss(ignore_index=CFG['IGNORE_INDEX'])
print(f"üìâ **Loss Function:** CrossEntropyLoss(ignore_index={CFG['IGNORE_INDEX']})")

# üîß Optimizer Setup  
optimizer = optim.SGD(
    model.parameters(),
    lr=CFG['LEARNING_RATE'],
    momentum=CFG['MOMENTUM'],
    weight_decay=CFG['WEIGHT_DECAY']
)

print(f"üîß **Optimizer:** SGD")
print(f"   Learning Rate: {CFG['LEARNING_RATE']}")
print(f"   Momentum: {CFG['MOMENTUM']}")
print(f"   Weight Decay: {CFG['WEIGHT_DECAY']}")

# üìÖ Learning Rate Scheduler (Polynomial Decay)
def poly_lr_scheduler(optimizer, init_lr, iter, max_iter, power):
    """Polynomial learning rate decay"""
    lr = init_lr * (1 - iter / max_iter) ** power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# üî• Mixed Precision Setup
if CFG['MIXED_PRECISION']:
    if IS_TPU:
        # TPU uses bfloat16 automatically
        print("üî• **Mixed Precision:** bfloat16 (TPU native)")
        scaler = None  # TPU handles this automatically
    else:
        # GPU uses GradScaler for fp16
        from torch.cuda.amp import GradScaler, autocast
        scaler = GradScaler()
        print("üî• **Mixed Precision:** fp16 + GradScaler")
else:
    scaler = None
    print("üî• **Mixed Precision:** Disabled")

# üìä Training State Tracking
training_state = {
    'iteration': 0,
    'best_miou': 0.0,
    'train_losses': [],
    'val_mious': [],
    'learning_rates': []
}

print(f"\nüéØ **Training Configuration:**")
print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
print(f"   Eval Interval: {CFG['EVAL_INTERVAL']:,}")
print(f"   Save Interval: {CFG['SAVE_INTERVAL']:,}")
print(f"   Warmup Iterations: {CFG['WARMUP_ITERATIONS']:,}")
print(f"   Polynomial Power: {CFG['POWER']}")

# üìÅ Create output directory
os.makedirs(CFG['OUTPUT_DIR'], exist_ok=True)
print(f"üìÅ Output directory: {CFG['OUTPUT_DIR']}")

# üß™ Test loss computation with proper gradient setup
print("\nüß™ Testing loss computation...")
try:
    # Create dummy predictions and targets with gradients enabled
    dummy_pred = torch.randn(2, CFG['NUM_CLASSES'], 64, 64, requires_grad=True).to(device)
    dummy_target = torch.randint(0, CFG['NUM_CLASSES'], (2, 64, 64)).to(device)
    
    # Add some ignore pixels
    dummy_target[0, :10, :10] = CFG['IGNORE_INDEX']
    
    # Compute loss
    loss = criterion(dummy_pred, dummy_target)
    
    print(f"‚úÖ **Loss Test Passed:**")
    print(f"   Dummy prediction shape: {dummy_pred.shape}")
    print(f"   Dummy target shape: {dummy_target.shape}")
    print(f"   Loss value: {loss.item():.4f}")
    print(f"   Requires grad: {dummy_pred.requires_grad}")
    print(f"   Loss has grad_fn: {loss.grad_fn is not None}")
    
    # Test backward pass
    loss.backward()
    print(f"   ‚úÖ Backward pass successful")
    print(f"   Gradient shape: {dummy_pred.grad.shape if dummy_pred.grad is not None else 'None'}")
    
    # Clear gradients for clean state
    dummy_pred.grad = None
    
except Exception as e:
    print(f"‚ùå Loss test failed: {e}")
    import traceback
    traceback.print_exc()

# üß™ Test with actual model gradients
print("\nüß™ Testing with model gradients...")
try:
    # Zero gradients
    optimizer.zero_grad()
    
    # Create dummy input and run through model
    dummy_input = torch.randn(1, 3, 64, 64).to(device)
    dummy_target = torch.randint(0, CFG['NUM_CLASSES'], (1, 64, 64)).to(device)
    
    # Forward pass through model
    model_output = model(dummy_input)
    model_loss = criterion(model_output['out'], dummy_target)
    
    print(f"‚úÖ **Model Loss Test Passed:**")
    print(f"   Model output shape: {model_output['out'].shape}")
    print(f"   Model loss: {model_loss.item():.4f}")
    print(f"   Loss has grad_fn: {model_loss.grad_fn is not None}")
    
    # Test backward pass
    model_loss.backward()
    print(f"   ‚úÖ Model backward pass successful")
    
    # Check if gradients were computed
    total_grad_norm = 0
    param_count = 0
    for param in model.parameters():
        if param.grad is not None:
            total_grad_norm += param.grad.data.norm(2).item() ** 2
            param_count += 1
    
    total_grad_norm = total_grad_norm ** 0.5
    print(f"   üìä Gradient norm: {total_grad_norm:.4f} across {param_count} parameters")
    
    # Clear gradients
    optimizer.zero_grad()
    
except Exception as e:
    print(f"‚ùå Model loss test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# üìä **mIoU EVALUATION METRICS** - Cityscapes Standard
"""
Mean Intersection over Union (mIoU) computation for Cityscapes
Key features:
- Per-class IoU calculation
- Proper handling of ignore_index (255)
- TPU-optimized computation
- Standard Cityscapes evaluation protocol
"""

class mIoUCalculator:
    """
    üìä Calculate mean Intersection over Union for semantic segmentation
    """
    
    def __init__(self, num_classes, ignore_index=255):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.reset()
    
    def reset(self):
        """Reset all statistics"""
        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
    
    def update(self, predictions, targets):
        """
        Update confusion matrix with new predictions and targets
        
        Args:
            predictions: Model predictions (B, H, W) - class indices
            targets: Ground truth labels (B, H, W) - class indices
        """
        # Convert to numpy if tensors
        if torch.is_tensor(predictions):
            predictions = predictions.cpu().numpy()
        if torch.is_tensor(targets):
            targets = targets.cpu().numpy()
        
        # Flatten arrays
        predictions = predictions.flatten()
        targets = targets.flatten()
        
        # Remove ignore pixels
        valid_mask = (targets != self.ignore_index)
        predictions = predictions[valid_mask]
        targets = targets[valid_mask]
        
        # Ensure predictions are within valid range
        predictions = np.clip(predictions, 0, self.num_classes - 1)
        
        # Update confusion matrix
        for pred, target in zip(predictions, targets):
            if 0 <= target < self.num_classes:
                self.confusion_matrix[target, pred] += 1
    
    def compute_iou(self):
        """
        Compute IoU for each class and mean IoU
        
        Returns:
            per_class_iou: IoU for each class
            mean_iou: Mean IoU across all classes
        """
        # Calculate IoU for each class
        intersection = np.diag(self.confusion_matrix)
        union = (
            self.confusion_matrix.sum(axis=1) + 
            self.confusion_matrix.sum(axis=0) - 
            intersection
        )
        
        # Avoid division by zero
        valid_classes = union > 0
        per_class_iou = np.zeros(self.num_classes)
        per_class_iou[valid_classes] = intersection[valid_classes] / union[valid_classes]
        
        # Mean IoU (only for classes that appear in ground truth)
        mean_iou = per_class_iou[valid_classes].mean() if valid_classes.any() else 0.0
        
        return per_class_iou, mean_iou
    
    def get_results(self):
        """Get detailed results"""
        per_class_iou, mean_iou = self.compute_iou()
        
        results = {
            'mIoU': mean_iou,
            'per_class_IoU': per_class_iou,
            'confusion_matrix': self.confusion_matrix.copy()
        }
        
        return results
    
    def print_results(self, class_names=None):
        """Print formatted results"""
        per_class_iou, mean_iou = self.compute_iou()
        
        print(f"üìä **mIoU Results:**")
        print(f"   Mean IoU: {mean_iou:.4f} ({mean_iou*100:.2f}%)")
        print(f"\nüìã **Per-Class IoU:**")
        
        if class_names is None:
            class_names = [f"Class_{i}" for i in range(self.num_classes)]
        
        for i, (class_name, iou) in enumerate(zip(class_names, per_class_iou)):
            print(f"   {i:2d}. {class_name:<15}: {iou:.4f} ({iou*100:.2f}%)")

# üß™ Test mIoU calculator
print("üìä **mIoU CALCULATOR INITIALIZED**")
miou_calculator = mIoUCalculator(
    num_classes=CFG['NUM_CLASSES'], 
    ignore_index=CFG['IGNORE_INDEX']
)

# Test with dummy data
print("\nüß™ Testing mIoU calculator...")
try:
    # Create dummy predictions and targets
    dummy_predictions = np.random.randint(0, CFG['NUM_CLASSES'], (2, 100, 100))
    dummy_targets = np.random.randint(0, CFG['NUM_CLASSES'], (2, 100, 100))
    
    # Add some ignore pixels
    dummy_targets[0, :10, :10] = CFG['IGNORE_INDEX']
    
    # Update calculator
    miou_calculator.update(dummy_predictions, dummy_targets)
    
    # Compute results
    results = miou_calculator.get_results()
    
    print(f"‚úÖ **mIoU Test Passed:**")
    print(f"   Mean IoU: {results['mIoU']:.4f}")
    print(f"   Confusion matrix shape: {results['confusion_matrix'].shape}")
    print(f"   Per-class IoU shape: {results['per_class_IoU'].shape}")
    
    # Test with class names
    miou_calculator.print_results(CITYSCAPES_CLASSES)
    
    # Reset for actual training
    miou_calculator.reset()
    
except Exception as e:
    print(f"‚ùå mIoU test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# üîÑ **EVALUATION FUNCTION** - TPU Optimized Validation
"""
Comprehensive evaluation function for Cityscapes validation
Key features:
- TPU-aware evaluation loop
- Memory-efficient processing
- mIoU computation with proper aggregation
- Progress tracking and logging
"""

def evaluate_model(model, val_loader, criterion, miou_calculator, device, is_tpu=False):
    """
    üîç Evaluate model on validation set
    
    Returns:
        eval_results: Dictionary with loss, mIoU, and per-class IoU
    """
    model.eval()
    
    # Reset metrics
    miou_calculator.reset()
    total_loss = 0.0
    total_samples = 0
    
    print("üîç Starting validation...")
    
    with torch.no_grad():
        # Progress bar
        pbar = tqdm(val_loader, desc="üß™ Validating", leave=False) if not is_tpu else val_loader
        
        for batch_idx, (images, targets) in enumerate(pbar):
            # Move to device (already on TPU if using TPU loader)
            if not is_tpu:
                images = images.to(device)
                targets = targets.to(device)
            
            # Forward pass
            if CFG['MIXED_PRECISION'] and not is_tpu:
                # GPU with mixed precision
                with autocast():
                    outputs = model(images)
                    loss = criterion(outputs['out'], targets)
            else:
                # TPU or regular computation
                outputs = model(images)
                loss = criterion(outputs['out'], targets)
            
            # Accumulate loss
            total_loss += loss.item()
            total_samples += images.size(0)
            
            # Get predictions
            predictions = torch.argmax(outputs['out'], dim=1)
            
            # Update mIoU calculator
            miou_calculator.update(predictions, targets)
            
            # Update progress bar
            if not is_tpu and hasattr(pbar, 'set_postfix'):
                current_loss = total_loss / (batch_idx + 1)
                pbar.set_postfix({
                    'Loss': f'{current_loss:.4f}',
                    'Samples': f'{total_samples:,}'
                })
            
            # Periodic memory cleanup for TPU
            if is_tpu and batch_idx % 10 == 0:
                xm.mark_step()  # TPU step
    
    # Compute final metrics
    avg_loss = total_loss / len(val_loader)
    eval_results = miou_calculator.get_results()
    eval_results['loss'] = avg_loss
    eval_results['total_samples'] = total_samples
    
    # TPU synchronization
    if is_tpu:
        # Reduce metrics across TPU cores
        xm.master_print(f"üîç Validation completed on {total_samples:,} samples")
        # Note: For proper TPU evaluation, we'd need to aggregate metrics across cores
        # This is simplified for demonstration
    
    print(f"‚úÖ Validation completed:")
    print(f"   Average Loss: {avg_loss:.4f}")
    print(f"   Mean IoU: {eval_results['mIoU']:.4f} ({eval_results['mIoU']*100:.2f}%)")
    print(f"   Total Samples: {total_samples:,}")
    
    model.train()  # Set back to training mode
    return eval_results

# üß™ Test evaluation function (dry run)
print("üîç **EVALUATION FUNCTION READY**")
print("   - TPU-aware validation loop")
print("   - Memory-efficient processing")
print("   - Comprehensive mIoU computation")
print("   - Progress tracking and logging")

In [None]:
# üöÄ **TRAINING LOOP** - TPU Optimized Main Training
"""
Complete training loop optimized for TPU v5e-8
Key features:
- Polynomial learning rate scheduling with warmup
- TPU-native mixed precision (bfloat16)
- Periodic evaluation and checkpointing
- Memory-efficient gradient computation
- Comprehensive logging and monitoring
"""

def train_model():
    """
    üöÄ Main training function
    """
    print("üöÄ **STARTING TRAINING**")
    print(f"   Device: {device}")
    print(f"   Max Iterations: {CFG['MAX_ITERATIONS']:,}")
    print(f"   Effective Batch Size: {(CFG['BATCH_SIZE'] * world_size if IS_TPU else CFG['BATCH_SIZE'])}")
    print(f"   Learning Rate: {CFG['LEARNING_RATE']}")
    print(f"   Mixed Precision: {CFG['MIXED_PRECISION']}")
    
    # Training state
    global training_state
    iteration = training_state['iteration']
    best_miou = training_state['best_miou']
    
    # Create progress tracking
    if not IS_TPU or xm.is_master_ordinal():
        pbar = tqdm(total=CFG['MAX_ITERATIONS'], initial=iteration, desc="üöÄ Training")
    
    # Training loop
    model.train()
    
    while iteration < CFG['MAX_ITERATIONS']:
        
        # Set epoch for distributed sampler
        if IS_TPU:
            epoch = iteration // len(train_loader)
            train_loader._loader.sampler.set_epoch(epoch)
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            
            # Check iteration limit
            if iteration >= CFG['MAX_ITERATIONS']:
                break
            
            # Move to device (already on TPU if using TPU loader)
            if not IS_TPU:
                images = images.to(device)
                targets = targets.to(device)
            
            # Learning rate scheduling with warmup
            if iteration < CFG['WARMUP_ITERATIONS']:
                # Linear warmup
                lr = CFG['LEARNING_RATE'] * (iteration / CFG['WARMUP_ITERATIONS'])
            else:
                # Polynomial decay
                lr = poly_lr_scheduler(
                    optimizer, 
                    CFG['LEARNING_RATE'],
                    iteration - CFG['WARMUP_ITERATIONS'],
                    CFG['MAX_ITERATIONS'] - CFG['WARMUP_ITERATIONS'],
                    CFG['POWER']
                )
            
            # Set learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            if CFG['MIXED_PRECISION'] and not IS_TPU:
                # GPU mixed precision
                with autocast():
                    outputs = model(images)
                    main_loss = criterion(outputs['out'], targets)
                    
                    # Auxiliary loss (if available)
                    if 'aux' in outputs and outputs['aux'] is not None:
                        aux_loss = criterion(outputs['aux'], targets)
                        loss = main_loss + 0.4 * aux_loss  # Standard weight for aux loss
                    else:
                        loss = main_loss
                
                # Backward pass with scaling
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
            else:
                # TPU or regular computation
                outputs = model(images)
                main_loss = criterion(outputs['out'], targets)
                
                # Auxiliary loss (if available)
                if 'aux' in outputs and outputs['aux'] is not None:
                    aux_loss = criterion(outputs['aux'], targets)
                    loss = main_loss + 0.4 * aux_loss
                else:
                    loss = main_loss
                
                # Backward pass
                loss.backward()
                optimizer.step()
            
            # TPU step marking
            if IS_TPU:
                xm.mark_step()
            
            # Record training state
            training_state['train_losses'].append(loss.item())
            training_state['learning_rates'].append(lr)
            iteration += 1
            training_state['iteration'] = iteration
            
            # Update progress bar
            if not IS_TPU or xm.is_master_ordinal():
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'LR': f'{lr:.6f}',
                    'Iter': f'{iteration}/{CFG["MAX_ITERATIONS"]}'
                })
                pbar.update(1)
            
            # Evaluation
            if iteration % CFG['EVAL_INTERVAL'] == 0:
                print(f"\nüîç **Evaluation at iteration {iteration:,}**")
                
                # Run evaluation
                eval_results = evaluate_model(
                    model, val_loader, criterion, 
                    miou_calculator, device, IS_TPU
                )
                
                current_miou = eval_results['mIoU']
                training_state['val_mious'].append(current_miou)
                
                # Check for best model
                if current_miou > best_miou:
                    best_miou = current_miou
                    training_state['best_miou'] = best_miou
                    
                    # Save best model
                    if not IS_TPU or xm.is_master_ordinal():
                        model_path = os.path.join(CFG['OUTPUT_DIR'], 'best_model.pth')
                        torch.save({
                            'iteration': iteration,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_miou': best_miou,
                            'config': CFG
                        }, model_path)
                        print(f"üíæ Best model saved: {model_path} (mIoU: {best_miou:.4f})")
                
                print(f"   Current mIoU: {current_miou:.4f}")
                print(f"   Best mIoU: {best_miou:.4f}")
                
                # Print per-class results
                miou_calculator.print_results(CITYSCAPES_CLASSES)
                
                model.train()  # Set back to training mode
            
            # Save checkpoint
            if iteration % CFG['SAVE_INTERVAL'] == 0:
                if not IS_TPU or xm.is_master_ordinal():
                    checkpoint_path = os.path.join(CFG['OUTPUT_DIR'], f'checkpoint_{iteration}.pth')
                    torch.save({
                        'iteration': iteration,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'training_state': training_state,
                        'config': CFG
                    }, checkpoint_path)
                    print(f"\nüíæ Checkpoint saved: {checkpoint_path}")
            
            # Memory cleanup
            del outputs, loss
            if IS_TPU and iteration % 50 == 0:
                xm.mark_step()  # Additional TPU synchronization
    
    # Close progress bar
    if not IS_TPU or xm.is_master_ordinal():
        pbar.close()
    
    print(f"\n‚úÖ **Training completed!**")
    print(f"   Total iterations: {iteration:,}")
    print(f"   Best mIoU: {best_miou:.4f} ({best_miou*100:.2f}%)")
    
    return training_state

print("üöÄ **TRAINING LOOP READY**")
print("   - TPU-optimized training with bfloat16")
print("   - Polynomial LR scheduling with warmup")
print("   - Periodic evaluation and checkpointing")
print("   - Memory-efficient gradient computation")
print("   - Comprehensive progress tracking")

In [None]:
# üé¨ **START TRAINING** - Execute Training Loop
"""
Launch the complete training process
This will run for 40,000 iterations on TPU v5e-8
"""

# üöÄ Start training
print("üé¨ **LAUNCHING CITYSCAPES TRAINING**")
print("=" * 60)

try:
    # Run training
    final_state = train_model()
    
    print("\nüéâ **TRAINING COMPLETED SUCCESSFULLY!**")
    print(f"   Final Best mIoU: {final_state['best_miou']:.4f}")
    print(f"   Total Iterations: {final_state['iteration']:,}")
    print(f"   Training Losses: {len(final_state['train_losses']):,} recorded")
    print(f"   Validation mIoUs: {len(final_state['val_mious'])} recorded")
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
    print(f"   Current iteration: {training_state['iteration']:,}")
    print(f"   Current best mIoU: {training_state['best_miou']:.4f}")
    
except Exception as e:
    print(f"\n‚ùå Training failed: {e}")
    import traceback
    traceback.print_exc()
    
    # Save emergency checkpoint if possible
    try:
        emergency_path = os.path.join(CFG['OUTPUT_DIR'], 'emergency_checkpoint.pth')
        torch.save({
            'iteration': training_state['iteration'],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_state': training_state,
            'config': CFG,
            'error': str(e)
        }, emergency_path)
        print(f"üíæ Emergency checkpoint saved: {emergency_path}")
    except:
        print("‚ùå Could not save emergency checkpoint")

print("=" * 60)

In [None]:
# üìä **TRAINING VISUALIZATION & ANALYSIS** - Results Dashboard
"""
Visualize training progress and analyze results
Key features:
- Training loss curves
- Validation mIoU progression
- Learning rate schedule
- Per-class IoU breakdown
- Model performance analysis
"""

def plot_training_results(training_state):
    """
    üìà Plot comprehensive training results
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('üèôÔ∏è DeepLabv3+ Cityscapes Training Results', fontsize=16, y=0.98)
    
    # 1. Training Loss
    if training_state['train_losses']:
        axes[0, 0].plot(training_state['train_losses'], 'b-', alpha=0.7, linewidth=1)
        axes[0, 0].set_title('üìâ Training Loss')
        axes[0, 0].set_xlabel('Iteration')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Add smoothed line
        if len(training_state['train_losses']) > 100:
            window = len(training_state['train_losses']) // 50
            smooth_loss = np.convolve(training_state['train_losses'], 
                                    np.ones(window)/window, mode='valid')
            smooth_x = np.arange(window//2, len(training_state['train_losses']) - window//2 + 1)
            axes[0, 0].plot(smooth_x, smooth_loss, 'r-', linewidth=2, label='Smoothed')
            axes[0, 0].legend()
    
    # 2. Validation mIoU
    if training_state['val_mious']:
        eval_iterations = [CFG['EVAL_INTERVAL'] * (i+1) for i in range(len(training_state['val_mious']))]
        axes[0, 1].plot(eval_iterations, training_state['val_mious'], 'g-o', linewidth=2, markersize=6)
        axes[0, 1].set_title('üìä Validation mIoU')
        axes[0, 1].set_xlabel('Iteration')
        axes[0, 1].set_ylabel('mIoU')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].set_ylim(0, 1)
        
        # Highlight best score
        best_idx = np.argmax(training_state['val_mious'])
        best_iter = eval_iterations[best_idx]
        best_miou = training_state['val_mious'][best_idx]
        axes[0, 1].scatter([best_iter], [best_miou], color='red', s=100, zorder=5)
        axes[0, 1].annotate(f'Best: {best_miou:.4f}', 
                          xy=(best_iter, best_miou), 
                          xytext=(10, 10), textcoords='offset points',
                          bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
    
    # 3. Learning Rate Schedule
    if training_state['learning_rates']:
        axes[1, 0].plot(training_state['learning_rates'], 'purple', linewidth=2)
        axes[1, 0].set_title('üìÖ Learning Rate Schedule')
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].grid(True, alpha=0.3)
        axes[1, 0].set_yscale('log')
    
    # 4. Training Summary
    axes[1, 1].axis('off')
    summary_text = f\"\"\"\nüéØ Training Summary:\n\n‚Ä¢ Total Iterations: {training_state['iteration']:,}\n‚Ä¢ Best mIoU: {training_state['best_miou']:.4f} ({training_state['best_miou']*100:.2f}%)\n‚Ä¢ Final Loss: {training_state['train_losses'][-1]:.4f if training_state['train_losses'] else 'N/A'}\n‚Ä¢ Dataset: Cityscapes (19 classes)\n‚Ä¢ Resolution: {CFG['CROP_SIZE']}√ó{CFG['CROP_SIZE']}\n‚Ä¢ Batch Size: {CFG['BATCH_SIZE']} per core\n‚Ä¢ TPU Cores: {CFG['TPU_CORES']}\n‚Ä¢ Mixed Precision: {CFG['MIXED_PRECISION']}\n‚Ä¢ Backbone: {CFG['BACKBONE']}\n\"\"\"\n    axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, \n                    fontsize=11, verticalalignment='top', \n                    bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8))\n    \n    plt.tight_layout()\n    plt.show()\n    \n    # Save plot\n    plot_path = os.path.join(CFG['OUTPUT_DIR'], 'training_results.png')\n    fig.savefig(plot_path, dpi=150, bbox_inches='tight')\n    print(f\"üìä Training plot saved: {plot_path}\")\n\ndef analyze_final_results(training_state):\n    \"\"\"üìã Analyze and summarize final training results\"\"\"\n    print(\"\\nüìã **FINAL TRAINING ANALYSIS**\")\n    print(\"=\" * 50)\n    \n    # Basic stats\n    print(f\"üéØ **Training Completed:**\")\n    print(f\"   Total Iterations: {training_state['iteration']:,}\")\n    print(f\"   Best mIoU: {training_state['best_miou']:.4f} ({training_state['best_miou']*100:.2f}%)\")\n    \n    if training_state['train_losses']:\n        final_loss = training_state['train_losses'][-1]\n        avg_loss = np.mean(training_state['train_losses'][-1000:])  # Last 1000 iterations\n        print(f\"   Final Loss: {final_loss:.4f}\")\n        print(f\"   Average Loss (last 1000): {avg_loss:.4f}\")\n    \n    if training_state['val_mious']:\n        print(f\"   Evaluations: {len(training_state['val_mious'])}\")\n        print(f\"   mIoU Improvement: {training_state['val_mious'][-1] - training_state['val_mious'][0]:.4f}\")\n    \n    # Model files\n    print(f\"\\nüìÅ **Output Files:**\")\n    output_files = os.listdir(CFG['OUTPUT_DIR'])\n    for file in sorted(output_files):\n        file_path = os.path.join(CFG['OUTPUT_DIR'], file)\n        size_mb = os.path.getsize(file_path) / (1024 * 1024)\n        print(f\"   {file}: {size_mb:.1f} MB\")\n    \n    print(\"=\" * 50)\n\n# üìä Check if training has results to visualize\nif training_state['iteration'] > 0:\n    print(\"üìä **VISUALIZING TRAINING RESULTS**\")\n    plot_training_results(training_state)\n    analyze_final_results(training_state)\nelse:\n    print(\"üìä **NO TRAINING DATA TO VISUALIZE**\")\n    print(\"   Run the training cells above first!\")

In [None]:
# üñºÔ∏è **INFERENCE & VISUALIZATION** - Model Predictions
"""
Inference pipeline for trained model
Key features:
- Load best model checkpoint
- Inference on validation samples
- Visualization of predictions vs ground truth
- Cityscapes color mapping for beautiful results
"""

# üé® Cityscapes Color Palette (for visualization)
CITYSCAPES_COLORS = [
    [128, 64, 128],   # road
    [244, 35, 232],   # sidewalk  
    [70, 70, 70],     # building
    [102, 102, 156],  # wall
    [190, 153, 153],  # fence
    [153, 153, 153],  # pole
    [250, 170, 30],   # traffic light
    [220, 220, 0],    # traffic sign
    [107, 142, 35],   # vegetation
    [152, 251, 152],  # terrain
    [70, 130, 180],   # sky
    [220, 20, 60],    # person
    [255, 0, 0],      # rider
    [0, 0, 142],      # car
    [0, 0, 70],       # truck
    [0, 60, 100],     # bus
    [0, 80, 100],     # train
    [0, 0, 230],      # motorcycle
    [119, 11, 32]     # bicycle
]

def load_best_model():
    """üì• Load the best trained model"""
    model_path = os.path.join(CFG['OUTPUT_DIR'], 'best_model.pth')
    
    if os.path.exists(model_path):
        print(f"üì• Loading best model: {model_path}")
        checkpoint = torch.load(model_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        best_miou = checkpoint.get('best_miou', 0.0)
        iteration = checkpoint.get('iteration', 0)
        
        print(f"‚úÖ Model loaded successfully:")
        print(f"   Iteration: {iteration:,}")
        print(f"   Best mIoU: {best_miou:.4f}")
        
        model.eval()
        return True
    else:
        print(f"‚ùå Model not found: {model_path}")
        print("   Train the model first!")
        return False

def colorize_prediction(prediction):
    """üé® Convert prediction to colored image"""
    h, w = prediction.shape
    colored = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_id, color in enumerate(CITYSCAPES_COLORS):
        mask = (prediction == class_id)
        colored[mask] = color
    
    return colored

def inference_on_samples(num_samples=4):
    """üîç Run inference on validation samples"""
    if not load_best_model():
        return
    
    print(f"üîç Running inference on {num_samples} validation samples...")
    
    # Get samples from validation set
    val_iter = iter(val_loader)
    
    with torch.no_grad():
        for sample_idx in range(min(num_samples, len(val_loader))):
            try:
                images, targets = next(val_iter)
                
                # Take first image from batch
                image = images[0:1]  # Keep batch dimension
                target = targets[0].cpu().numpy()
                
                # Move to device
                if not IS_TPU:
                    image = image.to(device)
                
                # Forward pass
                outputs = model(image)
                prediction = torch.argmax(outputs['out'], dim=1)[0].cpu().numpy()
                
                # Convert to visualization format
                original_img = images[0].cpu().numpy().transpose(1, 2, 0)
                # Denormalize image
                mean = np.array(CFG['MEAN'])
                std = np.array(CFG['STD'])
                original_img = (original_img * std + mean) * 255
                original_img = np.clip(original_img, 0, 255).astype(np.uint8)
                
                # Colorize masks
                target_colored = colorize_prediction(target)
                pred_colored = colorize_prediction(prediction)
                
                # Create visualization
                fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                fig.suptitle(f'üèôÔ∏è Sample {sample_idx + 1} - Cityscapes Inference', fontsize=14)
                
                # Original image
                axes[0].imshow(original_img)
                axes[0].set_title('üì∑ Original Image')
                axes[0].axis('off')
                
                # Ground truth
                axes[1].imshow(target_colored)
                axes[1].set_title('üéØ Ground Truth')
                axes[1].axis('off')
                
                # Prediction
                axes[2].imshow(pred_colored)
                axes[2].set_title('ü§ñ Prediction')  
                axes[2].axis('off')
                
                # Overlay
                alpha = 0.6
                overlay = (alpha * original_img + (1-alpha) * pred_colored).astype(np.uint8)
                axes[3].imshow(overlay)
                axes[3].set_title('üé® Overlay')
                axes[3].axis('off')
                
                plt.tight_layout()
                plt.show()
                
                # Compute sample mIoU
                sample_miou = mIoUCalculator(CFG['NUM_CLASSES'], CFG['IGNORE_INDEX'])
                sample_miou.update(prediction, target)
                results = sample_miou.get_results()
                
                print(f"   Sample {sample_idx + 1} mIoU: {results['mIoU']:.4f}")
                
                # Save inference result
                inference_path = os.path.join(CFG['OUTPUT_DIR'], f'inference_sample_{sample_idx+1}.png')
                fig.savefig(inference_path, dpi=150, bbox_inches='tight')
                print(f"   üíæ Saved: {inference_path}")
                
            except StopIteration:
                print(f"   ‚ö†Ô∏è Only {sample_idx} samples available")
                break
            except Exception as e:
                print(f"   ‚ùå Error processing sample {sample_idx + 1}: {e}")

# üñºÔ∏è Create class legend
def create_class_legend():
    """üé® Create color legend for Cityscapes classes"""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create color patches
    colors_normalized = [[c/255.0 for c in color] for color in CITYSCAPES_COLORS]
    
    # Display as grid
    n_cols = 4
    n_rows = (len(CITYSCAPES_CLASSES) + n_cols - 1) // n_cols
    
    for i, (class_name, color) in enumerate(zip(CITYSCAPES_CLASSES, colors_normalized)):
        row = i // n_cols
        col = i % n_cols
        
        # Create rectangle
        rect = plt.Rectangle((col, n_rows - row - 1), 0.8, 0.8, 
                           facecolor=color, edgecolor='black', linewidth=1)
        ax.add_patch(rect)
        
        # Add text
        ax.text(col + 0.4, n_rows - row - 0.5, f"{i}: {class_name}", 
               ha='center', va='center', fontsize=10, weight='bold')
    
    ax.set_xlim(0, n_cols)
    ax.set_ylim(0, n_rows)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('üèôÔ∏è Cityscapes Classes Color Legend', fontsize=16, pad=20)
    
    plt.tight_layout()
    plt.show()
    
    # Save legend
    legend_path = os.path.join(CFG['OUTPUT_DIR'], 'cityscapes_legend.png')
    fig.savefig(legend_path, dpi=150, bbox_inches='tight')
    print(f"üé® Class legend saved: {legend_path}")

print("üñºÔ∏è **INFERENCE PIPELINE READY**")
print("   - Load best model checkpoint")
print("   - Inference on validation samples") 
print("   - Beautiful Cityscapes color visualization")
print("   - Per-sample mIoU computation")

# Show class legend
create_class_legend()

In [None]:
# üöÄ **RUN INFERENCE** - Generate Beautiful Predictions
"""
Execute inference on validation samples to see model results
"""

# üîç Run inference on samples
print("üîç **STARTING INFERENCE**")
print("=" * 40)

try:
    # Run inference on 4 validation samples
    inference_on_samples(num_samples=4)
    
    print("\n‚úÖ **INFERENCE COMPLETED SUCCESSFULLY!**")
    print("   Check the visualizations above and saved PNG files")
    
except Exception as e:
    print(f"\n‚ùå Inference failed: {e}")
    import traceback
    traceback.print_exc()

print("=" * 40)