In [None]:

import os, sys, torch
print('Python:', sys.version)
print('PyTorch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0))

# Check presence of mamba_ssm (optional) and mamba_model (required)
try:
    import importlib.util
    spec = importlib.util.find_spec('mamba_ssm')
    print('mamba_ssm found:', spec is not None)
except Exception as e:
    print('mamba_ssm check error:', e)

try:
    spec2 = importlib.util.find_spec('mamba_model')
    print('mamba_model found:', spec2 is not None)
except Exception as e:
    print('mamba_model check error:', e)


Python: 3.11.13 (main, Jun  5 2025, 13:12:00) [GCC 11.2.0]
PyTorch: 2.8.0+cu129
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3070 Ti
mamba_ssm found: True
mamba_model found: True


In [None]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Optional RMSNorm from mamba_ssm; fall back to LayerNorm if missing ---
try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm
except Exception:
    RMSNorm = None

# --- REQUIRED: your local Mamba implementation exported from mamba_model.py ---
from mamba_model import Mamba


# ------------------------- Utils -------------------------
class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


# ------------------------- Involution stem -------------------------
class Involution2D(nn.Module):
    """
    Involution: location-specific, channel-agnostic kernel.
    Paper: "Involution: Inverting the Inherence of Convolution for Visual Recognition" (CVPR 2021)
    """
    def __init__(self, channels, kernel_size=3, stride=1, reduction_ratio=4, groups=1):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups

        hidden = max(channels // reduction_ratio, 1)
        self.reduce = nn.Conv2d(channels, hidden, 1)
        self.act = nn.ReLU(inplace=True)
        self.span = nn.Conv2d(hidden, groups * (kernel_size * kernel_size), 1)
        self.sigma = nn.AvgPool2d(kernel_size=stride, stride=stride) if stride > 1 else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        k, g = self.kernel_size, self.groups
        assert C % g == 0, f"Channels ({C}) must be divisible by groups ({g})."

        # 1) Generate spatially-varying kernels on possibly downsampled features
        x_k = self.sigma(x)                                   # (B, C, H', W')
        K = self.span(self.act(self.reduce(x_k)))             # (B, g*k*k, H', W')
        Hout, Wout = K.shape[2], K.shape[3]
        K = K.view(B, g, k * k, Hout, Wout)                   # (B, g, k*k, H', W')

        # 2) Unfold patches on original x with stride/padding so H',W' align
        patches = F.unfold(x, kernel_size=k, dilation=1, padding=k // 2, stride=self.stride)
        patches = patches.view(B, g, C // g, k * k, Hout, Wout)  # (B, g, Cg, k*k, H', W')

        # 3) Channel-agnostic weighted sum within group
        out = (patches * K.unsqueeze(2)).sum(dim=3)           # (B, g, Cg, H', W')
        out = out.view(B, C, Hout, Wout)
        return out


class OverlapPatchEmbedInvo(nn.Module):
    """Involution downsample + 1x1 projection to embed_dim."""
    def __init__(self, in_chans=3, embed_dim=32, kernel_size=3, stride=2, padding=1,
                 reduction_ratio=4, groups=1):
        super().__init__()
        self.invo = Involution2D(in_chans, kernel_size=kernel_size, stride=stride,
                                 reduction_ratio=reduction_ratio, groups=groups)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1, padding=0)
        self.norm = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.invo(x)     # downsample + spatially-varying mixing
        x = self.proj(x)     # channel set to embed_dim
        x = self.norm(x)
        return x


# ------------------------- Conv stem (fast option) -------------------------
class OverlapPatchEmbedConv(nn.Module):
    """Standard conv downsample + BN + ReLU."""
    def __init__(self, in_chans=3, embed_dim=32, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.proj(x)


# ------------------------- Mixer Block (Mamba) -------------------------
class MambaBlock(nn.Module):
    """
    Mamba SSM mixer on (B, N, C) with PreNorm + MLP and residuals.
    """
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0, d_state=16, use_rmsnorm=True):
        super().__init__()
        Norm = RMSNorm if (use_rmsnorm and RMSNorm is not None) else nn.LayerNorm

        self.norm1 = Norm(dim)
        self.mamba = Mamba(d_model=dim, d_state=d_state)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = Norm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=hidden_dim, act_layer=nn.GELU, drop=drop)

    def forward(self, x, H=None, W=None):
        # x: (B, N, C)
        shortcut = x
        x = self.norm1(x)
        x = self.mamba(x)
        x = self.drop_path(x) + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x) + shortcut
        return x


# ------------------------- MiT Stage (stem + Mamba blocks) -------------------------
class MiTStage(nn.Module):
    def __init__(self, in_chs, embed_dim, depth, drop_path_rates=None, use_involution=True):
        super().__init__()
        # choose stem type
        if use_involution:
            self.patch_embed = OverlapPatchEmbedInvo(
                in_chans=in_chs, embed_dim=embed_dim, kernel_size=3, stride=2, padding=1,
                reduction_ratio=4, groups=1
            )
        else:
            self.patch_embed = OverlapPatchEmbedConv(
                in_chans=in_chs, embed_dim=embed_dim, kernel_size=3, stride=2, padding=1
            )

        self.blocks = nn.ModuleList([
            MambaBlock(
                dim=embed_dim,
                mlp_ratio=4.0,
                drop=0.0,
                drop_path=(drop_path_rates[i] if drop_path_rates is not None else 0.0),
                d_state=16,
                use_rmsnorm=True
            )
            for i in range(depth)
        ])

    def forward(self, x):
        x = self.patch_embed(x)              # (B, C, H', W')
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)     # (B, N, C)
        for blk in self.blocks:
            x = blk(x, H, W)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        return x


# ------------------------- Mix Vision Transformer (backbone) -------------------------
class MixVisionTransformer(nn.Module):
    """
    Full MiT backbone (4 stages) returning multi-scale features:
        [S1, S2, S3, S4] with strides [2, 4, 8, 16] wrt input.
    """
    def __init__(self,
                 in_chans: int = 3,
                 embed_dims = [32, 64, 160, 256],
                 depths    = [2, 2, 2, 2],
                 drop_path_rate: float = 0.0,
                 use_invo_stages = (True, True, False, False)):
        super().__init__()
        assert len(embed_dims) == 4 and len(depths) == 4 and len(use_invo_stages) == 4

        # Stochastic depth schedule across all blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0

        self.stages = nn.ModuleList()
        in_c = in_chans
        for i in range(4):
            d = depths[i]
            dpr_slice = dp_rates[cur:cur + d]
            cur += d

            self.stages.append(
                MiTStage(
                    in_chs=in_c,
                    embed_dim=embed_dims[i],
                    depth=d,
                    drop_path_rates=dpr_slice,
                    use_involution=use_invo_stages[i],
                )
            )
            in_c = embed_dims[i]

    def forward(self, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features


# ------------------------------- Decoder --------------------------------
class SegFormerHead(nn.Module):
    def __init__(self, in_channels: List[int], embed_dim=128, num_classes=19):
        super().__init__()
        self.proj_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, embed_dim, 1, bias=False),
                nn.BatchNorm2d(embed_dim),
                nn.ReLU(inplace=True)
            )
            for in_ch in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(embed_dim * len(in_channels), embed_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, features: List[torch.Tensor]):
        # Upsample all features to the highest spatial resolution (stage 1)
        target_h, target_w = features[0].shape[2:]
        proj = []
        for i, feat in enumerate(features):
            x = self.proj_layers[i](feat)
            if x.shape[2:] != (target_h, target_w):
                x = F.interpolate(x, size=(target_h, target_w), mode='bilinear', align_corners=False)
            proj.append(x)
        x = torch.cat(proj, dim=1)
        x = self.fuse(x)
        x = self.classifier(x)
        return x


# -------------------------- Full SegFormer Model -------------------------
class SegFormer(nn.Module):
    def __init__(self, num_classes=19, variant='mit_b0', pretrained=False,
                 drop_path_rate=0.1, in_chans=3, use_invo_stages=(True, True, False, False),
                 embed_dims=None, depths=None): # Added embed_dims and depths
        super().__init__()
        variants = {
            'mit_b0': dict(embed_dims=[16, 32, 64, 128], depths=[1, 1, 1, 1]),
            'mit_b1': dict(embed_dims=[64, 128, 320, 512], depths=[2, 2, 2, 2]),
            'mit_b2': dict(embed_dims=[64, 128, 320, 512], depths=[3, 4, 6, 3]),
            'mit_b3': dict(embed_dims=[64, 128, 320, 512], depths=[3, 4, 18, 3]),
            'mit_b4': dict(embed_dims=[64, 128, 320, 512], depths=[3, 8, 27, 3]),
            'mit_b5': dict(embed_dims=[64, 128, 320, 512], depths=[3, 6, 40, 3]),
        }

        if variant in variants:
            cfg = variants[variant]
            embed_dims = embed_dims or cfg['embed_dims']
            depths = depths or cfg['depths']
        elif embed_dims is None or depths is None:
             raise ValueError(f'Unknown variant: {variant} and embed_dims or depths not provided')


        self.backbone = MixVisionTransformer(
            in_chans=in_chans,
            embed_dims=embed_dims, # Pass embed_dims
            depths=depths, # Pass depths
            drop_path_rate=drop_path_rate,
            use_invo_stages=use_invo_stages
        )
        self.decoder = SegFormerHead(in_channels=embed_dims, embed_dim=128, num_classes=num_classes) # Pass embed_dims
        self.apply(self._init_weights)

        if pretrained:
            print('pretrained=True selected but no loader is implemented in this script.')

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        H, W = x.shape[2], x.shape[3]
        feats = self.backbone(x)                       # [S1,S2,S3,S4]
        out = self.decoder(feats)                      # logits at S1 resolution
        out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
        return out

In [None]:
#imports
import os
import h5py
import numpy as np
import cv2
import torch
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim import AdamW


In [None]:
# Constants
PATCH_SIZE = 256
TARGET_SIZE = (256, 256)
LABEL_KEY = 'outlines'
MODALITIES_3 = ['dem', 'optical', 'bright_dark_outlines']
CHANNEL_INFO_3 = {'dem': 2, 'optical': 6, 'bright_dark_outlines': 3}

def normalize(arr):
    arr = arr.astype(np.float32)
    return (arr - arr.mean()) / (arr.std() + 1e-5)

In [None]:

def normalize(arr):
    arr = arr.astype(np.float32)
    return (arr - arr.mean()) / (arr.std() + 1e-5)

def augment_patch(img, label):
    if random.random() < 0.5:
        img = np.flip(img, axis=0)
        label = np.flip(label, axis=0)
    if random.random() < 0.5:
        img = np.flip(img, axis=1)
        label = np.flip(label, axis=1)
    return img, label

class GlacierHDF5PatchDataset3(torch.utils.data.Dataset):
    def __init__(self, hdf5_file_path, patch_size=PATCH_SIZE, length=2000):
        self.hdf5 = h5py.File(hdf5_file_path, 'r')
        self.tiles = [name for name in self.hdf5.keys() if all(m in self.hdf5[name] for m in MODALITIES_3)]
        self.patch_size = patch_size
        self.length = length
        self.selected_indices = [self._get_random_patch_coords() for _ in range(length)]

    def __len__(self):
        return self.length

    def _get_random_patch_coords(self):
        tile_name = random.choice(self.tiles)
        tile = self.hdf5[tile_name]
        h, w = tile[MODALITIES_3[0]].shape[:2]
        y = random.randint(0, h - self.patch_size)
        x = random.randint(0, w - self.patch_size)
        return tile_name, y, x

    def __getitem__(self, idx):
        tile_name, y, x = self.selected_indices[idx]
        tile = self.hdf5[tile_name]

        dem = normalize(tile['dem'][y:y+self.patch_size, x:x+self.patch_size, :])
        opt = normalize(tile['optical'][y:y+self.patch_size, x:x+self.patch_size, :])
        bd  = tile['bright_dark_outlines'][y:y+self.patch_size, x:x+self.patch_size, :].astype(np.float32)
        if bd.max() > 1:  # scale to [0,1]
            bd = bd / 255.0

        full_patch = np.concatenate([dem, opt, bd], axis=2)
        label = tile['outlines'][y:y+self.patch_size, x:x+self.patch_size]
        if label.ndim == 3:
            label = label[:, :, 0]

        aug_patch, aug_label = augment_patch(full_patch.copy(), label.copy())
        orig_tensor = torch.tensor(np.ascontiguousarray(full_patch)).permute(2,0,1).float()
        aug_tensor  = torch.tensor(np.ascontiguousarray(aug_patch)).permute(2,0,1).float()

        return orig_tensor, aug_tensor

    def close(self):
        self.hdf5.close()

In [None]:
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    jaccard_score,
    precision_score,
    recall_score,
    confusion_matrix
)

def segmentation_metrics(y_true, y_pred, num_classes=3):
    # Flatten arrays
    y_true = y_true.ravel()
    y_pred = y_pred.ravel()

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

    # Pixel accuracy
    pixel_acc = accuracy_score(y_true, y_pred)

    # IoU (Jaccard Index)
    per_class_iou = jaccard_score(y_true, y_pred, average=None, labels=list(range(num_classes)))
    mean_iou = jaccard_score(y_true, y_pred, average="macro", labels=list(range(num_classes)))

    # Dice (F1-score)
    per_class_dice = f1_score(y_true, y_pred, average=None, labels=list(range(num_classes)))
    mean_dice = f1_score(y_true, y_pred, average="macro", labels=list(range(num_classes)))

    # Precision & Recall
    precision = precision_score(y_true, y_pred, average="macro", labels=list(range(num_classes)))
    recall = recall_score(y_true, y_pred, average="macro", labels=list(range(num_classes)))

    return {
        "pixel_acc": pixel_acc,
        "mean_iou": mean_iou,
        "mean_dice": mean_dice,
        "per_class_iou": per_class_iou,
        "per_class_dice": per_class_dice,
        "precision": precision,
        "recall": recall,
        "confusion_matrix": cm
    }


In [None]:
from tqdm import tqdm
def train_epoch(loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/len(loader)

def eval_epoch(loader, model, criterion, device):
    model.eval()
    total_loss = 0
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.numel()
    return total_loss/len(loader), correct/total


In [None]:
from torch.utils.tensorboard import SummaryWriter
import os

log_dir = "runs/layerwise"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)
%load_ext tensorboard
%tensorboard --logdir runs



The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 204389), started 12:24:58 ago. (Use '!kill 204389' to kill it.)

In [None]:
TRAIN_PATH = "/home/mt/dataset/20230905_train_global_ps384.hdf5"
VAL_PATH = "/home/mt/dataset/20230905_val_global_ps384.hdf5"
WEIGHT_DIR = "weights/layerwise"
os.makedirs(WEIGHT_DIR, exist_ok=True)

dataset_train = GlacierHDF5PatchDataset3(TRAIN_PATH, length=2000)
dataset_val = GlacierHDF5PatchDataset3(VAL_PATH, length = 500)
train_loader= DataLoader(dataset_train, batch_size=30, shuffle=True)
dev_loader = DataLoader(dataset_val, batch_size=30, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SegFormer(num_classes=2, variant='mit_b0', drop_path_rate=0.2, in_chans=11).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-6)
EPOCHS = 35
DEVICE = "cuda"
NUM_CLASSES = 2



In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Full model params
model = SegFormer(num_classes=19, variant="mit_b0", in_chans=3)
print("Total params:", count_parameters(model))

# Encoder vs Decoder
encoder_params = count_parameters(model.backbone)
decoder_params = count_parameters(model.decoder)
print("Encoder params:", encoder_params)
print("Decoder params:", decoder_params)


Total params: 679242
Encoder params: 431543
Decoder params: 247699


In [None]:
import torch
import time
from thop import profile
from torchinfo import summary

def analyze_model(model, input_size=(1, 11, 256, 256), device=None):
    # --- Device selection ---
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    # --- Move model and input to device ---
    model = model.to(device)
    dummy_input = torch.randn(*input_size).to(device)

    # --- Verify device consistency ---
    if next(model.parameters()).device != dummy_input.device:
        raise RuntimeError(f"Model and input device mismatch: model on {next(model.parameters()).device}, input on {dummy_input.device}")

    # --- FLOPs & MACs ---
    try:
        macs, params = profile(model, inputs=(dummy_input,), verbose=False)
        flops = 2 * macs  # 1 MAC = 2 FLOPs
    except Exception as e:
        print(f"[ERROR] thop profiling failed: {e}")
        macs, params, flops = 0, 0, 0  # Fallback values

    # --- Torchinfo summary ---
    try:
        print(summary(model, input_size=input_size, device=device))
    except Exception as e:
        print(f"[ERROR] torchinfo summary failed: {e}")

    # --- Inference time ---
    model.eval()
    with torch.no_grad():
        # Warm-up
        for _ in range(5):
            _ = model(dummy_input)
        torch.cuda.synchronize() if device.startswith("cuda") else None  # Sync for GPU
        start = time.time()
        for _ in range(20):
            _ = model(dummy_input)
        torch.cuda.synchronize() if device.startswith("cuda") else None  # Sync for GPU
        end = time.time()
    avg_time = (end - start) / 20 * 1000  # Convert to ms/sample

    return {
        "device": device,
        "params": params,
        "macs": macs / 1e6,  # Convert to millions
        "flops": flops / 1e6,  # Convert to millions
        "avg_inference_time_ms": avg_time
    }

# Example usage
try:
    model = SegFormer(num_classes=2, in_chans=11)  # Ensure SegFormer is defined
    stats = analyze_model(model, input_size=(1, 11, 256, 256), device="cuda")
    print(stats)
except Exception as e:
    print(f"[ERROR] Analysis failed: {e}")

[INFO] Using device: cuda
Layer (type:depth-idx)                             Output Shape              Param #
SegFormer                                          [1, 2, 256, 256]          --
├─MixVisionTransformer: 1-1                        [1, 16, 128, 128]         --
│    └─ModuleList: 2-1                             --                        --
│    │    └─MiTStage: 3-1                          [1, 16, 128, 128]         5,795
│    │    └─MiTStage: 3-2                          [1, 32, 64, 64]           19,057
│    │    └─MiTStage: 3-3                          [1, 64, 32, 32]           84,416
│    │    └─MiTStage: 3-4                          [1, 128, 16, 16]          322,432
├─SegFormerHead: 1-2                               [1, 2, 128, 128]          --
│    └─ModuleList: 2-2                             --                        --
│    │    └─Sequential: 3-5                        [1, 128, 128, 128]        2,304
│    │    └─Sequential: 3-6                        [1, 128, 64, 64]   

In [None]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Optional RMSNorm from mamba_ssm; fall back to LayerNorm if missing ---
try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm
except Exception:
    RMSNorm = None

# --- REQUIRED: your local Mamba implementation exported from mamba_model.py ---
from mamba_model import Mamba


# ------------------------- Utils -------------------------
class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


# ------------------------- Involution stem -------------------------
class Involution2D(nn.Module):
    """
    Involution: location-specific, channel-agnostic kernel.
    Paper: "Involution: Inverting the Inherence of Convolution for Visual Recognition" (CVPR 2021)
    """
    def __init__(self, channels, kernel_size=3, stride=1, reduction_ratio=4, groups=1):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups

        hidden = max(channels // reduction_ratio, 1)
        self.reduce = nn.Conv2d(channels, hidden, 1)
        self.act = nn.ReLU(inplace=True)
        self.span = nn.Conv2d(hidden, groups * (kernel_size * kernel_size), 1)
        self.sigma = nn.AvgPool2d(kernel_size=stride, stride=stride) if stride > 1 else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        k, g = self.kernel_size, self.groups
        assert C % g == 0, f"Channels ({C}) must be divisible by groups ({g})."

        # 1) Generate spatially-varying kernels on possibly downsampled features
        x_k = self.sigma(x)                                   # (B, C, H', W')
        K = self.span(self.act(self.reduce(x_k)))             # (B, g*k*k, H', W')
        Hout, Wout = K.shape[2], K.shape[3]
        K = K.view(B, g, k * k, Hout, Wout)                   # (B, g, k*k, H', W')

        # 2) Unfold patches on original x with stride/padding so H',W' align
        patches = F.unfold(x, kernel_size=k, dilation=1, padding=k // 2, stride=self.stride)
        patches = patches.view(B, g, C // g, k * k, Hout, Wout)  # (B, g, Cg, k*k, H', W')

        # 3) Channel-agnostic weighted sum within group
        out = (patches * K.unsqueeze(2)).sum(dim=3)           # (B, g, Cg, H', W')
        out = out.view(B, C, Hout, Wout)
        return out


class OverlapPatchEmbedInvo(nn.Module):
    """Involution downsample + 1x1 projection to embed_dim."""
    def __init__(self, in_chans=3, embed_dim=32, kernel_size=3, stride=2, padding=1,
                 reduction_ratio=4, groups=1):
        super().__init__()
        self.invo = Involution2D(in_chans, kernel_size=kernel_size, stride=stride,
                                 reduction_ratio=reduction_ratio, groups=groups)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=1, padding=0)
        self.norm = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.invo(x)     # downsample + spatially-varying mixing
        x = self.proj(x)     # channel set to embed_dim
        x = self.norm(x)
        return x


# ------------------------- Conv stem (fast option) -------------------------
class OverlapPatchEmbedConv(nn.Module):
    """Standard conv downsample + BN + ReLU."""
    def __init__(self, in_chans=3, embed_dim=32, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.proj(x)


# ------------------------- Mixer Block (Mamba) -------------------------
class MambaBlock(nn.Module):
    """
    Mamba SSM mixer on (B, N, C) with PreNorm + MLP and residuals.
    """
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0, d_state=16, use_rmsnorm=True):
        super().__init__()
        Norm = RMSNorm if (use_rmsnorm and RMSNorm is not None) else nn.LayerNorm

        self.norm1 = Norm(dim)
        self.mamba = Mamba(d_model=dim, d_state=d_state)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = Norm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=hidden_dim, act_layer=nn.GELU, drop=drop)

    def forward(self, x, H=None, W=None):
        # x: (B, N, C)
        shortcut = x
        x = self.norm1(x)
        x = self.mamba(x)
        x = self.drop_path(x) + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x) + shortcut
        return x

# --- Placeholder for Transformer Block ---
class TransformerBlockPlaceholder(nn.Module):
    """
    Placeholder for a Transformer block.
    Replace with actual Transformer block implementation if needed for ablation.
    """
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        # Placeholder for Attention
        self.attn = nn.Identity()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=hidden_dim, act_layer=nn.GELU, drop=drop)

    def forward(self, x, H=None, W=None):
        # x: (B, N, C)
        shortcut = x
        x = self.norm1(x)
        # Apply placeholder attention
        x = self.attn(x)
        x = self.drop_path(x) + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x) + shortcut
        return x


# ------------------------- MiT Stage (stem + Mamba/Transformer blocks) -------------------------
class MiTStage(nn.Module):
    def __init__(self, in_chs, embed_dim, depth, drop_path_rates=None, use_involution=True, use_mamba=True):
        super().__init__()
        self.use_mamba = use_mamba

        if use_involution:
            self.patch_embed = OverlapPatchEmbedInvo(
                in_chans=in_chs, embed_dim=embed_dim, kernel_size=3, stride=2, padding=1,
                reduction_ratio=4, groups=1
            )
        else:
            self.patch_embed = OverlapPatchEmbedConv(
                in_chans=in_chs, embed_dim=embed_dim, kernel_size=3, stride=2, padding=1
            )

        # Choose block type based on use_mamba
        block_type = MambaBlock if use_mamba else TransformerBlockPlaceholder # Assuming TransformerBlockPlaceholder is defined

        self.blocks = nn.ModuleList([
            block_type(
                dim=embed_dim,
                mlp_ratio=4.0,
                drop=0.0,
                drop_path=(drop_path_rates[i] if drop_path_rates is not None else 0.0),
                # Only pass d_state and use_rmsnorm if using MambaBlock
                **(dict(d_state=16, use_rmsnorm=True) if use_mamba else {})
            )
            for i in range(depth)
        ])


    def forward(self, x):
        x = self.patch_embed(x)              # (B, C, H', W')
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)     # (B, N, C)
        for blk in self.blocks:
            x = blk(x, H, W)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        return x


# ------------------------- Mix Vision Transformer (backbone) -------------------------
class MixVisionTransformer(nn.Module):
    """
    Full MiT backbone (4 stages) returning multi-scale features:
        [S1, S2, S3, S4] with strides [2, 4, 8, 16] wrt input.
    """
    def __init__(self,
                 in_chans: int = 3,
                 embed_dims = [32, 64, 160, 256],
                 depths    = [2, 2, 2, 2],
                 drop_path_rate: float = 0.0,
                 use_invo_stages = (True, True, False, False),
                 use_mamba_stages = (True, True, True, True)):
        super().__init__()
        assert len(embed_dims) == 4 and len(depths) == 4 and len(use_invo_stages) == 4 and len(use_mamba_stages) == 4

        # Stochastic depth schedule across all blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0

        self.stages = nn.ModuleList()
        in_c = in_chans
        for i in range(4):
            d = depths[i]
            dpr_slice = dp_rates[cur:cur + d]
            cur += d

            self.stages.append(
                MiTStage(
                    in_chs=in_c,
                    embed_dim=embed_dims[i],
                    depth=d,
                    drop_path_rates=dpr_slice,
                    use_involution=use_invo_stages[i],
                    use_mamba=use_mamba_stages[i] # Pass use_mamba flag to MiTStage
                )
            )
            in_c = embed_dims[i]

    def forward(self, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features


# ------------------------------- Decoder --------------------------------
class SegFormerHead(nn.Module):
    def __init__(self, in_channels: List[int], embed_dim=128, num_classes=19):
        super().__init__()
        self.proj_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, embed_dim, 1, bias=False),
                nn.BatchNorm2d(embed_dim),
                nn.ReLU(inplace=True)
            )
            for in_ch in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(embed_dim * len(in_channels), embed_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, features: List[torch.Tensor]):
        # Upsample all features to the highest spatial resolution (stage 1)
        target_h, target_w = features[0].shape[2:]
        proj = []
        for i, feat in enumerate(features):
            x = self.proj_layers[i](feat)
            if x.shape[2:] != (target_h, target_w):
                x = F.interpolate(x, size=(target_h, target_w), mode='bilinear', align_corners=False)
            proj.append(x)
        x = torch.cat(proj, dim=1)
        x = self.fuse(x)
        x = self.classifier(x)
        return x


# -------------------------- Full SegFormer Model -------------------------
class SegFormer(nn.Module):
    def __init__(self, num_classes=19, variant='mit_b0', pretrained=False,
                 drop_path_rate=0.1, in_chans=3, use_invo_stages=(True, True, False, False),
                 use_mamba_stages=(True, True, True, True), embed_dims=None, depths=None): # Added embed_dims and depths
        super().__init__()
        variants = {
            'mit_b0': dict(embed_dims=[16, 32, 64, 128], depths=[1, 1, 1, 1]),
            'mit_b1': dict(embed_dims=[64, 128, 320, 512], depths=[2, 2, 2, 2]),
            'mit_b2': dict(embed_dims=[64, 128, 320, 512], depths=[3, 4, 6, 3]),
            'mit_b3': dict(embed_dims=[64, 128, 320, 512], depths=[3, 4, 18, 3]),
            'mit_b4': dict(embed_dims=[64, 128, 320, 512], depths=[3, 8, 27, 3]),
            'mit_b5': dict(embed_dims=[64, 128, 320, 512], depths=[3, 6, 40, 3]),
        }

        if variant in variants:
            cfg = variants[variant]
            embed_dims = embed_dims or cfg['embed_dims']
            depths = depths or cfg['depths']
        elif embed_dims is None or depths is None:
             raise ValueError(f'Unknown variant: {variant} and embed_dims or depths not provided')


        self.backbone = MixVisionTransformer(
            in_chans=in_chans,
            embed_dims=embed_dims, # Pass embed_dims
            depths=depths, # Pass depths
            drop_path_rate=drop_path_rate,
            use_invo_stages=use_invo_stages,
            use_mamba_stages=use_mamba_stages # Pass use_mamba_stages to backbone
        )
        self.decoder = SegFormerHead(in_channels=embed_dims, embed_dim=128, num_classes=num_classes) # Pass embed_dims
        self.apply(self._init_weights)

        if pretrained:
            print('pretrained=True selected but no loader is implemented in this script.')

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.LayerNorm, RMSNorm)):
            nn.init.ones_(m.weight)
            if getattr(m, "bias", None) is not None: # Add this check
                nn.init.zeros_(m.bias)

    def forward(self, x):
        H, W = x.shape[2], x.shape[3]
        feats = self.backbone(x)                       # [S1,S2,S3,S4]
        out = self.decoder(feats)                      # logits at S1 resolution
        out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
        return out

In [None]:
import os
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

# Define ablation variants
variants = [
    {"name": "early_mamba_late_attn", "use_mamba_stages": (True, True, False, False),
     "embed_dims": [16, 32, 64, 128], "depths": [1, 1, 1, 1]}
]

# Initialize dataset (fixed)
dataset_train = GlacierHDF5PatchDataset3(TRAIN_PATH, length=2000)
dataset_val = GlacierHDF5PatchDataset3(VAL_PATH, length=500)
train_loader = DataLoader(dataset_train, batch_size=30, shuffle=True)
dev_loader = DataLoader(dataset_val, batch_size=30, shuffle=False)

# Run ablation experiments
for variant in variants:
    print(f"\nRunning ablation: {variant['name']}")
    log_dir = f"runs/ablation_{variant['name']}"
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)

    # Initialize model
    model = SegFormer(
        num_classes=2,
        variant='mit_b0',
        in_chans=11,
        use_mamba_stages=variant['use_mamba_stages'],
        use_invo_stages=(True, True, False, False),
        embed_dims=variant['embed_dims'],
        depths=variant['depths'],
        drop_path_rate=0.2
    ).to(DEVICE)

    # Compute FLOPs/params
    stats = analyze_model(model, input_size=(1, 11, 256, 256), device=DEVICE)
    writer.add_scalar("Params/total", stats["params"], 0)
    writer.add_scalar("FLOPs/total", stats["flops"], 0)
    writer.add_scalar("InferenceTime_ms", stats["avg_inference_time_ms"], 0)

    # Train
    optimizer = AdamW(model.parameters(), lr=3e-6)
    criterion = nn.CrossEntropyLoss()
    best_val_iou = 0.0
    weight_dir = os.path.join(WEIGHT_DIR, variant['name'])
    os.makedirs(weight_dir, exist_ok=True)

    for epoch in range(EPOCHS):
        train_loss = train_epoch(train_loader, model, criterion, optimizer, DEVICE)
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for x, y in dev_loader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                outputs = model(x)
                preds = torch.argmax(outputs, dim=1)
                all_preds.append(preds.cpu().numpy())
                all_labels.append(y.cpu().numpy())

        y_true = np.concatenate(all_labels)
        y_pred = np.concatenate(all_preds)
        metrics = segmentation_metrics(y_true, y_pred, NUM_CLASSES)

        print(f"Variant {variant['name']} | Epoch {epoch+1}/{EPOCHS} | "
              f"Train Loss {train_loss:.4f} | "
              f"PixelAcc {metrics['pixel_acc']:.4f} | "
              f"mIoU {metrics['mean_iou']:.4f} | "
              f"Dice {metrics['mean_dice']:.4f} | "
              f"Precision {metrics['precision']:.4f} | "
              f"Recall {metrics['recall']:.4f}")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Acc/pixel", metrics["pixel_acc"], epoch)
        writer.add_scalar("IoU/mean", metrics["mean_iou"], epoch)
        writer.add_scalar("Dice/mean", metrics["mean_dice"], epoch)
        writer.add_scalar("Precision/mean", metrics["precision"], epoch)
        writer.add_scalar("Recall/mean", metrics["recall"], epoch)

        ckpt_path = os.path.join(weight_dir, f"epoch_{epoch+1:03d}.pth")
        torch.save(model.state_dict(), ckpt_path)
        if metrics["mean_iou"] > best_val_iou:
            best_val_iou = metrics["mean_iou"]
            best_ckpt_path = os.path.join(weight_dir, "best_model.pth")
            torch.save(model.state_dict(), best_ckpt_path)
            print(f"Saved best model for {variant['name']}")

    dataset_train.close()
    dataset_val.close()
    writer.close()


Running ablation: early_mamba_late_attn
[INFO] Using device: cuda
Layer (type:depth-idx)                                       Output Shape              Param #
SegFormer                                                    [1, 2, 256, 256]          --
├─MixVisionTransformer: 1-1                                  [1, 16, 128, 128]         --
│    └─ModuleList: 2-1                                       --                        --
│    │    └─MiTStage: 3-1                                    [1, 16, 128, 128]         5,795
│    │    └─MiTStage: 3-2                                    [1, 32, 64, 64]           19,057
│    │    └─MiTStage: 3-3                                    [1, 64, 32, 32]           51,904
│    │    └─MiTStage: 3-4                                    [1, 128, 16, 16]          206,208
├─SegFormerHead: 1-2                                         [1, 2, 128, 128]          --
│    └─ModuleList: 2-2                                       --                        --
│    │    └─

100%|██████████| 67/67 [01:43<00:00,  1.55s/it]


Variant early_mamba_late_attn | Epoch 1/35 | Train Loss 2.3127 | PixelAcc 0.6212 | mIoU 0.4488 | Dice 0.6189 | Precision 0.6812 | Recall 0.7027
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:44<00:00,  1.56s/it]


Variant early_mamba_late_attn | Epoch 2/35 | Train Loss 1.7998 | PixelAcc 0.7135 | mIoU 0.5511 | Dice 0.7097 | Precision 0.7441 | Recall 0.7770
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:46<00:00,  1.58s/it]


Variant early_mamba_late_attn | Epoch 3/35 | Train Loss 1.3415 | PixelAcc 0.7316 | mIoU 0.5656 | Dice 0.7200 | Precision 0.7392 | Recall 0.7966
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 4/35 | Train Loss 1.1693 | PixelAcc 0.7279 | mIoU 0.5673 | Dice 0.7228 | Precision 0.7538 | Recall 0.7950
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:44<00:00,  1.56s/it]


Variant early_mamba_late_attn | Epoch 5/35 | Train Loss 0.8717 | PixelAcc 0.8333 | mIoU 0.7076 | Dice 0.8280 | Precision 0.8297 | Recall 0.8695
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:44<00:00,  1.57s/it]


Variant early_mamba_late_attn | Epoch 6/35 | Train Loss 0.7550 | PixelAcc 0.8332 | mIoU 0.6992 | Dice 0.8212 | Precision 0.8142 | Recall 0.8701


100%|██████████| 67/67 [01:41<00:00,  1.52s/it]


Variant early_mamba_late_attn | Epoch 7/35 | Train Loss 0.6503 | PixelAcc 0.8644 | mIoU 0.7460 | Dice 0.8531 | Precision 0.8416 | Recall 0.8979
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 8/35 | Train Loss 0.5407 | PixelAcc 0.8743 | mIoU 0.7561 | Dice 0.8593 | Precision 0.8432 | Recall 0.9028
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:41<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 9/35 | Train Loss 0.5104 | PixelAcc 0.9108 | mIoU 0.8198 | Dice 0.9000 | Precision 0.8836 | Recall 0.9326
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 10/35 | Train Loss 0.4694 | PixelAcc 0.9012 | mIoU 0.8053 | Dice 0.8912 | Precision 0.8763 | Recall 0.9250


100%|██████████| 67/67 [01:40<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 11/35 | Train Loss 0.4315 | PixelAcc 0.9148 | mIoU 0.8278 | Dice 0.9050 | Precision 0.8891 | Recall 0.9362
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:40<00:00,  1.50s/it]


Variant early_mamba_late_attn | Epoch 12/35 | Train Loss 0.3746 | PixelAcc 0.9191 | mIoU 0.8391 | Dice 0.9119 | Precision 0.8987 | Recall 0.9387
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:39<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 13/35 | Train Loss 0.3613 | PixelAcc 0.9237 | mIoU 0.8456 | Dice 0.9157 | Precision 0.9016 | Recall 0.9416
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:43<00:00,  1.54s/it]


Variant early_mamba_late_attn | Epoch 14/35 | Train Loss 0.3164 | PixelAcc 0.9439 | mIoU 0.8870 | Dice 0.9399 | Precision 0.9308 | Recall 0.9541
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:40<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 15/35 | Train Loss 0.2846 | PixelAcc 0.9490 | mIoU 0.8934 | Dice 0.9434 | Precision 0.9321 | Recall 0.9597
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:41<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 16/35 | Train Loss 0.3267 | PixelAcc 0.9452 | mIoU 0.8827 | Dice 0.9372 | Precision 0.9229 | Recall 0.9581


100%|██████████| 67/67 [01:41<00:00,  1.52s/it]


Variant early_mamba_late_attn | Epoch 17/35 | Train Loss 0.2840 | PixelAcc 0.9533 | mIoU 0.8978 | Dice 0.9457 | Precision 0.9323 | Recall 0.9639
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:43<00:00,  1.54s/it]


Variant early_mamba_late_attn | Epoch 18/35 | Train Loss 0.2316 | PixelAcc 0.9594 | mIoU 0.9085 | Dice 0.9517 | Precision 0.9382 | Recall 0.9692
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:39<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 19/35 | Train Loss 0.2499 | PixelAcc 0.9448 | mIoU 0.8842 | Dice 0.9382 | Precision 0.9261 | Recall 0.9555


100%|██████████| 67/67 [01:39<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 20/35 | Train Loss 0.2097 | PixelAcc 0.9604 | mIoU 0.9132 | Dice 0.9544 | Precision 0.9434 | Recall 0.9683
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:40<00:00,  1.50s/it]


Variant early_mamba_late_attn | Epoch 21/35 | Train Loss 0.2156 | PixelAcc 0.9638 | mIoU 0.9196 | Dice 0.9579 | Precision 0.9472 | Recall 0.9712
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 22/35 | Train Loss 0.2146 | PixelAcc 0.9530 | mIoU 0.8957 | Dice 0.9446 | Precision 0.9313 | Recall 0.9620


100%|██████████| 67/67 [01:39<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 23/35 | Train Loss 0.2024 | PixelAcc 0.9513 | mIoU 0.9001 | Dice 0.9472 | Precision 0.9386 | Recall 0.9594


100%|██████████| 67/67 [01:39<00:00,  1.49s/it]


Variant early_mamba_late_attn | Epoch 24/35 | Train Loss 0.2192 | PixelAcc 0.9443 | mIoU 0.8833 | Dice 0.9376 | Precision 0.9259 | Recall 0.9543


100%|██████████| 67/67 [01:40<00:00,  1.50s/it]


Variant early_mamba_late_attn | Epoch 25/35 | Train Loss 0.1982 | PixelAcc 0.9696 | mIoU 0.9349 | Dice 0.9662 | Precision 0.9590 | Recall 0.9750
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 26/35 | Train Loss 0.1725 | PixelAcc 0.9501 | mIoU 0.8918 | Dice 0.9424 | Precision 0.9299 | Recall 0.9592


100%|██████████| 67/67 [01:40<00:00,  1.50s/it]


Variant early_mamba_late_attn | Epoch 27/35 | Train Loss 0.1880 | PixelAcc 0.9693 | mIoU 0.9323 | Dice 0.9648 | Precision 0.9559 | Recall 0.9755


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 28/35 | Train Loss 0.1591 | PixelAcc 0.9701 | mIoU 0.9346 | Dice 0.9660 | Precision 0.9582 | Recall 0.9754


100%|██████████| 67/67 [01:41<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 29/35 | Train Loss 0.1362 | PixelAcc 0.9694 | mIoU 0.9327 | Dice 0.9651 | Precision 0.9570 | Recall 0.9747


100%|██████████| 67/67 [01:41<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 30/35 | Train Loss 0.1519 | PixelAcc 0.9728 | mIoU 0.9390 | Dice 0.9684 | Precision 0.9614 | Recall 0.9765
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 31/35 | Train Loss 0.1544 | PixelAcc 0.9625 | mIoU 0.9171 | Dice 0.9565 | Precision 0.9452 | Recall 0.9709


100%|██████████| 67/67 [01:41<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 32/35 | Train Loss 0.1374 | PixelAcc 0.9697 | mIoU 0.9334 | Dice 0.9654 | Precision 0.9570 | Recall 0.9756


100%|██████████| 67/67 [01:40<00:00,  1.51s/it]


Variant early_mamba_late_attn | Epoch 33/35 | Train Loss 0.1521 | PixelAcc 0.9519 | mIoU 0.8978 | Dice 0.9458 | Precision 0.9332 | Recall 0.9638


100%|██████████| 67/67 [01:42<00:00,  1.53s/it]


Variant early_mamba_late_attn | Epoch 34/35 | Train Loss 0.1535 | PixelAcc 0.9795 | mIoU 0.9524 | Dice 0.9755 | Precision 0.9697 | Recall 0.9819
Saved best model for early_mamba_late_attn


100%|██████████| 67/67 [01:40<00:00,  1.50s/it]


Variant early_mamba_late_attn | Epoch 35/35 | Train Loss 0.1228 | PixelAcc 0.9666 | mIoU 0.9283 | Dice 0.9626 | Precision 0.9569 | Recall 0.9693


In [None]:
# viz_predictions_with_metrics.py

import os, json, random
import numpy as np
import h5py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ====== USER SETTINGS ======
TEST_PATH   = "/home/mt/dataset/20230905_test_global_ps384.hdf5"
SAVE_BASE_DIR    = "ablation_viz"
NUM_IMAGES  = 300
BATCH_SIZE  = 4
IN_CHANS    = 11
NUM_CLASSES = 2
# VARIANT     = 'mit_b0' # This will be set per variant
DROP_PATH   = 0.2 # Keep consistent with training
IGNORE_INDEX = None   # e.g., 255 if your labels have an ignore id; else None
WEIGHT_BASE_DIR = "weights/layerwise" # Base directory where variant weights are saved

# Define ablation variants (should match the training script)
variants = [
    {"name": "default_mamba", "use_mamba_stages": (True, True, True, True),
     "embed_dims": [16, 32, 64, 128], "depths": [1, 1, 1, 1]},
    {"name": "stage1_deep", "use_mamba_stages": (True, True, True, True),
     "embed_dims": [16, 32, 64, 128], "depths": [2, 1, 1, 1]},
    {"name": "stage4_deep", "use_mamba_stages": (True, True, True, True),
     "embed_dims": [16, 32, 64, 128], "depths": [1, 1, 1, 2]},
    {"name": "stage4_skip", "use_mamba_stages": (True, True, True, True),
     "embed_dims": [16, 32, 64, 128], "depths": [1, 1, 1, 0]},
    {"name": "early_mamba_late_attn", "use_mamba_stages": (True, True, False, False),
     "embed_dims": [16, 32, 64, 128], "depths": [1, 1, 1, 1]},
    {"name": "stage1_light", "use_mamba_stages": (True, True, True, True),
     "embed_dims": [8, 32, 64, 128], "depths": [1, 1, 1, 1]},
]

# ====== VIZ UTILS ======
def to_pseudo_rgb(x: torch.Tensor) -> np.ndarray:
    x = x.float()
    C, H, W = x.shape
    def norm(ch):
        ch = ch.clone()
        vmin, vmax = torch.min(ch), torch.max(ch)
        if float(vmax - vmin) < 1e-8:
            return torch.zeros_like(ch)
        return (ch - vmin) / (vmax - vmin)
    if C >= 3:
        r = norm(x[0]); g = norm(x[1]); b = norm(x[2])
    elif C == 2:
        r = norm(x[0]); g = norm(x[1]); b = torch.zeros_like(r)
    else:
        r = g = b = norm(x[0])
    return torch.stack([r, g, b], dim=-1).cpu().numpy().astype(np.float32)

def colorize_mask(mask: np.ndarray, class_colors=None) -> np.ndarray:
    if class_colors is None:
        class_colors = [(30, 30, 30), (0, 200, 255)]
    h, w = mask.shape
    out = np.zeros((h, w, 3), dtype=np.uint8)
    max_id = int(mask.max()) if mask.size else 0
    while len(class_colors) <= max_id:
        class_colors.append(tuple(np.random.randint(0, 255, size=3).tolist()))
    for cid, color in enumerate(class_colors):
        out[mask == cid] = color
    return out

def overlay_image(image_rgb: np.ndarray, mask_rgb: np.ndarray, alpha=0.45) -> np.ndarray:
    img = image_rgb.copy()
    m = (mask_rgb.astype(np.float32) / 255.0)
    out = (1 - alpha) * img + alpha * m
    return np.clip(out, 0, 1)

@torch.no_grad()
def save_grid(rgb, y_rgb, p_rgb, over_gt, over_pred, out_path):
    fig = plt.figure(figsize=(12, 6))
    gs = fig.add_gridspec(2, 3, hspace=0.15, wspace=0.05)
    ax1 = fig.add_subplot(gs[:, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 1])
    ax4 = fig.add_subplot(gs[0, 2])
    ax5 = fig.add_subplot(gs[1, 2])

    ax1.imshow(rgb);       ax1.set_title("Pseudo-RGB");  ax1.axis('off')
    ax2.imshow(y_rgb);     ax2.set_title("GT Mask");     ax2.axis('off')
    ax3.imshow(p_rgb);     ax3.set_title("Pred Mask");   ax3.axis('off')
    ax4.imshow(over_gt);   ax4.set_title("Overlay GT");  ax4.axis('off')
    ax5.imshow(over_pred); ax5.set_title("Overlay Pred");ax5.axis('off')

    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=150)
    plt.close(fig)

# ====== DATASET (3-modal, 11-chan) ======
PATCH_SIZE = 256
TARGET_SIZE = (256, 256)
LABEL_KEY = 'outlines'
MODALITIES_3 = ['dem', 'optical', 'bright_dark_outlines']  # 2 + 6 + 3 = 11 chans

def normalize(arr):
    arr = arr.astype(np.float32)
    return (arr - arr.mean()) / (arr.std() + 1e-5)

# Augmentation not needed for evaluation
# def augment_patch(img, label):
#     if random.random() < 0.5:
#         img = np.flip(img, axis=0); label = np.flip(label, axis=0)
#     if random.random() < 0.5:
#         img = np.flip(img, axis=1); label = np.flip(label, axis=1)
#     return img, label

class GlacierHDF5PatchDataset3(torch.utils.data.Dataset):
    def __init__(self, hdf5_file_path, patch_size=PATCH_SIZE, target_size=TARGET_SIZE, length=2000):
        self.hdf5 = h5py.File(hdf5_file_path, 'r')
        self.tiles = [name for name in self.hdf5.keys() if all(m in self.hdf5[name] for m in MODALITIES_3)]
        self.patch_size = patch_size
        self.target_size = target_size
        self.length = length
        # Pre-select random indices for the fixed number of images
        self.selected_indices = [self._get_random_patch_coords() for _ in range(length)]


    def __len__(self):
        return self.length

    def _get_random_patch_coords(self):
        tile_name = random.choice(self.tiles)
        tile = self.hdf5[tile_name]
        h, w = tile[MODALITIES_3[0]].shape[:2]
        y = random.randint(0, h - self.patch_size)
        x = random.randint(0, w - self.patch_size)
        return tile_name, y, x

    def __getitem__(self, idx):
        tile_name, y, x = self.selected_indices[idx]
        tile = self.hdf5[tile_name]

        input_channels = []
        for key in MODALITIES_3:
            arr = tile[key][y:y+self.patch_size, x:x+self.patch_size, :]
            arr = normalize(arr)
            input_channels.append(arr)
        input_patch = np.concatenate(input_channels, axis=2)  # (H,W,11)

        label = tile[LABEL_KEY][y:y+self.patch_size, x:x+self.patch_size]
        if label.ndim == 3:
            label = label[:, :, 0] if label.shape[2] == 1 else np.argmax(label, axis=2)

        # Augmentation is not needed for evaluation
        # input_patch, label = augment_patch(input_patch, label)

        input_tensor = torch.tensor(np.ascontiguousarray(input_patch)).permute(2, 0, 1).float()  # (C,H,W)
        label_tensor = torch.tensor(np.ascontiguousarray(label), dtype=torch.long)               # (H,W)
        return input_tensor, label_tensor

    def close(self):
        self.hdf5.close()

# ====== METRICS ======
def _fast_confusion_matrix(pred, target, num_classes, ignore_index=None):
    """
    pred, target: 1D arrays with same length (flattened), type int
    Returns (K,K) confusion matrix where rows = GT, cols = Pred.
    """
    if ignore_index is not None:
        mask = target != ignore_index
        pred = pred[mask]
        target = target[mask]
    # remove out-of-range labels just in case
    mask = (target >= 0) & (target < num_classes)
    pred = pred[mask]; target = target[mask]
    cm = np.bincount(num_classes * target + pred, minlength=num_classes**2).reshape(num_classes, num_classes)
    return cm.astype(np.int64)

def compute_metrics_from_confmat(conf_mat):
    """
    conf_mat: (K,K), rows=GT, cols=Pred
    Returns dict with pixel_acc, mean_iou, mean_dice, per-class stats, etc.
    """
    K = conf_mat.shape[0]
    tp = np.diag(conf_mat).astype(np.float64)
    fp = conf_mat.sum(0) - tp
    fn = conf_mat.sum(1) - tp
    denom_iou  = tp + fp + fn
    denom_dice = 2*tp + fp + fn

    with np.errstate(divide='ignore', invalid='ignore'):
        iou  = np.where(denom_iou  > 0, tp / denom_iou, 0.0)
        dice = np.where(denom_dice > 0, 2*tp / denom_dice, 0.0)

    total = conf_mat.sum()
    pixel_acc = float(tp.sum() / total) if total > 0 else 0.0
    mean_iou  = float(np.mean(iou))
    mean_dice = float(np.mean(dice))

    return {
        "pixel_acc": pixel_acc,
        "mean_iou":  mean_iou,
        "mean_dice": mean_dice,
        "per_class_iou":  iou.tolist(),
        "per_class_dice": dice.tolist(),
        "confusion_matrix": conf_mat.tolist(),
        "total_pixels": int(total),
    }

# ====== MAIN ======
def main():
    # ---- Device ----
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    # ---- Dataset & Loader ----
    dataset = GlacierHDF5PatchDataset3(TEST_PATH, length=NUM_IMAGES)
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False, # No shuffle for evaluation
        num_workers=0,   # h5py-safe
        pin_memory=True
    )

    all_variants_metrics = {} # Dictionary to store metrics for each variant

    for variant in variants:
        variant_name = variant['name']
        print(f"\nEvaluating variant: {variant_name}")

        # ---- Model ----
        model = SegFormer(num_classes=NUM_CLASSES,
                          variant='mit_b0', # Keep base variant name
                          drop_path_rate=DROP_PATH,
                          in_chans=IN_CHANS,
                          use_mamba_stages=variant['use_mamba_stages'],
                          use_invo_stages=(True, True, False, False), # Assuming consistent stem setup as in training
                          embed_dims=variant['embed_dims'],
                          depths=variant['depths']
                          ).to(device)

        checkpoint_path = os.path.join(WEIGHT_BASE_DIR, variant_name, "best_model.pth")
        assert os.path.isfile(checkpoint_path), f"Checkpoint not found for {variant_name}: {checkpoint_path}"

        def clean_state_dict(state):
            return {k: v for k, v in state.items() if "total_ops" not in k and "total_params" not in k}

        state = torch.load(checkpoint_path, map_location=device)
        state = clean_state_dict(state)
        model.load_state_dict(state, strict=False)
        model.eval()
        print(f"Loaded cleaned weights from: {checkpoint_path}")


        save_dir = os.path.join(SAVE_BASE_DIR, variant_name)
        os.makedirs(save_dir, exist_ok=True)

        # ---- Inference, Save viz, and Accumulate Confusion Matrix ----
        saved = 0
        conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)

        with torch.no_grad():
            for bi, (xb, yb) in enumerate(tqdm(loader, desc=f"Inferring {variant_name}")): # Added tqdm for progress
                xb = xb.to(device)                 # [B,C,H,W]
                yb = yb.to(device).long()          # [B,H,W]

                logits = model(xb)                 # [B,K,H,W]  (K=NUM_CLASSES)
                if logits.dim() == 4 and logits.size(1) > 1:
                    preds = torch.argmax(logits, dim=1)  # [B,H,W]
                else:
                    preds = (logits.squeeze(1) > 0).long()

                # ---- Metrics accumulation ----
                # Flatten and move to CPU numpy
                y_flat = yb.view(-1).detach().cpu().numpy()
                p_flat = preds.view(-1).detach().cpu().numpy()
                conf_mat += _fast_confusion_matrix(
                    p_flat, y_flat, num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX
                )

                # ---- Visualization ----
                # Only save visualizations for a subset to avoid generating too many images
                if saved < NUM_IMAGES:
                     for i in range(xb.size(0)):
                        if saved < NUM_IMAGES:
                            x  = xb[i].detach().cpu()           # (C,H,W)
                            y  = yb[i].detach().cpu().numpy()   # (H,W)
                            p  = preds[i].detach().cpu().numpy()

                            rgb    = to_pseudo_rgb(x)
                            y_rgb  = colorize_mask(y)
                            p_rgb  = colorize_mask(p)
                            over_y = overlay_image(rgb, y_rgb, alpha=0.45)
                            over_p = overlay_image(rgb, p_rgb, alpha=0.45)

                            out_path = os.path.join(save_dir, f"pred_{saved:04d}.png")
                            save_grid(rgb, y_rgb, p_rgb, over_y, over_p, out_path)
                            # print(f"Saved {out_path}") # Avoid excessive printing during loop
                            saved += 1
                        else:
                            break # Stop saving visualizations once NUM_IMAGES is reached


        # ---- Compute and Save metrics for the current variant ----
        metrics = compute_metrics_from_confmat(conf_mat)
        metrics_path = os.path.join(save_dir, "metrics.json")
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")

        all_variants_metrics[variant_name] = metrics # Store metrics

    # ---- Close dataset and print summary of all variants ----
    if hasattr(dataset, "close"):
        dataset.close()

    print("\n--- Summary of all variants ---")
    for name, metrics in all_variants_metrics.items():
        print(f"Variant: {name}")
        print(f"  PixelAcc: {metrics['pixel_acc']:.4f}")
        print(f"  mIoU: {metrics['mean_iou']:.4f}")
        print(f"  Dice: {metrics['mean_dice']:.4f}")
        # print(f"  Per-class IoU: {metrics['per_class_iou']}")
        # print(f"  Per-class Dice: {metrics['per_class_dice']}")
        print("-" * 20)


if __name__ == "__main__":
    main()
    import csv

    # ---- Save combined metrics (CSV + JSON) ----
    summary_csv = os.path.join(SAVE_BASE_DIR, "all_variants_summary.csv")
    summary_json = os.path.join(SAVE_BASE_DIR, "all_variants_summary.json")

    os.makedirs(SAVE_BASE_DIR, exist_ok=True)

    # Write CSV
    with open(summary_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Variant", "PixelAcc", "mIoU", "Dice"])  # header
        for name, metrics in all_variants_metrics.items():
            writer.writerow([
                name,
                f"{metrics['pixel_acc']:.4f}",
                f"{metrics['mean_iou']:.4f}",
                f"{metrics['mean_dice']:.4f}"
            ])

    # Write JSON
    with open(summary_json, "w") as f:
        json.dump(all_variants_metrics, f, indent=2)

    print(f"\nSaved combined results to:\n  {summary_csv}\n  {summary_json}")


Device: cuda

Evaluating variant: default_mamba
Loaded cleaned weights from: weights/layerwise/default_mamba/best_model.pth


  plt.tight_layout()
Inferring default_mamba: 100%|██████████| 75/75 [01:27<00:00,  1.16s/it]


Saved metrics to ablation_viz/default_mamba/metrics.json

Evaluating variant: stage1_deep
Loaded cleaned weights from: weights/layerwise/stage1_deep/best_model.pth


Inferring stage1_deep: 100%|██████████| 75/75 [01:09<00:00,  1.08it/s]


Saved metrics to ablation_viz/stage1_deep/metrics.json

Evaluating variant: stage4_deep
Loaded cleaned weights from: weights/layerwise/stage4_deep/best_model.pth


Inferring stage4_deep: 100%|██████████| 75/75 [01:10<00:00,  1.07it/s]


Saved metrics to ablation_viz/stage4_deep/metrics.json

Evaluating variant: stage4_skip
Loaded cleaned weights from: weights/layerwise/stage4_skip/best_model.pth


Inferring stage4_skip: 100%|██████████| 75/75 [01:08<00:00,  1.09it/s]


Saved metrics to ablation_viz/stage4_skip/metrics.json

Evaluating variant: early_mamba_late_attn
Loaded cleaned weights from: weights/layerwise/early_mamba_late_attn/best_model.pth


Inferring early_mamba_late_attn: 100%|██████████| 75/75 [01:09<00:00,  1.08it/s]


Saved metrics to ablation_viz/early_mamba_late_attn/metrics.json

Evaluating variant: stage1_light
Loaded cleaned weights from: weights/layerwise/stage1_light/best_model.pth


Inferring stage1_light: 100%|██████████| 75/75 [01:10<00:00,  1.07it/s]

Saved metrics to ablation_viz/stage1_light/metrics.json

--- Summary of all variants ---
Variant: default_mamba
  PixelAcc: 0.6014
  mIoU: 0.4297
  Dice: 0.6009
--------------------
Variant: stage1_deep
  PixelAcc: 0.9547
  mIoU: 0.8998
  Dice: 0.9468
--------------------
Variant: stage4_deep
  PixelAcc: 0.9041
  mIoU: 0.8089
  Dice: 0.8933
--------------------
Variant: stage4_skip
  PixelAcc: 0.9599
  mIoU: 0.9110
  Dice: 0.9531
--------------------
Variant: early_mamba_late_attn
  PixelAcc: 0.9807
  mIoU: 0.9574
  Dice: 0.9782
--------------------
Variant: stage1_light
  PixelAcc: 0.9218
  mIoU: 0.8391
  Dice: 0.9117
--------------------





NameError: name 'all_variants_metrics' is not defined