In [None]:
import os
import numpy as np
import nibabel as nib
import cv2
from tqdm import tqdm
#config for colab not ....
SOURCE_DATA_DIR = "/content/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
PREPROCESSED_DIR = "/content/preprocessed_data"
IMG_SIZE = 192

# once preproccess for faster train
def preprocess_and_save():
    os.makedirs(PREPROCESSED_DIR, exist_ok=True)

    print("Starting one-time preprocessing of the dataset on Colab storage...")

    samples_to_process = []
    for s in sorted(os.listdir(SOURCE_DATA_DIR)):
        try:
            num = int(s.split("_")[-1])
            if num > 354: continue
        except:
            continue

        base_path = os.path.join(SOURCE_DATA_DIR, s)
        required = [f"{s}_flair.nii", f"{s}_t1.nii", f"{s}_t1ce.nii", f"{s}_t2.nii", f"{s}_seg.nii"]
        if all(os.path.exists(os.path.join(base_path, f)) for f in required):
            seg_path = os.path.join(base_path, f"{s}_seg.nii")
            seg_volume = nib.load(seg_path).get_fdata()
            for slice_idx in range(seg_volume.shape[2]):
                if np.any(seg_volume[:, :, slice_idx] > 0):
                    samples_to_process.append((s, slice_idx))

    print(f"Found {len(samples_to_process)} slices. Now processing and saving them as .npy files...")

    for i, (sample_id, slice_idx) in enumerate(tqdm(samples_to_process, desc="Preprocessing Slices")):
        sample_path = os.path.join(SOURCE_DATA_DIR, sample_id)

        flair = nib.load(os.path.join(sample_path, sample_id + "_flair.nii")).get_fdata()[:, :, slice_idx]
        t1 = nib.load(os.path.join(sample_path, sample_id + "_t1.nii")).get_fdata()[:, :, slice_idx]
        t1ce = nib.load(os.path.join(sample_path, sample_id + "_t1ce.nii")).get_fdata()[:, :, slice_idx]
        t2 = nib.load(os.path.join(sample_path, sample_id + "_t2.nii")).get_fdata()[:, :, slice_idx]
        mask = nib.load(os.path.join(sample_path, sample_id + "_seg.nii")).get_fdata()[:, :, slice_idx].astype(np.int64)

        image_stack = np.stack([flair, t1, t1ce, t2], axis=-1)
        image_resized = cv2.resize(image_stack, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
        mask_resized = cv2.resize(mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        image_processed = image_resized.transpose(2, 0, 1)
        image_processed = (image_processed - np.mean(image_processed)) / (np.std(image_processed) + 1e-8)

        mask_processed = mask_resized
        mask_processed[mask_processed == 4] = 3

        binary_mask = (mask_processed > 0).astype(np.uint8)
        boundary_map = np.zeros(mask_processed.shape, dtype=np.float32)
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        cv2.drawContours(boundary_map, contours, -1, (1.0), thickness=2)

        filename = f"slice_{i:05d}.npy"
        filepath = os.path.join(PREPROCESSED_DIR, filename)
        np.save(filepath, {'image': image_processed, 'mask': mask_processed, 'boundary': boundary_map})

    print("Preprocessing complete! Data is ready in /content/preprocessed_data/")

if __name__ == '__main__':
    preprocess_and_save()

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PREPROCESSED_DIR = "/content/preprocessed_data"
CLASS_COUNT = 4
EPOCHS = 25
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

class BraTSDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.filepaths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.npy')])
        print(f"Found {len(self.filepaths)} preprocessed slices in {root_dir}.")

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        data = np.load(filepath, allow_pickle=True).item()
        return (
            torch.tensor(data['image'], dtype=torch.float32),
            torch.tensor(data['mask'], dtype=torch.long),
            torch.tensor(data['boundary'], dtype=torch.float32).unsqueeze(0)
        )

# --- Boundary-Aware Loss Function ---
dice_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
focal_loss = smp.losses.FocalLoss(mode='multiclass')
boundary_loss = nn.BCEWithLogitsLoss()

def boundary_aware_loss(pred, target_mask, target_boundary):
    seg_loss = dice_loss(pred, target_mask) + focal_loss(pred, target_mask)
    pred_softmax = F.softmax(pred, dim=1)
    pred_for_boundary = 1 - pred_softmax[:, 0, :, :].unsqueeze(1)
    b_loss = boundary_loss(pred_for_boundary, target_boundary)
    combined_loss = seg_loss + (2.0 * b_loss)
    return combined_loss
#train
def train(model, loader, optimizer, scheduler, scaler):
    loss_history = []
    model.train()
    for epoch in range(EPOCHS):
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=True)
        epoch_loss = 0.0
        for images, masks, boundaries in progress_bar:
            images, masks, boundaries = images.to(DEVICE), masks.to(DEVICE), boundaries.to(DEVICE)
            optimizer.zero_grad()
            with autocast():
                outputs = model(images)
                loss = boundary_aware_loss(outputs, masks, boundaries)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
        avg_loss = epoch_loss / len(loader)
        loss_history.append(avg_loss)
        scheduler.step()
    return loss_history

#visualization
def visualize_results(model, dataset, count=10):
    fig, axs = plt.subplots(count, 3, figsize=(15, count * 5))
    fig.suptitle("Model Predictions vs. Ground Truth", fontsize=20)
    model.eval()
    with torch.no_grad():
        for i in range(count):
            idx = np.random.randint(0, len(dataset))
            image, mask, _ = dataset[idx]
            image_for_pred = image.unsqueeze(0).to(DEVICE)
            pred = model(image_for_pred)
            pred_mask = torch.argmax(pred.squeeze(0), dim=0).cpu().numpy()
            display_image = image[0, :, :].numpy()
            true_mask_color = colorize_mask(mask.numpy())
            pred_mask_color = colorize_mask(pred_mask)
            axs[i, 0].imshow(display_image, cmap='bone')
            axs[i, 0].set_title(f"Input Image (Flair) - Sample {idx}")
            axs[i, 0].axis('off')
            axs[i, 1].imshow(true_mask_color)
            axs[i, 1].set_title("Ground Truth Mask")
            axs[i, 1].axis('off')
            axs[i, 2].imshow(pred_mask_color)
            axs[i, 2].set_title("Predicted Mask")
            axs[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

def colorize_mask(mask):
    colors = {0: (0, 0, 0), 1: (0, 255, 0), 2: (255, 255, 0), 3: (255, 0, 0)}
    color_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for label, color in colors.items():
        color_mask[mask == label] = color
    return color_mask

#main exec
if __name__ == '__main__':
    full_dataset = BraTSDataset(root_dir=PREPROCESSED_DIR)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=4, classes=CLASS_COUNT).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
    scaler = GradScaler()

    print("Starting training from preprocessed data on Colab storage...")
    loss_history = train(model, train_loader, optimizer, scheduler, scaler)

    plt.figure(figsize=(10, 5))
    plt.plot(loss_history, label='Training Loss')
    plt.title('Training Loss Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    print("Visualizing results on validation data...")
    visualize_results(model, val_dataset)

In [None]:
Found 23415 preprocessed slices in /content/preprocessed_data.
/tmp/ipython-input-758465914.py:124: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
Starting training from preprocessed data on Colab storage...
Epoch 1/25:   0%|          | 0/1171 [00:00<?, ?it/s]/tmp/ipython-input-758465914.py:65: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast():
Epoch 1/25: 100%|██████████| 1171/1171 [02:30<00:00,  7.78it/s, loss=1.7039]
Epoch 2/25: 100%|██████████| 1171/1171 [02:21<00:00,  8.27it/s, loss=1.5865]
Epoch 3/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.51it/s, loss=1.5173]
Epoch 4/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.51it/s, loss=1.5119]
Epoch 5/25: 100%|██████████| 1171/1171 [02:18<00:00,  8.48it/s, loss=1.5457]
Epoch 6/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.55it/s, loss=1.5448]
Epoch 7/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.56it/s, loss=1.5214]
Epoch 8/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.60it/s, loss=1.5079]
Epoch 9/25: 100%|██████████| 1171/1171 [02:15<00:00,  8.62it/s, loss=1.5005]
Epoch 10/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.60it/s, loss=1.4944]
Epoch 11/25: 100%|██████████| 1171/1171 [02:15<00:00,  8.62it/s, loss=1.5095]
Epoch 12/25: 100%|██████████| 1171/1171 [02:15<00:00,  8.64it/s, loss=1.4912]
Epoch 13/25: 100%|██████████| 1171/1171 [02:15<00:00,  8.61it/s, loss=1.4735]
Epoch 14/25: 100%|██████████| 1171/1171 [02:15<00:00,  8.66it/s, loss=1.4791]
Epoch 15/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.57it/s, loss=1.4743]
Epoch 16/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.50it/s, loss=1.4796]
Epoch 17/25: 100%|██████████| 1171/1171 [02:16<00:00,  8.58it/s, loss=1.4748]
Epoch 18/25: 100%|██████████| 1171/1171 [02:18<00:00,  8.47it/s, loss=1.4675]
Epoch 19/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.54it/s, loss=1.4642]
Epoch 20/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.54it/s, loss=1.4661]
Epoch 21/25: 100%|██████████| 1171/1171 [02:18<00:00,  8.46it/s, loss=1.4689]
Epoch 22/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.49it/s, loss=1.4637]
Epoch 23/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.51it/s, loss=1.4828]
Epoch 24/25: 100%|██████████| 1171/1171 [02:17<00:00,  8.51it/s, loss=1.4907]
Epoch 25/25: 100%|██████████| 1171/1171 [02:18<00:00,  8.46it/s, loss=1.4725]

Visualizing results on validation data...
