<a href="https://colab.research.google.com/github/hikmat690/mamba/blob/main/TurboMamba_TAP_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üöÄ TurboMamba-TAP: Weather-Robust Semantic Segmentation

**Senior Deep Learning Engineer Implementation for Google Colab (T4 GPU)**

---

## Architecture Overview:
- **Stage 1:** TAP (Task-Adaptive Prompt) Cleaner - Removes weather degradation
- **Stage 2:** Mamba-inspired Encoder - Long-range dependency modeling
- **Stage 3:** Detail-preserving Decoder - High-quality segmentation

## Dataset Structure:
```
dataset.zip
‚îî‚îÄ‚îÄ acdc_night_train/
    ‚îú‚îÄ‚îÄ folder_1/ (images)
    ‚îú‚îÄ‚îÄ folder_2/ (images)
    ‚îú‚îÄ‚îÄ folder_3/ (images)
    ‚îú‚îÄ‚îÄ folder_4/ (images)
    ‚îî‚îÄ‚îÄ folder_5/ (images)

cityscapes.zip
‚îî‚îÄ‚îÄ cityscapes_data/
    ‚îú‚îÄ‚îÄ cityscapes_data/
    ‚îÇ   ‚îú‚îÄ‚îÄ train/ (images)
    ‚îÇ   ‚îî‚îÄ‚îÄ val/ (images)
    ‚îú‚îÄ‚îÄ train/ (images)
    ‚îî‚îÄ‚îÄ val/ (images)
```

---

**üìå Instructions:**
1. Upload `dataset.zip` and `cityscapes.zip` to `/content/`
2. Run all cells in order
3. Wait for training to complete (~20 epochs)
4. Download your trained model!

## üì¶ Step 1: Install Dependencies & Import Libraries

In [3]:
# Install required packages (if needed)
!pip install -q torch torchvision tqdm matplotlib pillow numpy

# Import libraries
import os
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import random
import glob

print("‚úì All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

‚úì All libraries imported successfully!
PyTorch version: 2.9.0+cpu
CUDA available: False


## üìÇ Step 2: Extract & Verify Dataset Structure

In [2]:
def setup_datasets():
    """
    Extract datasets and verify structure.

    Expected structure:
    - dataset.zip ‚Üí acdc_night_train ‚Üí 5 folders with images
    - cityscapes.zip ‚Üí cityscapes_data ‚Üí cityscapes_data/train, train, val
    """
    print("=" * 70)
    print("TASK 1: Dataset Setup & Extraction")
    print("=" * 70)

    base_path = '/content'

    # Dataset configurations
    datasets = {
        'dataset.zip': 'acdc_night_train',
        'cityscapes.zip': 'cityscapes_data'
    }

    extracted_paths = {}

    for zip_file, expected_folder in datasets.items():
        zip_path = os.path.join(base_path, zip_file)

        # Check if zip file exists
        if not os.path.exists(zip_path):
            print(f"‚ö†Ô∏è  WARNING: {zip_file} not found in /content/")
            print(f"   Please upload {zip_file} to Colab before running this cell.")
            continue

        print(f"\nüì¶ Processing {zip_file}...")

        # Extract
        print(f"   Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(base_path)
        print(f"   ‚úì Extracted")

        # Find the actual extracted folder
        # Sometimes zip files have the folder inside, sometimes they don't
        possible_paths = [
            os.path.join(base_path, expected_folder),
            os.path.join(base_path, zip_file.replace('.zip', '')),
        ]

        extracted_path = None
        for path in possible_paths:
            if os.path.exists(path):
                extracted_path = path
                break

        if extracted_path:
            extracted_paths[zip_file] = extracted_path
            print(f"   ‚úì Found at: {extracted_path}")
        else:
            print(f"   ‚ö†Ô∏è  Could not locate {expected_folder}")

    # Verify and print structure
    print("\n" + "=" * 70)
    print("Directory Structure Verification:")
    print("=" * 70)

    for zip_file, path in extracted_paths.items():
        print(f"\nüìÅ {os.path.basename(path)}/")

        # Walk through directory structure
        for root, dirs, files in os.walk(path):
            level = root.replace(path, '').count(os.sep)
            indent = ' ' * 2 * level
            folder_name = os.path.basename(root)

            if level < 3:  # Only show up to 3 levels deep
                if level > 0:
                    print(f"{indent}‚îú‚îÄ‚îÄ {folder_name}/")

                # Show image count
                if files and level < 3:
                    image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                    if image_files:
                        print(f"{indent}‚îÇ   ‚îî‚îÄ‚îÄ ({len(image_files)} images)")

    print("\n" + "=" * 70)
    return extracted_paths

# Run extraction
extracted_datasets = setup_datasets()

print("\n‚úì Dataset extraction complete!")
print(f"\nExtracted datasets: {list(extracted_datasets.keys())}")

TASK 1: Dataset Setup & Extraction

üì¶ Processing dataset.zip...
   Extracting...
   ‚úì Extracted
   ‚úì Found at: /content/acdc_night_train

üì¶ Processing cityscapes.zip...
   Extracting...
   ‚úì Extracted
   ‚úì Found at: /content/cityscapes_data

Directory Structure Verification:

üìÅ acdc_night_train/
  ‚îú‚îÄ‚îÄ GP020397/
  ‚îÇ   ‚îî‚îÄ‚îÄ (44 images)
  ‚îú‚îÄ‚îÄ GP010376/
  ‚îÇ   ‚îî‚îÄ‚îÄ (56 images)
  ‚îú‚îÄ‚îÄ GP010397/
  ‚îÇ   ‚îî‚îÄ‚îÄ (60 images)
  ‚îú‚îÄ‚îÄ GOPR0376/
  ‚îÇ   ‚îî‚îÄ‚îÄ (147 images)
  ‚îú‚îÄ‚îÄ GOPR0351/
  ‚îÇ   ‚îî‚îÄ‚îÄ (93 images)

üìÅ cityscapes_data/
  ‚îú‚îÄ‚îÄ val/
  ‚îÇ   ‚îî‚îÄ‚îÄ (500 images)
  ‚îú‚îÄ‚îÄ train/
  ‚îÇ   ‚îî‚îÄ‚îÄ (2975 images)
  ‚îú‚îÄ‚îÄ cityscapes_data/
    ‚îú‚îÄ‚îÄ val/
    ‚îÇ   ‚îî‚îÄ‚îÄ (500 images)
    ‚îú‚îÄ‚îÄ train/
    ‚îÇ   ‚îî‚îÄ‚îÄ (2975 images)


‚úì Dataset extraction complete!

Extracted datasets: ['dataset.zip', 'cityscapes.zip']


## üîß Step 3: Custom Dataset Class

In [3]:
class CombinedWeatherDataset(Dataset):
    """
    Multi-dataset loader for semantic segmentation.

    Handles:
    - ACDC Night Train: acdc_night_train/folder_X/images
    - Cityscapes: cityscapes_data/train or val/images

    Features:
    - Automatic image/mask discovery
    - RGB mask ‚Üí Class index conversion
    - Unified 512x512 resizing for T4 GPU
    - Handles images-only datasets (creates dummy masks)
    """

    # Cityscapes 19-class color mapping
    CITYSCAPES_COLORS = {
        (128, 64, 128): 0,   # road
        (244, 35, 232): 1,   # sidewalk
        (70, 70, 70): 2,     # building
        (102, 102, 156): 3,  # wall
        (190, 153, 153): 4,  # fence
        (153, 153, 153): 5,  # pole
        (250, 170, 30): 6,   # traffic light
        (220, 220, 0): 7,    # traffic sign
        (107, 142, 35): 8,   # vegetation
        (152, 251, 152): 9,  # terrain
        (70, 130, 180): 10,  # sky
        (220, 20, 60): 11,   # person
        (255, 0, 0): 12,     # rider
        (0, 0, 142): 13,     # car
        (0, 0, 70): 14,      # truck
        (0, 60, 100): 15,    # bus
        (0, 80, 100): 16,    # train
        (0, 0, 230): 17,     # motorcycle
        (119, 11, 32): 18,   # bicycle
    }

    def __init__(self, root_dirs, img_size=512, has_masks=True):
        """
        Args:
            root_dirs: List of root directories or single directory
            img_size: Target size for resizing (default: 512x512)
            has_masks: Whether dataset has ground truth masks
        """
        self.root_dirs = root_dirs if isinstance(root_dirs, list) else [root_dirs]
        self.img_size = img_size
        self.has_masks = has_masks

        # Find all images
        self.samples = []
        self._build_dataset()

        print(f"\nüìä Dataset Statistics:")
        print(f"   Total samples: {len(self.samples)}")
        print(f"   Image size: {img_size}x{img_size}")
        print(f"   Has masks: {has_masks}")

    def _build_dataset(self):
        """Scan directories and build image list."""
        for root_dir in self.root_dirs:
            root_path = Path(root_dir)

            if not root_path.exists():
                print(f"‚ö†Ô∏è  Warning: {root_dir} does not exist, skipping...")
                continue

            # Find all image files recursively
            image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']

            for ext in image_extensions:
                # Search recursively for images
                for img_path in root_path.rglob(ext):
                    # Skip mask files (common naming patterns)
                    if any(x in str(img_path).lower() for x in ['mask', 'label', 'gt', 'gtfine']):
                        continue

                    # Try to find corresponding mask
                    mask_path = self._find_mask(img_path)

                    self.samples.append({
                        'image': str(img_path),
                        'mask': str(mask_path) if mask_path else None
                    })

    def _find_mask(self, img_path):
        """Try to find corresponding mask file."""
        if not self.has_masks:
            return None

        img_path = Path(img_path)

        # Common mask directory patterns
        mask_patterns = [
            img_path.parent.parent / 'masks' / img_path.name,
            img_path.parent.parent / 'labels' / img_path.name,
            img_path.parent.parent / 'gt' / img_path.name,
            img_path.parent / 'masks' / img_path.name,
            img_path.parent / 'labels' / img_path.name,
        ]

        for mask_path in mask_patterns:
            if mask_path.exists():
                return mask_path

        return None

    def _rgb_to_class(self, mask_rgb):
        """Convert RGB mask to class indices."""
        mask_rgb = np.array(mask_rgb)
        h, w = mask_rgb.shape[:2]
        mask_class = np.zeros((h, w), dtype=np.int64)

        # Convert RGB to class index
        for color, class_idx in self.CITYSCAPES_COLORS.items():
            matches = np.all(mask_rgb == color, axis=-1)
            mask_class[matches] = class_idx

        return mask_class

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Load image
        image = Image.open(sample['image']).convert('RGB')
        image = image.resize((self.img_size, self.img_size), Image.BILINEAR)
        image = np.array(image).astype(np.float32) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1)  # HWC ‚Üí CHW

        # Load or create mask
        if sample['mask'] and os.path.exists(sample['mask']):
            mask = Image.open(sample['mask'])
            mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)

            # Convert RGB mask to class indices if needed
            if mask.mode == 'RGB':
                mask = self._rgb_to_class(mask)
            else:
                mask = np.array(mask)

            mask = torch.from_numpy(mask).long()
        else:
            # Create dummy mask (all zeros) if no mask available
            mask = torch.zeros((self.img_size, self.img_size), dtype=torch.long)

        return image, mask

print("‚úì Dataset class defined successfully!")

‚úì Dataset class defined successfully!


## üß† Step 4: TurboMamba-TAP Architecture

In [4]:
class TAP_Cleaner(nn.Module):
    """
    Task-Adaptive Prompt (TAP) Module
    Removes weather degradation (night, fog, rain) using learnable prompts.
    """
    def __init__(self, in_channels=3, hidden_dim=64):
        super().__init__()

        # Learnable prompt tensor
        self.prompt = nn.Parameter(torch.randn(1, in_channels, 1, 1) * 0.02)

        # 3-layer Conv2d cleaner with residual
        self.cleaner = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Conv2d(hidden_dim, in_channels, 3, padding=1),
            nn.Tanh()  # Residual in [-1, 1]
        )

    def forward(self, x):
        # Add learnable prompt
        x_prompted = x + self.prompt

        # Generate residual correction
        residual = self.cleaner(x_prompted)

        # Apply residual (scaled for stability)
        cleaned = x + 0.1 * residual
        cleaned = torch.clamp(cleaned, 0, 1)

        return cleaned, residual

print("‚úì TAP Cleaner defined!")

‚úì TAP Cleaner defined!


In [5]:
class SimpleMambaEncoder(nn.Module):
    """
    Mamba-inspired Encoder (Pure PyTorch)
    Uses Conv1d with large kernel to approximate selective scan.
    """
    def __init__(self, in_channels=3, hidden_dim=128, num_layers=4):
        super().__init__()

        # Initial projection
        self.input_proj = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, 3, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim),
            nn.GELU()
        )

        # Mamba blocks (Conv1d approximation of selective scan)
        self.mamba_blocks = nn.ModuleList()
        for _ in range(num_layers):
            self.mamba_blocks.append(
                nn.ModuleDict({
                    'scan': nn.Sequential(
                        nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7,
                                 padding=3, groups=hidden_dim),
                        nn.GroupNorm(8, hidden_dim),
                        nn.GELU(),
                    ),
                    'ffn': nn.Sequential(
                        nn.Conv1d(hidden_dim, hidden_dim * 4, 1),
                        nn.GELU(),
                        nn.Conv1d(hidden_dim * 4, hidden_dim, 1),
                    )
                })
            )

        # Downsampling for multi-scale features
        self.downsample = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim * 2, 3, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim * 2),
            nn.GELU()
        )

    def forward(self, x):
        # Initial projection: B,3,512,512 ‚Üí B,128,256,256
        x = self.input_proj(x)

        # Flatten for 1D convolution (simulate sequence)
        B, C, H, W = x.shape
        x_flat = x.view(B, C, H * W)  # B,C,L where L=H*W

        # Mamba blocks
        for block in self.mamba_blocks:
            # Selective scan
            residual = x_flat
            x_flat = block['scan'](x_flat) + residual

            # FFN
            residual = x_flat
            x_flat = block['ffn'](x_flat) + residual

        # Reshape back to 2D
        x = x_flat.view(B, C, H, W)

        # Downsample: B,128,256,256 ‚Üí B,256,128,128
        x = self.downsample(x)

        return x

print("‚úì Mamba Encoder defined!")

‚úì Mamba Encoder defined!


In [6]:
class DetailHead(nn.Module):
    """
    Detail-preserving Decoder
    Progressive upsampling back to original resolution.
    """
    def __init__(self, in_channels=256, num_classes=19):
        super().__init__()

        self.decoder = nn.Sequential(
            # 128x128 ‚Üí 256x256
            nn.ConvTranspose2d(in_channels, 128, 4, stride=2, padding=1),
            nn.GroupNorm(8, 128),
            nn.GELU(),

            # 256x256 ‚Üí 512x512
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.GroupNorm(8, 64),
            nn.GELU(),

            # Final classification
            nn.Conv2d(64, num_classes, 1)
        )

    def forward(self, x):
        return self.decoder(x)

print("‚úì Detail Head defined!")

‚úì Detail Head defined!


In [7]:
class TurboMambaTAP(nn.Module):
    """
    Complete TurboMamba-TAP Architecture

    Pipeline:
    Input ‚Üí TAP Cleaner ‚Üí Mamba Encoder ‚Üí Detail Head ‚Üí Segmentation
    """
    def __init__(self, num_classes=19):
        super().__init__()

        self.tap_cleaner = TAP_Cleaner(in_channels=3, hidden_dim=64)
        self.mamba_encoder = SimpleMambaEncoder(in_channels=3, hidden_dim=128, num_layers=4)
        self.detail_head = DetailHead(in_channels=256, num_classes=num_classes)

    def forward(self, x, return_cleaned=False):
        # Stage 1: TAP Cleaning
        x_clean, residual = self.tap_cleaner(x)

        # Stage 2: Mamba Encoding
        features = self.mamba_encoder(x_clean)

        # Stage 3: Detail Decoding
        logits = self.detail_head(features)

        if return_cleaned:
            return logits, x_clean

        return logits

print("‚úì TurboMamba-TAP model defined!")

‚úì TurboMamba-TAP model defined!


## üèãÔ∏è Step 5: Training Functions

In [8]:
def calculate_metrics(pred, target, num_classes=19):
    """Calculate mIoU and pixel accuracy."""
    pred = pred.cpu().numpy()
    target = target.cpu().numpy()

    # Pixel accuracy
    pixel_acc = (pred == target).mean()

    # mIoU
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)

        intersection = (pred_cls & target_cls).sum()
        union = (pred_cls | target_cls).sum()

        if union > 0:
            ious.append(intersection / union)

    miou = np.mean(ious) if ious else 0.0

    return pixel_acc, miou

print("‚úì Metrics function defined!")

‚úì Metrics function defined!


In [9]:
def train(model, train_loader, val_loader, device, epochs=20, lr=1e-4):
    """
    Training loop with progress tracking.
    """
    print("\n" + "=" * 70)
    print("TRAINING TURBOMAMBA-TAP")
    print("=" * 70)
    print(f"Device: {device}")
    print(f"Batch size: {train_loader.batch_size}")
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Epochs: {epochs}")
    print(f"Learning rate: {lr}")
    print("=" * 70 + "\n")

    # Optimizer and loss
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_miou': [],
        'val_pixel_acc': []
    }

    best_miou = 0.0

    for epoch in range(epochs):
        # ========== TRAINING ==========
        model.train()
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        # ========== VALIDATION ==========
        model.eval()
        val_loss = 0.0
        all_pixel_acc = []
        all_miou = []

        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for images, masks in pbar:
                images = images.to(device)
                masks = masks.to(device)

                logits = model(images)
                loss = criterion(logits, masks)
                val_loss += loss.item()

                pred = logits.argmax(dim=1)
                pixel_acc, miou = calculate_metrics(pred, masks)
                all_pixel_acc.append(pixel_acc)
                all_miou.append(miou)

                pbar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'mIoU': f"{miou:.4f}"
                })

        val_loss /= len(val_loader)
        avg_pixel_acc = np.mean(all_pixel_acc)
        avg_miou = np.mean(all_miou)

        history['val_loss'].append(val_loss)
        history['val_miou'].append(avg_miou)
        history['val_pixel_acc'].append(avg_pixel_acc)

        scheduler.step()

        # Print summary
        print(f"\n{'='*70}")
        print(f"Epoch {epoch+1}/{epochs} Summary:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val mIoU:   {avg_miou:.4f}")
        print(f"  Val Acc:    {avg_pixel_acc:.4f}")
        print(f"{'='*70}\n")

        # Save best model
        if avg_miou > best_miou:
            best_miou = avg_miou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'miou': best_miou,
            }, '/content/turbo_mamba_best.pth')
            print(f"‚úì Saved best model (mIoU: {best_miou:.4f})")

    # Save final model
    torch.save(model.state_dict(), '/content/turbo_mamba_colab.pth')
    print("\n‚úì Training complete! Model saved to: turbo_mamba_colab.pth")

    return history

print("‚úì Training function defined!")

‚úì Training function defined!


## üìä Step 6: Visualization Functions

In [10]:
def visualize_results(model, val_dataset, device):
    """
    Visualize: [Original Image, Cleaned, Ground Truth, Prediction]
    """
    print("\n" + "=" * 70)
    print("Generating Visualization...")
    print("=" * 70)

    model.eval()

    # Get random sample
    idx = random.randint(0, len(val_dataset) - 1)
    image, mask = val_dataset[idx]

    # Inference
    image_input = image.unsqueeze(0).to(device)

    with torch.no_grad():
        logits, cleaned = model(image_input, return_cleaned=True)
        pred = logits.argmax(dim=1)

    # Convert to numpy
    image_np = image.permute(1, 2, 0).cpu().numpy()
    cleaned_np = cleaned.squeeze(0).permute(1, 2, 0).cpu().numpy()
    mask_np = mask.cpu().numpy()
    pred_np = pred.squeeze(0).cpu().numpy()

    # Plot
    cmap = plt.cm.get_cmap('tab20', 19)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    axes[0].imshow(image_np)
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')

    axes[1].imshow(cleaned_np)
    axes[1].set_title('TAP Cleaned', fontsize=14, fontweight='bold')
    axes[1].axis('off')

    axes[2].imshow(mask_np, cmap=cmap, vmin=0, vmax=18)
    axes[2].set_title('Ground Truth', fontsize=14, fontweight='bold')
    axes[2].axis('off')

    im = axes[3].imshow(pred_np, cmap=cmap, vmin=0, vmax=18)
    axes[3].set_title('Prediction', fontsize=14, fontweight='bold')
    axes[3].axis('off')

    plt.colorbar(im, ax=axes, orientation='horizontal', fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.savefig('/content/turbo_mamba_result.png', dpi=150, bbox_inches='tight')
    print("‚úì Saved to: turbo_mamba_result.png")
    plt.show()


def plot_training_history(history):
    """Plot training curves."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training & Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # mIoU
    axes[1].plot(history['val_miou'], label='Val mIoU', marker='o', color='green')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('mIoU')
    axes[1].set_title('Validation mIoU')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # Pixel Accuracy
    axes[2].plot(history['val_pixel_acc'], label='Val Pixel Acc', marker='o', color='orange')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Accuracy')
    axes[2].set_title('Validation Pixel Accuracy')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('/content/training_history.png', dpi=150, bbox_inches='tight')
    print("‚úì Saved to: training_history.png")
    plt.show()

print("‚úì Visualization functions defined!")

‚úì Visualization functions defined!


## üöÄ Step 7: Main Execution - Build Dataset & Initialize Model

In [11]:
print("\n" + "=" * 70)
print("CREATING COMBINED DATASET")
print("=" * 70)

# Collect all dataset paths
all_data_dirs = []

# Add extracted dataset paths
for zip_file, path in extracted_datasets.items():
    all_data_dirs.append(path)
    print(f"Added: {path}")

# Create combined dataset
full_dataset = CombinedWeatherDataset(
    root_dirs=all_data_dirs,
    img_size=512,
    has_masks=True  # Set to False if you don't have masks
)

# Split into train/val (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

print(f"\n‚úì Dataset split: {train_size} train, {val_size} val")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,  # T4 GPU safe
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("‚úì DataLoaders created!")


CREATING COMBINED DATASET
Added: /content/acdc_night_train
Added: /content/cityscapes_data

üìä Dataset Statistics:
   Total samples: 7350
   Image size: 512x512
   Has masks: True

‚úì Dataset split: 5880 train, 1470 val
‚úì DataLoaders created!


In [2]:
print("\n" + "=" * 70)
print("INITIALIZING TURBOMAMBA-TAP MODEL")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")

# Initialize model
model = TurboMambaTAP(num_classes=19).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024**2:.2f} MB")
print(f"\n‚úì Model ready for training!")


INITIALIZING TURBOMAMBA-TAP MODEL


NameError: name 'torch' is not defined

## üéØ Step 8: Train the Model

**This will take approximately 30-60 minutes on T4 GPU for 20 epochs.**

In [1]:
# Train the model
history = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=20,
    lr=1e-4
)

NameError: name 'train' is not defined

## üìà Step 9: Visualize Results

In [None]:
# Plot training history
plot_training_history(history)

In [None]:
# Visualize predictions
visualize_results(model, val_dataset, device)

## üíæ Step 10: Download Your Trained Model

Run the cell below to download your trained model files.

In [None]:
from google.colab import files

print("\n" + "=" * 70)
print("DOWNLOADING TRAINED MODELS")
print("=" * 70)

# Download final model
if os.path.exists('/content/turbo_mamba_colab.pth'):
    print("\nDownloading turbo_mamba_colab.pth...")
    files.download('/content/turbo_mamba_colab.pth')
    print("‚úì Downloaded!")

# Download best model
if os.path.exists('/content/turbo_mamba_best.pth'):
    print("\nDownloading turbo_mamba_best.pth...")
    files.download('/content/turbo_mamba_best.pth')
    print("‚úì Downloaded!")

# Download visualizations
if os.path.exists('/content/turbo_mamba_result.png'):
    print("\nDownloading turbo_mamba_result.png...")
    files.download('/content/turbo_mamba_result.png')
    print("‚úì Downloaded!")

if os.path.exists('/content/training_history.png'):
    print("\nDownloading training_history.png...")
    files.download('/content/training_history.png')
    print("‚úì Downloaded!")

print("\n" + "=" * 70)
print("‚úì ALL DOWNLOADS COMPLETE!")
print("=" * 70)

## üîÆ Bonus: Inference on New Images

Use this cell to test your trained model on new images!

In [None]:
def predict_single_image(model, image_path, device):
    """
    Run inference on a single image.
    """
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image = image.resize((512, 512), Image.BILINEAR)
    image_np = np.array(image).astype(np.float32) / 255.0
    image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        logits, cleaned = model(image_tensor, return_cleaned=True)
        pred = logits.argmax(dim=1)

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(image_np)
    axes[0].set_title('Original', fontsize=14, fontweight='bold')
    axes[0].axis('off')

    cleaned_np = cleaned.squeeze(0).permute(1, 2, 0).cpu().numpy()
    axes[1].imshow(cleaned_np)
    axes[1].set_title('TAP Cleaned', fontsize=14, fontweight='bold')
    axes[1].axis('off')

    pred_np = pred.squeeze(0).cpu().numpy()
    cmap = plt.cm.get_cmap('tab20', 19)
    im = axes[2].imshow(pred_np, cmap=cmap, vmin=0, vmax=18)
    axes[2].set_title('Segmentation', fontsize=14, fontweight='bold')
    axes[2].axis('off')

    plt.colorbar(im, ax=axes, orientation='horizontal', fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

# Example usage (uncomment and provide your image path):
# predict_single_image(model, '/content/your_test_image.jpg', device)

---

## üéâ Training Complete!

### Generated Files:
- ‚úÖ `turbo_mamba_colab.pth` - Final trained model
- ‚úÖ `turbo_mamba_best.pth` - Best validation checkpoint
- ‚úÖ `turbo_mamba_result.png` - Visualization
- ‚úÖ `training_history.png` - Training curves

### Model Architecture:
- **TAP Cleaner**: Removes weather degradation
- **Mamba Encoder**: Long-range feature extraction
- **Detail Head**: High-resolution segmentation

### Next Steps:
1. Download your trained models
2. Test on new images using the inference cell
3. Fine-tune hyperparameters if needed
4. Deploy to your application

---

**Need help?** Check the paper or reach out to the research team!

**Senior Deep Learning Engineer** üöÄ