<a href="https://colab.research.google.com/github/nimrashaheen001/nimrashaheen001/blob/main/proposedMethodology4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# =============================
# 1. Install and Imports
# =============================
!pip install ml-collections -q

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from scipy.ndimage import zoom
from ml_collections import ConfigDict

# =============================
# 2. Mount Google Drive
# =============================
from google.colab import drive
drive.mount("/content/drive")

# Set paths
dataset_path = "/content/drive/MyDrive/dataset/BrainMRI_Seg_dataset_small.npz"
pretrained_path = "/content/drive/MyDrive/imagenet21k_R50+ViT-B_16.npz"

# =============================
# 3. Load Dataset
# =============================
data = np.load(dataset_path)
print("Available keys in dataset:", list(data.keys()))

# Here we assume keys[0] = images, keys[1] = masks
images = data[list(data.keys())[0]]   # shape (N, H, W, 3)
labels = data[list(data.keys())[1]]   # shape (N, H, W)

print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

# =============================
# 4. Preprocessing / Dataset Class
# =============================
class BrainMRIDataset(Dataset):
    def __init__(self, images, labels, output_size=(224,224), transform=None):
        self.images = images
        self.labels = labels
        self.output_size = output_size
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, label = self.images[idx], self.labels[idx]

        # Resize
        h, w = image.shape[:2]
        image = zoom(image, (self.output_size[0]/h, self.output_size[1]/w, 1), order=3)  # RGB keep channels
        label = zoom(label, (self.output_size[0]/h, self.output_size[1]/w), order=0)     # masks

        # Convert to (C, H, W) for PyTorch
        image = np.transpose(image, (2, 0, 1))   # (3, H, W)
        label = np.expand_dims(label, axis=0)    # (1, H, W)

        # Convert to tensors
        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()

        sample = {"image": image, "label": label}
        if self.transform:
            sample = self.transform(sample)

        return sample

# =============================
# 5. Split into Train/Test
# =============================
dataset = BrainMRIDataset(images, labels, output_size=(224,224))

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# =============================
# 6. Vision Transformer Config (R50 + ViT-B/16)
# =============================
def get_r50_b16_config():
    config_vit = ConfigDict()
    config_vit.patches = ConfigDict({'size': (16, 16)})
    config_vit.hidden_size = 768
    config_vit.transformer = ConfigDict({
        'mlp_dim': 3072,
        'num_heads': 12,
        'num_layers': 12,
        'attention_dropout_rate': 0.0,
        'dropout_rate': 0.1,
    })
    config_vit.classifier = 'seg'
    config_vit.representation_size = None
    config_vit.resnet = ConfigDict({
        'num_layers': (3, 4, 9),
        'width_factor': 1,
    })
    return config_vit

config = get_r50_b16_config()

# =============================
# 7. Load Pretrained Weights (if file exists)
# =============================
if os.path.exists(pretrained_path):
    vit_pretrained = np.load(pretrained_path)
    print("Pretrained weights keys:", vit_pretrained.files)
else:
    print(f"⚠️ Pretrained weights not found at {pretrained_path}")

# =============================
# 8. Training Loop Placeholder
# =============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for batch in train_loader:
    imgs, masks = batch["image"].to(device), batch["label"].to(device)
    print("Batch images:", imgs.shape)  # (B, 3, 224, 224)
    print("Batch labels:", masks.shape) # (B, 1, 224, 224)
    break

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive
Available keys in dataset: ['X_train', 'M_train', 'y_train', 'X_val', 'M_val', 'y_val', 'X_test', 'M_test', 'y_test']
Images shape: (10503, 224, 224, 3)
Labels shape: (10503, 224, 224)
Train samples: 8402, Test samples: 2101
Pretrained weights keys: ['Transformer/encoder_norm/bias', 'Transformer/encoder_norm/scale', 'Transformer/encoderblock_0/LayerNorm_0/bias', 'Transformer/encoderblock_0/LayerNorm_0/scale', 'Transformer/encoderblock_0/LayerNorm_2/bias', 'Transformer/encoderblock_0/LayerNorm_2/scale', 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/bias', 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/kernel', 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/bias', 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/kernel', 'Transform

In [None]:
# =============================
# TransUNet Training (Notebook-Friendly, Debug-Safe, Label-Remap)
# =============================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from scipy.ndimage import zoom
from ml_collections import ConfigDict
import torch.nn.functional as F

# =============================
# Dataset Class (handles RGB or grayscale images)
# Returns: image tensor (C,H,W), label tensor (H,W) (long)
# =============================
class BrainMRIDataset(Dataset):
    def __init__(self, images, labels, output_size=(224,224)):
        """
        images: numpy array (N, H, W) or (N, H, W, C)
        labels: numpy array (N, H, W)
        """
        self.images = images
        self.labels = labels
        self.output_size = output_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Resize depending on channel layout
        if image.ndim == 2:  # grayscale H,W
            H, W = image.shape
            image = zoom(image, (self.output_size[0]/H, self.output_size[1]/W), order=3)
            label = zoom(label, (self.output_size[0]/H, self.output_size[1]/W), order=0)
            image = np.expand_dims(image, axis=0)  # (1,H,W)

        elif image.ndim == 3:  # H,W,C
            H, W, C = image.shape
            image = zoom(image, (self.output_size[0]/H, self.output_size[1]/W, 1), order=3)
            label = zoom(label, (self.output_size[0]/H, self.output_size[1]/W), order=0)
            # to (C,H,W)
            image = np.transpose(image, (2,0,1))
        else:
            raise ValueError(f"Unsupported image ndim {image.ndim} for image with shape {image.shape}")

        # Ensure types
        image = image.astype(np.float32)
        label = label.astype(np.int64)

        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()  # (H, W)

        return {"image": image, "label": label}

# =============================
# Vision Transformer Config (kept for compatibility)
# =============================
def get_r50_b16_config():
    config_vit = ConfigDict()
    config_vit.patches = ConfigDict({'size': (16, 16)})
    config_vit.hidden_size = 768
    config_vit.transformer = ConfigDict({
        'mlp_dim': 3072,
        'num_heads': 12,
        'num_layers': 4,  # small for debug; increase later
        'attention_dropout_rate': 0.0,
        'dropout_rate': 0.1,
    })
    config_vit.classifier = 'seg'
    config_vit.representation_size = None
    config_vit.resnet = ConfigDict({
        'num_layers': (3, 4, 9),
        'width_factor': 1,
    })
    return config_vit

# =============================
# Simple ViT-UNet (mini) - accepts 1 or 3 channel images
# =============================
class ViT_UNet(nn.Module):
    def __init__(self, config, in_channels=3, img_size=224, num_classes=2):
        super(ViT_UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, stride=2, padding=1),  # downsample
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, config.hidden_size, 3, stride=2, padding=1),
            nn.BatchNorm2d(config.hidden_size),
            nn.ReLU(inplace=True)
        )
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.transformer.num_heads,
                dim_feedforward=config.transformer.mlp_dim,
                dropout=config.transformer.dropout_rate,
                batch_first=True
            ),
            num_layers=config.transformer.num_layers
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(config.hidden_size, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.encoder(x)   # (B, hidden, H/4, W/4)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, HW, C)
        x = self.transformer(x)           # (B, HW, C)
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.decoder(x)               # (B, num_classes, H, W)
        return x

# =============================
# Helper: remap labels to 0..K-1 or binary map
# =============================
def build_label_mapping(unique_vals, target_num_classes):
    """
    Decide mapping based on unique_vals and target_num_classes.
    Returns mapping dict {old_val: new_val} and new_num_classes.
    Behavior:
      - If unique_vals subset of [0..target_num_classes-1]: identity mapping.
      - If unique_vals equals {0, 255} and target_num_classes==2: map 255->1
      - Else if len(unique_vals) <= target_num_classes: map sorted unique to 0..k-1
      - Else: set new_num_classes = len(unique_vals) and map sorted unique -> 0..new_num_classes-1
    """
    uniq = np.array(sorted([int(x) for x in unique_vals]))
    mapping = {}
    if np.all((uniq >= 0) & (uniq <= target_num_classes - 1)):
        for v in uniq: mapping[int(v)] = int(v)
        return mapping, target_num_classes

    # Common case: masks use 0 and 255
    if set(uniq.tolist()) == {0, 255} and target_num_classes == 2:
        return {0: 0, 255: 1}, 2

    # If there are <= target_num_classes, map to 0..k-1 (preserve order)
    if len(uniq) <= target_num_classes:
        for i, v in enumerate(uniq):
            mapping[int(v)] = i
        return mapping, target_num_classes

    # Otherwise expand num_classes to fit unique values
    for i, v in enumerate(uniq):
        mapping[int(v)] = i
    new_num_classes = len(uniq)
    return mapping, new_num_classes

def apply_label_mapping(arr, mapping):
    # arr: numpy array of labels (H,W) or (N,H,W)
    # mapping: dict old->new
    if isinstance(arr, np.ndarray) and arr.ndim == 3:
        out = np.zeros_like(arr, dtype=np.int64)
        for old, new in mapping.items():
            out[arr == old] = new
        return out
    else:
        # single mask
        out = np.zeros_like(arr, dtype=np.int64)
        for old, new in mapping.items():
            out[arr == old] = new
        return out

# =============================
# Training Loop (with remap & debug)
# =============================
def train(args):
    # -------------------------
    # Load dataset arrays
    # -------------------------
    data = np.load(args.dataset_path)
    # explicit keys present in your .npz
    X_train, M_train = data['X_train'], data['M_train']
    X_val,   M_val   = data['X_val'],   data['M_val']
    X_test,  M_test  = data['X_test'],  data['M_test']

    # Examine unique labels across all splits (fast, numpy)
    unique_vals = np.unique(np.concatenate([
        np.unique(M_train), np.unique(M_val), np.unique(M_test)
    ]))
    print("Unique raw label values found:", unique_vals)

    # Build mapping and maybe update num_classes
    mapping, new_num_classes = build_label_mapping(unique_vals, args.num_classes)
    if new_num_classes != args.num_classes:
        print(f"Warning: target num_classes={args.num_classes} doesn't match labels; "
              f"setting num_classes={new_num_classes}")
        args.num_classes = new_num_classes
    print("Using label mapping (old->new):", mapping)

    # Apply mapping to label arrays (this keeps everything on CPU and safe)
    M_train_mapped = apply_label_mapping(M_train, mapping)
    M_val_mapped   = apply_label_mapping(M_val, mapping)
    M_test_mapped  = apply_label_mapping(M_test, mapping)

    # -------------------------
    # Create datasets & loaders
    # -------------------------
    # detect in_channels from images shape
    in_channels = 1
    sample_img = X_train[0]
    if sample_img.ndim == 3:
        in_channels = sample_img.shape[2]  # H,W,C
    else:
        in_channels = 1

    train_ds = BrainMRIDataset(X_train, M_train_mapped, output_size=(args.img_size, args.img_size))
    val_ds   = BrainMRIDataset(X_val,   M_val_mapped,   output_size=(args.img_size, args.img_size))
    test_ds  = BrainMRIDataset(X_test,  M_test_mapped,  output_size=(args.img_size, args.img_size))

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, num_workers=1)
    test_loader  = DataLoader(test_ds,  batch_size=args.batch_size, shuffle=False, num_workers=1)

    # -------------------------
    # Model
    # -------------------------
    config = get_r50_b16_config()
    model = ViT_UNet(config, in_channels=in_channels, img_size=args.img_size, num_classes=args.num_classes).to(args.device)
    print(f"Model created. in_channels={in_channels}, num_classes={args.num_classes}")

    # -------------------------
    # Optimizer / Loss
    # -------------------------
    optimizer = optim.Adam(model.parameters(), lr=args.base_lr)
    criterion = nn.CrossEntropyLoss()

    # -------------------------
    # Quick sanity check on a single CPU batch BEFORE GPU training (to avoid silent CUDA assert)
    # -------------------------
    model.eval()
    with torch.no_grad():
        batch0 = next(iter(train_loader))
        imgs0 = batch0['image']  # (B,C,H,W)
        masks0 = batch0['label']  # (B,H,W)
        print("Sanity batch shapes (before device):", imgs0.shape, masks0.shape)
        print("Unique labels in sample batch:", np.unique(masks0.numpy()))
    model.train()

    # -------------------------
    # Training loop
    # -------------------------
    for epoch in range(args.max_epochs):
        epoch_loss = 0.0
        for batch_idx, batch in enumerate(train_loader):
            imgs = batch['image'].to(args.device)
            masks = batch['label'].to(args.device)  # (B,H,W)
            masks = masks.long()
            # clamp again just in case
            masks = torch.clamp(masks, 0, args.num_classes - 1)

            preds = model(imgs)  # (B, num_classes, H, W)

            # Debug on first iteration
            if epoch == 0 and batch_idx == 0:
                print("DEBUG pred shape:", preds.shape, "mask shape:", masks.shape,
                      "unique mask values (batch):", torch.unique(masks).cpu().numpy().tolist())

            try:
                loss = criterion(preds, masks)
            except Exception as e:
                print("Error computing loss — printing batch info for debugging.")
                print("preds.shape", preds.shape)
                print("masks.shape", masks.shape)
                print("unique masks (cpu):", torch.unique(masks).cpu().numpy())
                raise

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{args.max_epochs}  Avg Loss: {avg_loss:.4f}")

        # (Optional) quick validation step
        if (epoch + 1) % args.val_every == 0:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for vbatch in val_loader:
                    vimgs = vbatch['image'].to(args.device)
                    vmasks = vbatch['label'].to(args.device).long()
                    vmasks = torch.clamp(vmasks, 0, args.num_classes - 1)
                    vpreds = model(vimgs)
                    vloss = criterion(vpreds, vmasks)
                    val_loss += vloss.item()
            val_loss /= len(val_loader)
            print(f"  Validation loss: {val_loss:.4f}")
            model.train()

    # -------------------------
    # Save final model
    # -------------------------
    os.makedirs(args.save_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(args.save_dir, "vit_unet_final.pth"))
    print("Training finished. Model saved to", os.path.join(args.save_dir, "vit_unet_final.pth"))

# =============================
# Notebook-Friendly args
# =============================
class Args:
    dataset_path = "/content/drive/MyDrive/dataset/BrainMRI_Seg_dataset_small.npz"
    save_dir = "/content/drive/MyDrive/transunet_snapshots"
    img_size = 224
    batch_size = 2
    max_epochs = 5
    base_lr = 0.001
    num_classes = 2            # initial target; may be updated automatically to match labels
    val_every = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = Args()

# Run training
train(args)


Unique raw label values found: [0 1 2 3]
Using label mapping (old->new): {0: 0, 1: 1, 2: 2, 3: 3}
Model created. in_channels=3, num_classes=4
Sanity batch shapes (before device): torch.Size([2, 3, 224, 224]) torch.Size([2, 224, 224])
Unique labels in sample batch: [0 3]
DEBUG pred shape: torch.Size([2, 4, 224, 224]) mask shape: torch.Size([2, 224, 224]) unique mask values (batch): [0, 1, 2]
Epoch 1/5  Avg Loss: 1.0730
  Validation loss: 1.0682
