<a href="https://colab.research.google.com/github/khanmhmdi/Moe-llm-edge-computing/blob/main/MOE_SWITCH_TRANSFORMER_LYFT_AND_BUDK100_DATASET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# 1. Image patch embedding
# -----------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H/patch_size, W/patch_size]
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        x += self.pos_embedding
        return x

# -----------------------------
# 2. Switch MoE Feed-forward module
# -----------------------------
class ExpertFFN(nn.Module):
    def __init__(self, embed_dim, expansion=4):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion)
        self.fc2 = nn.Linear(embed_dim * expansion, embed_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

class SwitchFFN(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4, capacity_factor=1.0):
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor

        # Instantiate experts as a ModuleList
        self.experts = nn.ModuleList([ExpertFFN(embed_dim, expansion=expansion) for _ in range(num_experts)])
        # Gating network (top-1 gating)
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        # x: [B, num_tokens, embed_dim]
        B, N, D = x.shape
        logits = self.gate(x)  # [B, N, num_experts]
        # Choose top-1 expert for each token
        indices = logits.argmax(dim=-1)  # [B, N]

        # Create output tensor
        output = torch.zeros_like(x)
        # For auxiliary loss (optional, not computed here)
        expert_counts = torch.zeros(x.size(0), self.num_experts, device=x.device)

        # Dispatch tokens to experts (this is a basic loop version; for efficiency, consider vectorizing or using libraries)
        for expert_id in range(self.num_experts):
            # Find tokens assigned to this expert
            mask = (indices == expert_id)  # [B, N] boolean mask
            if mask.sum() == 0:
                continue
            # Get the indices of tokens to be processed
            selected_tokens = x[mask]  # [num_selected, D]
            expert_output = self.experts[expert_id](selected_tokens)  # Process tokens
            output[mask] = expert_output
            expert_counts[:, expert_id] += mask.sum(dim=-1).float()
        return output, expert_counts  # return expert counts if you want to add auxiliary loss

# -----------------------------
# 3. Switch Transformer Layer
# -----------------------------
class SwitchTransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.switch_ffn = SwitchFFN(embed_dim, num_experts=num_experts, expansion=expansion)

    def forward(self, x):
        # Self-attention block
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # Switch MoE FFN block
        ffn_out, expert_counts = self.switch_ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.norm2(x)
        return x, expert_counts

# -----------------------------
# 4. Encoder with multiple Switch Transformer Layers
# -----------------------------
class SwitchTransformerEncoder(nn.Module):
    def __init__(self, depth, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            SwitchTransformerLayer(embed_dim, num_experts, expansion, num_heads, dropout)
            for _ in range(depth)
        ])

    def forward(self, x):
        expert_loss = 0.0  # sum of auxiliary losses from each layer (if computed)
        for layer in self.layers:
            x, expert_counts = layer(x)
            # Optionally: Compute a load-balancing loss with expert_counts here
            # For instance, a simple loss that penalizes deviation from even load distribution.
            # expert_loss += compute_auxiliary_loss(expert_counts)
        return x, expert_loss

# -----------------------------
# 5. Segmentation Decoder
# -----------------------------
class SegmentationDecoder(nn.Module):
    def __init__(self, embed_dim, num_patches, img_size=224, patch_size=16, num_classes=21):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches_side = img_size // patch_size

        # A simple decoder that reshape tokens back to spatial map and applies a conv layer.
        self.conv1 = nn.Conv2d(embed_dim, embed_dim // 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(embed_dim // 2)
        self.conv2 = nn.Conv2d(embed_dim // 2, num_classes, kernel_size=1)

    def forward(self, x):
        # x: [B, num_patches, embed_dim]
        B, N, D = x.shape
        H = W = self.num_patches_side
        x = x.transpose(1, 2).reshape(B, D, H, W)  # [B, D, H, W]

        x = F.relu(self.bn1(self.conv1(x)))
        # Upsample to original image size
        x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        logits = self.conv2(x)  # [B, num_classes, img_size, img_size]
        return logits

# -----------------------------
# 6. Full Model
# -----------------------------
class SwitchTransformerSegmentationModel(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6,
                 num_experts=4, expansion=4, num_heads=8, dropout=0.1, num_classes=21):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = (img_size // patch_size) ** 2
        self.encoder = SwitchTransformerEncoder(depth, embed_dim, num_experts, expansion, num_heads, dropout)
        self.decoder = SegmentationDecoder(embed_dim, num_patches, img_size, patch_size, num_classes)

    def forward(self, x):
        # x: [B, in_channels, img_size, img_size]
        x = self.patch_embedding(x)
        x, aux_loss = self.encoder(x)
        segmentation_logits = self.decoder(x)
        return segmentation_logits, aux_loss

# -----------------------------
# 7. Example Training Loop
# -----------------------------
if __name__ == "__main__":
    # Example input
    img = torch.randn(2, 3, 224, 224)  # batch of 2 images
    model = SwitchTransformerSegmentationModel(img_size=224, patch_size=16, in_channels=3,
                                                 embed_dim=768, depth=6, num_experts=4,
                                                 expansion=4, num_heads=8, dropout=0.1, num_classes=21)
    logits, aux_loss = model(img)
    print("Segmentation logits shape:", logits.shape)  # should be [B, num_classes, 224, 224]

    # Define segmentation loss (e.g., cross entropy)
    targets = torch.randint(0, 21, (2, 224, 224))  # random segmentation maps
    seg_loss = F.cross_entropy(logits, targets)

    loss = seg_loss + 0.01 * aux_loss  # combine with auxiliary load balancing loss
    loss.backward()
    print("Backward pass successful!")


Segmentation logits shape: torch.Size([2, 21, 224, 224])
Backward pass successful!


In [1]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
kumaresanmanickavelu_lyft_udacity_challenge_path = kagglehub.dataset_download('kumaresanmanickavelu/lyft-udacity-challenge')

Downloading from https://www.kaggle.com/api/v1/datasets/download/kumaresanmanickavelu/lyft-udacity-challenge?dataset_version_number=1...


100%|██████████| 5.11G/5.11G [04:13<00:00, 21.6MB/s]

Extracting files...





In [2]:
import os
print(os.listdir())

['.config', 'sample_data']


In [3]:
print(kumaresanmanickavelu_lyft_udacity_challenge_path)

/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1


In [3]:
import os

# Update the base path to the new location
base_path = '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/'

# Update the image and mask paths
image_path = os.path.join(base_path, 'CameraRGB/')
mask_path = os.path.join(base_path, 'CameraSeg/')

# List the files in the image and mask directories
image_list = os.listdir(image_path)
mask_list = os.listdir(mask_path)

# Create full paths for the images and masks
image_list = [os.path.join(image_path, i) for i in image_list]
mask_list = [os.path.join(mask_path, i) for i in mask_list]

# Print the first few paths to verify
print("First 5 image paths:", image_list[:5])
print("First 5 mask paths:", mask_list[:5])

First 5 image paths: ['/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraRGB/F68-11.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraRGB/F70-90.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraRGB/F3-11.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraRGB/06_00_240.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraRGB/09_00_210.png']
First 5 mask paths: ['/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraSeg/F68-11.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dataa/dataA/CameraSeg/F70-90.png', '/root/.cache/kagglehub/datasets/kumaresanmanickavelu/lyft-udacity-challenge/versions/1/dat

In [4]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import imageio

import matplotlib.pyplot as plt
%matplotlib inline


In [13]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# -----------------------------
# 1. Define your Dataset class
# -----------------------------
class LyftSegmentationDataset(Dataset):
    """
    Dataset for segmentation that loads images and masks from disk.
    Assumes images are in RGB and masks are single-channel with integer labels.
    """
    def __init__(self, image_list, mask_list, transform_image=None, transform_mask=None):
        assert len(image_list) == len(mask_list), "Image and mask lists must be equal length."
        self.image_list = image_list
        self.mask_list = mask_list
        self.transform_image = transform_image
        self.transform_mask = transform_mask

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx]).convert("RGB")
        mask = Image.open(self.mask_list[idx]).convert("L")  # Convert to single channel

        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)
        else:
            mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        return image, mask

# -----------------------------
# 2. Define image and mask transformations
# -----------------------------
img_size = 224  # Defined image size
transform_image = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),  # [0, 1] float tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_mask = transforms.Compose([
    transforms.Resize((img_size, img_size), interpolation=Image.NEAREST),
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x, dtype=np.int64)))
])

# -----------------------------
# 3. File lists for images and masks
# -----------------------------
# image_path = '/kaggle/input/lyft-udacity-challenge/dataa/dataA/CameraRGB/'
# mask_path = '/kaggle/input/lyft-udacity-challenge/dataa/dataA/CameraSeg/'

# image_list = sorted([os.path.join(image_path, i) for i in os.listdir(image_path)])
# mask_list = sorted([os.path.join(mask_path, i) for i in os.listdir(mask_path)])

# -----------------------------
# 4. Define the Switch Transformer Segmentation Model and components
# -----------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

    def forward(self, x):
        # x: [B, in_channels, H, W]
        x = self.proj(x)                    # [B, embed_dim, H/patch_size, W/patch_size]
        x = x.flatten(2)                    # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)               # [B, num_patches, embed_dim]
        x += self.pos_embedding             # Add position embedding
        return x

class ExpertFFN(nn.Module):
    def __init__(self, embed_dim, expansion=4):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion)
        self.fc2 = nn.Linear(embed_dim * expansion, embed_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

class SwitchFFN(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([ExpertFFN(embed_dim, expansion) for _ in range(num_experts)])
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        B, N, D = x.shape
        logits = self.gate(x)               # [B, N, num_experts]
        indices = logits.argmax(dim=-1)     # [B, N] - select expert index with highest score
        output = torch.zeros_like(x)
        expert_counts = torch.zeros(B, self.num_experts, device=x.device)
        for expert_id in range(self.num_experts):
            mask = (indices == expert_id)   # Boolean mask [B, N]
            if mask.sum() == 0:
                continue
            selected_tokens = x[mask]       # [num_selected, D]
            expert_output = self.experts[expert_id](selected_tokens)
            output[mask] = expert_output
            expert_counts[:, expert_id] += mask.float().sum(dim=1) if mask.dim() == 2 else mask.float().sum(dim=0)
        return output, expert_counts

class SwitchTransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.switch_ffn = SwitchFFN(embed_dim, num_experts=num_experts, expansion=expansion)

    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ffn_out, expert_counts = self.switch_ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.norm2(x)
        return x, expert_counts

class SwitchTransformerEncoder(nn.Module):
    def __init__(self, depth, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            SwitchTransformerLayer(embed_dim, num_experts, expansion, num_heads, dropout)
            for _ in range(depth)
        ])

    def forward(self, x):
        aux_loss = 0.0
        for layer in self.layers:
            x, expert_counts = layer(x)
            # Optionally, incorporate an auxiliary loss for load balancing here if desired.
        return x, aux_loss

class SegmentationDecoder(nn.Module):
    def __init__(self, embed_dim, img_size=224, patch_size=16, num_classes=21):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches_side = img_size // patch_size
        self.conv1 = nn.Conv2d(embed_dim, embed_dim // 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(embed_dim // 2)
        self.conv2 = nn.Conv2d(embed_dim // 2, num_classes, kernel_size=1)

    def forward(self, x):
        B, N, D = x.shape
        H = W = self.num_patches_side
        x = x.transpose(1, 2).reshape(B, D, H, W)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        logits = self.conv2(x)
        return logits

class SwitchTransformerSegmentationModel(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6,
                 num_experts=4, expansion=4, num_heads=8, dropout=0.1, num_classes=21):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.encoder = SwitchTransformerEncoder(depth, embed_dim, num_experts, expansion, num_heads, dropout)
        self.decoder = SegmentationDecoder(embed_dim, img_size, patch_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x, aux_loss = self.encoder(x)
        segmentation_logits = self.decoder(x)
        return segmentation_logits, aux_loss

# -----------------------------
# 5. Define Dice Score Function
# -----------------------------
def dice_score(pred, target, num_classes, epsilon=1e-6):
    """
    Computes the Dice score per class and returns the average score.
    pred: tensor of shape [B, H, W] with predicted classes.
    target: tensor of shape [B, H, W] with ground truth classes.
    """
    dice = 0.0
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        dice += (2 * intersection + epsilon) / (union + epsilon)
    return dice / num_classes


import matplotlib.pyplot as plt
import numpy as np

# Define your class colors (adjust the number of classes/colors as needed)
class_colors = {
    0: (1, 0, 0),   # red  (often used for background, or class 0)
    1: (0, 1, 0),   # green (class 1)
    2: (0, 0, 1),   # blue  (class 2)
    3: (1, 0, 1)    # magenta (class 3)
}


def colorize_mask(mask, class_colors):
    """
    Colorize a segmentation mask given a mapping from class labels to colors.

    Parameters:
        mask (np.ndarray): 2D array (height, width) of class indices.
        class_colors (dict): dictionary mapping class indices to RGB tuples.

    Returns:
        np.ndarray: colorized mask of shape (height, width, 3) for visualization.
    """
    # Create an empty array with an extra channel for color (RGB)
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.float32)

    # Assign colors for each class present in the mask
    for class_id, color in class_colors.items():
        color_mask[mask == class_id] = color

    return color_mask

# -----------------------------
# 6. Helper function to plot predictions
# -----------------------------
def plot_prediction_and_mask(image, pred_mask, true_mask, class_colors):
    """
    Plot the original image, prediction mask, and ground truth mask.

    Parameters:
        image (np.ndarray): the input image of shape (height, width, 3).
        pred_mask (np.ndarray): predicted segmentation mask (class indices).
        true_mask (np.ndarray): ground truth segmentation mask (class indices).
        class_colors (dict): mapping from class labels to colors.
    """

    # Create color masks for the predicted and true masks
    pred_color_mask = colorize_mask(pred_mask, class_colors)
    true_color_mask = colorize_mask(true_mask, class_colors)

    # Plotting using matplotlib
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Show original image
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Show predicted mask
    axes[1].imshow(pred_color_mask)
    axes[1].set_title("Predicted Mask")
    axes[1].axis('off')

    # Show ground truth mask
    axes[2].imshow(true_color_mask)
    axes[2].set_title("Ground Truth Mask")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()


# -----------------------------
# 7. Training loop with Dice score and prediction plotting
# -----------------------------
def train(model, dataloader, optimizer, device, num_epochs=5, aux_loss_weight=0.01, num_classes=21, plot_freq=10):
    model.train()
    criterion = nn.CrossEntropyLoss()  # segmentation cross-entropy loss

    for epoch in range(num_epochs):
        running_loss = 0.0
        running_dice = 0.0
        for i, (images, masks) in enumerate(dataloader):
            images = images.to(device)      # [B, 3, H, W]
            masks = masks.to(device)        # [B, H, W]

            optimizer.zero_grad()
            seg_logits, aux_loss = model(images)
            # seg_logits shape: [B, num_classes, H, W]
            seg_loss = criterion(seg_logits, masks)
            loss = seg_loss + aux_loss_weight * aux_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Calculate Dice score on training mini-batch:
            preds = torch.argmax(seg_logits, dim=1)  # shape: [B, H, W]
            batch_dice = dice_score(preds, masks, num_classes)
            running_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice

            # Every plot_freq iterations, plot a few example predictions from the current batch.
            # if (i + 1) % plot_freq == 0:
            #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
            #     # Plot first few images. Adjust n_samples as desired.
            #     plot_predictions(images, masks, preds, n_samples=3, title_prefix="Train")

            if (i + 1) % 5 == 0:
                avg_loss = running_loss / (i+1)
                avg_dice = running_dice / (i+1)
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Avg Loss: {avg_loss:.4f}, Avg Dice: {avg_dice:.4f}")
        print(f"End of epoch {epoch+1}: Avg Loss: {running_loss/len(dataloader):.4f}, Avg Dice: {running_dice/len(dataloader):.4f}")

# -----------------------------
# 8. Evaluation loop with prediction plotting
# -----------------------------
def evaluate(model, dataloader, device, num_classes, n_samples_to_plot=3):
    model.eval()
    total_dice = 0.0
    count = 0
    plotted = False  # flag to plot only once per evaluation run
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]
            seg_logits, _ = model(images)
            preds = torch.argmax(seg_logits, dim=1)  # [B, H, W]
            batch_dice = dice_score(preds, masks, num_classes)
            total_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice
            count += 1

            # Optionally plot predictions from one batch during evaluation
            if not plotted:
                plot_prediction_and_mask(images, masks, preds, n_samples=n_samples_to_plot, title_prefix="Eval")
                plotted = True

    avg_dice = total_dice / count
    print(f"Average Dice score on Validation: {avg_dice:.4f}")
    return avg_dice

# -----------------------------
# 9. Main: Dataset setup, model instantiation, training, and evaluation
# -----------------------------
# def main():
    # Hyperparameters
num_epochs = 50
batch_size = 16
learning_rate = 1e-4
num_classes = 13  # Adjust this based on your segmentation classes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create full dataset
full_dataset = LyftSegmentationDataset(
    image_list=image_list,
    mask_list=mask_list,
    transform_image=transform_image,
    transform_mask=transform_mask
)

# Split the dataset into training and validation sets (80-20 split)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Instantiate model
model = SwitchTransformerSegmentationModel(
    img_size=img_size,
    patch_size=16,
    in_channels=3,
    embed_dim=256,    # Reduced embed_dim to speed testing; adjust as needed
    depth=3,          # Adjust depth as required
    num_experts=4,
    expansion=4,
    num_heads=4,
    dropout=0.1,
    num_classes=num_classes
)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model with Dice score monitoring and prediction plotting
print("Starting training...")
train(model, train_loader, optimizer, device, num_epochs=num_epochs, aux_loss_weight=0.01, num_classes=num_classes, plot_freq=10)
print("Training completed!")

# Evaluate the model using Dice score and plot predictions from the validation set
print("Evaluating model performance...")
evaluate(model, val_loader, device, num_classes, n_samples_to_plot=3)

# if __name__ == "__main__":
#     main()


Starting training...
Epoch [1/50], Step [5/50], Avg Loss: 2.2019, Avg Dice: 0.1030
Epoch [1/50], Step [10/50], Avg Loss: 1.9340, Avg Dice: 0.1568
Epoch [1/50], Step [15/50], Avg Loss: 1.7855, Avg Dice: 0.2019
Epoch [1/50], Step [20/50], Avg Loss: 1.6643, Avg Dice: 0.2339
Epoch [1/50], Step [25/50], Avg Loss: 1.5654, Avg Dice: 0.2811
Epoch [1/50], Step [30/50], Avg Loss: 1.4856, Avg Dice: 0.3328
Epoch [1/50], Step [35/50], Avg Loss: 1.4177, Avg Dice: 0.3854
Epoch [1/50], Step [40/50], Avg Loss: 1.3622, Avg Dice: 0.4288
Epoch [1/50], Step [45/50], Avg Loss: 1.3170, Avg Dice: 0.4624
Epoch [1/50], Step [50/50], Avg Loss: 1.2730, Avg Dice: 0.4896
End of epoch 1: Avg Loss: 1.2730, Avg Dice: 0.4896
Epoch [2/50], Step [5/50], Avg Loss: 0.8502, Avg Dice: 0.7959
Epoch [2/50], Step [10/50], Avg Loss: 0.8439, Avg Dice: 0.8028
Epoch [2/50], Step [15/50], Avg Loss: 0.8236, Avg Dice: 0.8062
Epoch [2/50], Step [20/50], Avg Loss: 0.8168, Avg Dice: 0.8069
Epoch [2/50], Step [25/50], Avg Loss: 0.8085, Av

KeyboardInterrupt: 

In [9]:
def evaluate(model, dataloader, device, num_classes, n_samples_to_plot=3):
    model.eval()
    total_dice = 0.0
    count = 0
    plotted = False  # flag to plot only once per evaluation run
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]
            seg_logits, _ = model(images)
            preds = torch.argmax(seg_logits, dim=1)  # [B, H, W]
            batch_dice = dice_score(preds, masks, num_classes)
            total_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice
            count += 1

            # Optionally plot predictions from one batch during evaluation
            if not plotted:
                plot_prediction_and_mask(images, masks, preds, class_colors = class_colors)
                plotted = True

    avg_dice = total_dice / count
    print(f"Average Dice score on Validation: {avg_dice:.4f}")
    return avg_dice


In [10]:
evaluate(model, val_loader, device, num_classes, n_samples_to_plot=3)


ValueError: too many values to unpack (expected 2)

In [49]:
import os
import numpy as np
from PIL import Image

def find_unique_classes_from_masks(mask_files):
    """
    Given a list of mask file paths, compute the union of unique pixel values
    from each mask, and return a sorted list of unique classes.
    """
    unique_classes = set()

    for mask_path in mask_files:
        # Open the mask image and convert it to a numpy array
        mask = Image.open(mask_path)
        mask_array = np.array(mask)

        # Update the unique classes set with the values found in this mask
        unique_values = np.unique(mask_array)
        unique_classes.update(unique_values)

    # Convert the set to a sorted list before returning
    return sorted(unique_classes)

# Example usage:
# if __name__ == "__main__":
    # Replace this with the path to your mask folder
    # mask_folder = "path/to/your/mask_folder"
    # Get a list of all mask file paths (assuming PNG or JPG images)
mask_files = mask_paths

unique_classes = find_unique_classes_from_masks(mask_files)
num_classes = len(unique_classes)

print("Found unique classes:", unique_classes)
print("Number of classes:", num_classes)


Found unique classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 255]
Number of classes: 20


# BD100K

In [14]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
solesensei_solesensei_bdd100k_path = kagglehub.dataset_download('solesensei/solesensei_bdd100k')

print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/solesensei/solesensei_bdd100k?dataset_version_number=2...


100%|██████████| 7.61G/7.61G [06:17<00:00, 21.6MB/s]

Extracting files...





Data source import complete.
Num GPUs Available:  1
0 data points loaded in total!


In [17]:
print(solesensei_solesensei_bdd100k_path)

/root/.cache/kagglehub/datasets/solesensei/solesensei_bdd100k/versions/2


In [1]:

import tensorflow as tf
import os
import numpy as np
import glob
from sklearn import utils as sk_utils


print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# define image size
IMG_WIDTH = 256
IMG_HEIGHT = 192
IMG_CHANNELS = 3


TRAIN_PATH = '/root/.cache/kagglehub/datasets/solesensei/solesensei_bdd100k/versions/2/bdd100k_seg/bdd100k/seg/images/train/*.jpg'
MASK_TRAIN = '/root/.cache/kagglehub/datasets/solesensei/solesensei_bdd100k/versions/2/bdd100k_seg/bdd100k/seg/labels/train/*.png'

def load_image(file_path, is_x=False):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image)

    if is_x:
        image = tf.image.resize(image, (IMG_HEIGHT,IMG_WIDTH), method=tf.image.ResizeMethod.LANCZOS3)
        image = image / 255.0  # Normalize the image
    else:
        image = tf.image.resize(image, (IMG_HEIGHT,IMG_WIDTH), method='nearest', antialias=True)
    return image


image_paths = sorted(glob.glob(TRAIN_PATH))
mask_paths = sorted(glob.glob(MASK_TRAIN))
image_paths, mask_paths = sk_utils.shuffle(image_paths, mask_paths, random_state=42)

# X_train = []
# Y_train = []

# for x, y in zip(image_paths, mask_paths):
#     X_train.append(load_image(x, True))
#     Y_train.append(load_image(y))
#     if len(X_train) % 500 == 0:
#         print(len(X_train), 'data points loaded!')
# else:
#     print(len(X_train), 'data points loaded in total!')


# # Convert the lists to NumPy arrays
# X_train = np.array(X_train)
# Y_train = np.array(Y_train)
# # 255 is representing unknown objects
# Y_train[Y_train == 255] = 19


# import matplotlib.pyplot as plt
# color_dict = {
#     0: (0.7, 0.7, 0.7),     # road - gray
#     1:  (0.9, 0.9, 0.2),     # sidewalk - light yellow
#     2: (1.0, 0.4980392156862745, 0.054901960784313725),
#     3: (1.0, 0.7333333333333333, 0.47058823529411764),
#     4: (0.8, 0.5, 0.1),  # Fence - rust orange
#     5: (0.596078431372549, 0.8745098039215686, 0.5411764705882353),
#     6: (0.325, 0.196, 0.361),
#     7: (1.0, 0.596078431372549, 0.5882352941176471),
#     8:  (0.2, 0.6, 0.2),     # vegetation - green
#     9: (0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
#     10: (0.5, 0.7, 1.0),     # sky - light blue
#     11: (1.0, 0.0, 0.0), # person - red
#     12: (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
#     13: (0.0, 0.0, 1.0),  # Car - blue
#     14: (0.0, 0.0, 1.0),  # Track - blue
#     15: (0.0, 0.0, 1.0),  # Bus - blue
#     16: (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
#     17: (0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
#     18: (0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
#     19: (0, 0, 0) # unknown - black
# }

# def colorize_image(image, color_dict):
#     # remove the extra dimension
#     image = np.squeeze(image)
#     # Generate the colored image using the color dictionary
#     colored_image = np.zeros((image.shape[0], image.shape[1], 3))

#     for pixel_value, color in color_dict.items():
#         colored_image[image == pixel_value] = color

#     # Convert the image to 8-bit unsigned integer
#     colored_image = (colored_image * 255).astype(np.uint8)

#     return colored_image

Num GPUs Available:  1


In [2]:
!ls /root/.cache/kagglehub/datasets/solesensei/solesensei_bdd100k/versions/2/bdd100k_seg/bdd100k/seg/

color_labels  images  labels


In [48]:
np.unique(Image.open(mask_paths[100]).convert("L"))

array([  0,   1,   2,   5,   6,   7,   8,   9,  10,  13,  14,  15, 255],
      dtype=uint8)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# -----------------------------
# 1. Define your Dataset class
# -----------------------------
class LyftSegmentationDataset(Dataset):
    """
    Dataset for segmentation that loads images and masks from disk.
    Assumes images are in RGB and masks are single-channel with integer labels.
    """
    def __init__(self, image_list, mask_list, transform_image=None, transform_mask=None):
        assert len(image_list) == len(mask_list), "Image and mask lists must be equal length."
        self.image_list = image_list
        self.mask_list = mask_list
        self.transform_image = transform_image
        self.transform_mask = transform_mask

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx]).convert("RGB")
        mask = Image.open(self.mask_list[idx]).convert("L")  # Convert to single channel

        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)
        else:
            mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        mask[mask == 255] = 19  # or map to another appropriate class index if needed
        return image, mask

# -----------------------------
# 2. Define image and mask transformations
# -----------------------------
img_size = 224  # Defined image size
transform_image = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),  # [0, 1] float tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_mask = transforms.Compose([
    transforms.Resize((img_size, img_size), interpolation=Image.NEAREST),
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x, dtype=np.int64)))
])

# -----------------------------
# 3. File lists for images and masks
# -----------------------------
# image_path = '/kaggle/input/lyft-udacity-challenge/dataa/dataA/CameraRGB/'
# mask_path = '/kaggle/input/lyft-udacity-challenge/dataa/dataA/CameraSeg/'

# image_list = sorted([os.path.join(image_path, i) for i in os.listdir(image_path)])
# mask_list = sorted([os.path.join(mask_path, i) for i in os.listdir(mask_path)])

# -----------------------------
# 4. Define the Switch Transformer Segmentation Model and components
# -----------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

    def forward(self, x):
        # x: [B, in_channels, H, W]
        x = self.proj(x)                    # [B, embed_dim, H/patch_size, W/patch_size]
        x = x.flatten(2)                    # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)               # [B, num_patches, embed_dim]
        x += self.pos_embedding             # Add position embedding
        return x

class ExpertFFN(nn.Module):
    def __init__(self, embed_dim, expansion=4):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion)
        self.fc2 = nn.Linear(embed_dim * expansion, embed_dim)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

class SwitchFFN(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([ExpertFFN(embed_dim, expansion) for _ in range(num_experts)])
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        B, N, D = x.shape
        logits = self.gate(x)               # [B, N, num_experts]
        indices = logits.argmax(dim=-1)     # [B, N] - select expert index with highest score
        output = torch.zeros_like(x)
        expert_counts = torch.zeros(B, self.num_experts, device=x.device)
        for expert_id in range(self.num_experts):
            mask = (indices == expert_id)   # Boolean mask [B, N]
            if mask.sum() == 0:
                continue
            selected_tokens = x[mask]       # [num_selected, D]
            expert_output = self.experts[expert_id](selected_tokens)
            output[mask] = expert_output
            expert_counts[:, expert_id] += mask.float().sum(dim=1) if mask.dim() == 2 else mask.float().sum(dim=0)
        return output, expert_counts

class SwitchTransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.switch_ffn = SwitchFFN(embed_dim, num_experts=num_experts, expansion=expansion)

    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ffn_out, expert_counts = self.switch_ffn(x)
        x = x + self.dropout(ffn_out)
        x = self.norm2(x)
        return x, expert_counts

class SwitchTransformerEncoder(nn.Module):
    def __init__(self, depth, embed_dim, num_experts=4, expansion=4, num_heads=8, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            SwitchTransformerLayer(embed_dim, num_experts, expansion, num_heads, dropout)
            for _ in range(depth)
        ])

    def forward(self, x):
        aux_loss = 0.0
        for layer in self.layers:
            x, expert_counts = layer(x)
            # Optionally, incorporate an auxiliary loss for load balancing here if desired.
        return x, aux_loss

class SegmentationDecoder(nn.Module):
    def __init__(self, embed_dim, img_size=224, patch_size=16, num_classes=21):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches_side = img_size // patch_size
        self.conv1 = nn.Conv2d(embed_dim, embed_dim // 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(embed_dim // 2)
        self.conv2 = nn.Conv2d(embed_dim // 2, num_classes, kernel_size=1)

    def forward(self, x):
        B, N, D = x.shape
        H = W = self.num_patches_side
        x = x.transpose(1, 2).reshape(B, D, H, W)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
        logits = self.conv2(x)
        return logits

class SwitchTransformerSegmentationModel(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6,
                 num_experts=4, expansion=4, num_heads=8, dropout=0.1, num_classes=21):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.encoder = SwitchTransformerEncoder(depth, embed_dim, num_experts, expansion, num_heads, dropout)
        self.decoder = SegmentationDecoder(embed_dim, img_size, patch_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x, aux_loss = self.encoder(x)
        segmentation_logits = self.decoder(x)
        return segmentation_logits, aux_loss

# -----------------------------
# 5. Define Dice Score Function
# -----------------------------
def dice_score(pred, target, num_classes, epsilon=1e-6):
    """
    Computes the Dice score per class and returns the average score.
    pred: tensor of shape [B, H, W] with predicted classes.
    target: tensor of shape [B, H, W] with ground truth classes.
    """
    dice = 0.0
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        dice += (2 * intersection + epsilon) / (union + epsilon)
    return dice / num_classes


import matplotlib.pyplot as plt
import numpy as np

# Define your class colors (adjust the number of classes/colors as needed)
class_colors = {
    0: (1, 0, 0),   # red  (often used for background, or class 0)
    1: (0, 1, 0),   # green (class 1)
    2: (0, 0, 1),   # blue  (class 2)
    3: (1, 0, 1)    # magenta (class 3)
}


def colorize_mask(mask, class_colors):
    """
    Colorize a segmentation mask given a mapping from class labels to colors.

    Parameters:
        mask (np.ndarray): 2D array (height, width) of class indices.
        class_colors (dict): dictionary mapping class indices to RGB tuples.

    Returns:
        np.ndarray: colorized mask of shape (height, width, 3) for visualization.
    """
    # Create an empty array with an extra channel for color (RGB)
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.float32)

    # Assign colors for each class present in the mask
    for class_id, color in class_colors.items():
        color_mask[mask == class_id] = color

    return color_mask

# -----------------------------
# 6. Helper function to plot predictions
# -----------------------------
def plot_prediction_and_mask(image, pred_mask, true_mask, class_colors):
    """
    Plot the original image, prediction mask, and ground truth mask.

    Parameters:
        image (np.ndarray): the input image of shape (height, width, 3).
        pred_mask (np.ndarray): predicted segmentation mask (class indices).
        true_mask (np.ndarray): ground truth segmentation mask (class indices).
        class_colors (dict): mapping from class labels to colors.
    """

    # Create color masks for the predicted and true masks
    pred_color_mask = colorize_mask(pred_mask, class_colors)
    true_color_mask = colorize_mask(true_mask, class_colors)

    # Plotting using matplotlib
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Show original image
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Show predicted mask
    axes[1].imshow(pred_color_mask)
    axes[1].set_title("Predicted Mask")
    axes[1].axis('off')

    # Show ground truth mask
    axes[2].imshow(true_color_mask)
    axes[2].set_title("Ground Truth Mask")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()


# -----------------------------
# 7. Training loop with Dice score and prediction plotting
# -----------------------------
def train(model, dataloader, optimizer, device, num_epochs=5, aux_loss_weight=0.01, num_classes=21, plot_freq=10):
    model.train()
    criterion = nn.CrossEntropyLoss()  # segmentation cross-entropy loss

    for epoch in range(num_epochs):
        running_loss = 0.0
        running_dice = 0.0
        for i, (images, masks) in enumerate(dataloader):
            images = images.to(device)      # [B, 3, H, W]
            masks = masks.to(device)        # [B, H, W]
            masks = masks.long()

            optimizer.zero_grad()
            seg_logits, aux_loss = model(images)
            # assert outputs.shape[1] == num_classes, "The number of output channels must equal the number of classes."

            # seg_logits shape: [B, num_classes, H, W]
            seg_loss = criterion(seg_logits, masks)
            loss = seg_loss + aux_loss_weight * aux_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Calculate Dice score on training mini-batch:
            preds = torch.argmax(seg_logits, dim=1)  # shape: [B, H, W]
            batch_dice = dice_score(preds, masks, num_classes)
            running_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice

            # Every plot_freq iterations, plot a few example predictions from the current batch.
            # if (i + 1) % plot_freq == 0:
            #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
            #     # Plot first few images. Adjust n_samples as desired.
            #     plot_predictions(images, masks, preds, n_samples=3, title_prefix="Train")

            if (i + 1) % 5 == 0:
                avg_loss = running_loss / (i+1)
                avg_dice = running_dice / (i+1)
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Avg Loss: {avg_loss:.4f}, Avg Dice: {avg_dice:.4f}")
        print(f"End of epoch {epoch+1}: Avg Loss: {running_loss/len(dataloader):.4f}, Avg Dice: {running_dice/len(dataloader):.4f}")

# -----------------------------
# 8. Evaluation loop with prediction plotting
# -----------------------------
def evaluate(model, dataloader, device, num_classes, n_samples_to_plot=3):
    model.eval()
    total_dice = 0.0
    count = 0
    plotted = False  # flag to plot only once per evaluation run
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]
            seg_logits, _ = model(images)
            preds = torch.argmax(seg_logits, dim=1)  # [B, H, W]
            batch_dice = dice_score(preds, masks, num_classes)
            total_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice
            count += 1

            # Optionally plot predictions from one batch during evaluation
            if not plotted:
                plot_prediction_and_mask(images, masks, preds, n_samples=n_samples_to_plot, title_prefix="Eval")
                plotted = True

    avg_dice = total_dice / count
    print(f"Average Dice score on Validation: {avg_dice:.4f}")
    return avg_dice

# -----------------------------
# 9. Main: Dataset setup, model instantiation, training, and evaluation
# -----------------------------
# def main():
    # Hyperparameters
# ... (Your previous code for model and dataset)

# Hyperparameters
num_epochs = 50
batch_size = 16
learning_rate = 1e-4
num_classes = 21  # Change this to reflect the number of classes in the BDD100K dataset

# ... (Rest of your code for model loading and training)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create full dataset
full_dataset = LyftSegmentationDataset(
    image_list=image_paths,
    mask_list=mask_paths,
    transform_image=transform_image,
    transform_mask=transform_mask
)

# Split the dataset into training and validation sets (80-20 split)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Instantiate model
model = SwitchTransformerSegmentationModel(
    img_size=img_size,
    patch_size=16,
    in_channels=3,
    embed_dim=256,    # Reduced embed_dim to speed testing; adjust as needed
    depth=3,          # Adjust depth as required
    num_experts=4,
    expansion=4,
    num_heads=4,
    dropout=0.1,
    num_classes=num_classes
)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model with Dice score monitoring and prediction plotting
print("Starting training...")
train(model, train_loader, optimizer, device, num_epochs=num_epochs, aux_loss_weight=0.01, num_classes=num_classes, plot_freq=10)
print("Training completed!")

# Evaluate the model using Dice score and plot predictions from the validation set
print("Evaluating model performance...")
evaluate(model, val_loader, device, num_classes, n_samples_to_plot=3)

# if __name__ == "__main__":
#     main()


Starting training...
Epoch [1/50], Step [5/350], Avg Loss: 3.0115, Avg Dice: 0.0392
Epoch [1/50], Step [10/350], Avg Loss: 2.8190, Avg Dice: 0.0641
Epoch [1/50], Step [15/350], Avg Loss: 2.6982, Avg Dice: 0.0879
Epoch [1/50], Step [20/350], Avg Loss: 2.5907, Avg Dice: 0.1077
Epoch [1/50], Step [25/350], Avg Loss: 2.5059, Avg Dice: 0.1292
Epoch [1/50], Step [30/350], Avg Loss: 2.4433, Avg Dice: 0.1441
Epoch [1/50], Step [35/350], Avg Loss: 2.3851, Avg Dice: 0.1549
Epoch [1/50], Step [40/350], Avg Loss: 2.3198, Avg Dice: 0.1624
Epoch [1/50], Step [45/350], Avg Loss: 2.2674, Avg Dice: 0.1754
Epoch [1/50], Step [50/350], Avg Loss: 2.2314, Avg Dice: 0.1901
Epoch [1/50], Step [55/350], Avg Loss: 2.1911, Avg Dice: 0.1999
Epoch [1/50], Step [60/350], Avg Loss: 2.1644, Avg Dice: 0.2055
Epoch [1/50], Step [65/350], Avg Loss: 2.1286, Avg Dice: 0.2143
Epoch [1/50], Step [70/350], Avg Loss: 2.1023, Avg Dice: 0.2226
Epoch [1/50], Step [75/350], Avg Loss: 2.0814, Avg Dice: 0.2301
Epoch [1/50], Step [

In [8]:


optimizer = optim.Adam(model.parameters(), lr=learning_rate)



model.train()
criterion = nn.CrossEntropyLoss()  # segmentation cross-entropy loss

for epoch in range(num_epochs):
    running_loss = 0.0
    running_dice = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)      # [B, 3, H, W]
        masks = masks.to(device)        # [B, H, W]
        masks = masks.long()
        # print(np.unique(masks.detach().cpu().numpy()))
        optimizer.zero_grad()
        seg_logits, aux_loss = model(images)
        assert seg_logits.shape[1] == num_classes, "The number of output channels must equal the number of classes."

        # seg_logits shape: [B, num_classes, H, W]
        seg_loss = criterion(seg_logits, masks)
        loss = seg_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # Calculate Dice score on training mini-batch:
        preds = torch.argmax(seg_logits, dim=1)  # shape: [B, H, W]
        batch_dice = dice_score(preds, masks, num_classes)
        running_dice += batch_dice.item() if isinstance(batch_dice, torch.Tensor) else batch_dice

        # Every plot_freq iterations, plot a few example predictions from the current batch.
        # if (i + 1) % plot_freq == 0:
        #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
        #     # Plot first few images. Adjust n_samples as desired.
        #     plot_predictions(images, masks, preds, n_samples=3, title_prefix="Train")

        if (i + 1) % 5 == 0:
            avg_loss = running_loss / (i+1)
            avg_dice = running_dice / (i+1)
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Avg Loss: {avg_loss:.4f}, Avg Dice: {avg_dice:.4f}")
    print(f"End of epoch {epoch+1}: Avg Loss: {running_loss/len(train_loader):.4f}, Avg Dice: {running_dice/len(train_loader):.4f}")


















Epoch [1/50], Step [5/350], Avg Loss: 1.1815, Avg Dice: 0.3429
Epoch [1/50], Step [10/350], Avg Loss: 1.1857, Avg Dice: 0.3346
Epoch [1/50], Step [15/350], Avg Loss: 1.1698, Avg Dice: 0.3359
Epoch [1/50], Step [20/350], Avg Loss: 1.1674, Avg Dice: 0.3406
Epoch [1/50], Step [25/350], Avg Loss: 1.1587, Avg Dice: 0.3384
Epoch [1/50], Step [30/350], Avg Loss: 1.1655, Avg Dice: 0.3368
Epoch [1/50], Step [35/350], Avg Loss: 1.1670, Avg Dice: 0.3333
Epoch [1/50], Step [40/350], Avg Loss: 1.1683, Avg Dice: 0.3325
Epoch [1/50], Step [45/350], Avg Loss: 1.1726, Avg Dice: 0.3352
Epoch [1/50], Step [50/350], Avg Loss: 1.1671, Avg Dice: 0.3343
Epoch [1/50], Step [55/350], Avg Loss: 1.1657, Avg Dice: 0.3373
Epoch [1/50], Step [60/350], Avg Loss: 1.1707, Avg Dice: 0.3356
Epoch [1/50], Step [65/350], Avg Loss: 1.1702, Avg Dice: 0.3355
Epoch [1/50], Step [70/350], Avg Loss: 1.1668, Avg Dice: 0.3389
Epoch [1/50], Step [75/350], Avg Loss: 1.1635, Avg Dice: 0.3405
Epoch [1/50], Step [80/350], Avg Loss: 1.

KeyboardInterrupt: 

In [8]:
images.shape

torch.Size([16, 3, 224, 224])