In [None]:
!pip install monai nibabel matplotlib


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Define paths
DATASET_DIR = "/content/drive/MyDrive/brats2023/train"
OUTPUT_DIR = "/content/drive/MyDrive/segresnet_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [9]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from glob import glob

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd,
    RandRotate90d, ToTensord
)
from monai.data import Dataset, CacheDataset
from monai.networks.nets import SegResNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism


In [10]:
set_determinism(seed=42)

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=500, b_min=0.0, b_max=1.0, clip=True),
    RandCropByPosNegLabeld(keys=["image", "label"], label_key="label",
                           spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
    ToTensord(keys=["image", "label"]),
])


In [None]:
image_paths = sorted(glob(os.path.join(DATASET_DIR, "*-t1c.nii.gz")))
label_paths = sorted(glob(os.path.join(DATASET_DIR, "*-seg.nii.gz")))

data_dicts = [{"image": img, "label": seg} for img, seg in zip(image_paths, label_paths)]

train_ds = CacheDataset(data=data_dicts, transform=train_transforms, cache_rate=1.0)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SegResNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    init_features=32,
    blocks_down=(1, 2, 2, 4),
    blocks_up=(1, 1, 1)
).to(device)

loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)


In [None]:
max_epochs = 10
val_interval = 2
epoch_loss_values = []

for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_loader)}, Loss: {loss.item():.4f}")

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    # Save model
    if (epoch + 1) % val_interval == 0:
        model_path = os.path.join(OUTPUT_DIR, f"segresnet_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), model_path)
        print(f"Saved model checkpoint: {model_path}")


In [None]:
plt.plot(epoch_loss_values)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()


In [None]:
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=500, b_min=0.0, b_max=1.0, clip=True),
    ToTensord(keys=["image", "label"])
])


In [None]:
VAL_DIR = "/content/drive/MyDrive/brats2023/val"

val_images = sorted(glob(os.path.join(VAL_DIR, "*-t1c.nii.gz")))
val_labels = sorted(glob(os.path.join(VAL_DIR, "*-seg.nii.gz")))

val_data = [{"image": img, "label": lbl} for img, lbl in zip(val_images, val_labels)]
val_ds = CacheDataset(data=val_data, transform=val_transforms, cache_rate=1.0)
val_loader = DataLoader(val_ds, batch_size=1)


In [None]:
dice_metric = DiceMetric(include_background=False, reduction="mean")

model.load_state_dict(torch.load("/content/drive/MyDrive/segresnet_output/segresnet_epoch_10.pth"))
model.eval()


In [None]:
with torch.no_grad():
    for val_data in val_loader:
        val_inputs = val_data["image"].to(device)
        val_labels = val_data["label"].to(device)

        val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 1, model)
        val_outputs = torch.sigmoid(val_outputs) > 0.5

        dice_metric(y_pred=val_outputs, y=val_labels)

mean_dice = dice_metric.aggregate().item()
print("Mean Dice score on validation set:", mean_dice)
dice_metric.reset()


In [None]:
TEST_DIR = "/content/drive/MyDrive/brats2023/test"
test_images = sorted(glob(os.path.join(TEST_DIR, "*-t1c.nii.gz")))

test_data = [{"image": img} for img in test_images]

test_transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Spacingd(keys=["image"], pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
    Orientationd(keys=["image"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=500, b_min=0.0, b_max=1.0, clip=True),
    ToTensord(keys=["image"])
])

test_ds = CacheDataset(data=test_data, transform=test_transforms, cache_rate=1.0)
test_loader = DataLoader(test_ds, batch_size=1)

OUTPUT_PRED_DIR = "/content/drive/MyDrive/segresnet_predictions"
os.makedirs(OUTPUT_PRED_DIR, exist_ok=True)

model.eval()
for i, test_case in enumerate(test_loader):
    test_input = test_case["image"].to(device)
    test_output = sliding_window_inference(test_input, (96, 96, 96), 1, model)
    pred = (torch.sigmoid(test_output) > 0.5).float()

    # Convert back to Nifti
    pred_np = pred.cpu().numpy()[0, 0]
    affine = nib.load(test_case["image_meta_dict"]["filename_or_obj"][0]).affine
    pred_nii = nib.Nifti1Image(pred_np.astype(np.uint8), affine)
    nib.save(pred_nii, os.path.join(OUTPUT_PRED_DIR, f"pred_{i:03}.nii.gz"))

print("Saved all predictions.")
