In [7]:
# In a fresh notebook cell before doing any heavy imports
%config Application.verbose_crash=False
%xmode Minimal

Collecting monai
  Using cached monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Using cached monai-1.5.1-py3-none-any.whl (2.7 MB)
Installing collected packages: monai
Successfully installed monai-1.5.1


In [8]:
# ==============================
# LGG MRI Segmentation Experiment
# CPU-only, single-image proof-of-concept
# ==============================

# ------------------------------
# 1. Imports
# ------------------------------
import os
import torch
import numpy as np
import nibabel as nib
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from torch.utils.data import Dataset, DataLoader
from monai.transforms import Compose, EnsureChannelFirstd, ScaleIntensityd, ToTensord



Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/jovyan/blueberry/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 2194, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/blueberry/lib/python3.12/site-packages/IPython/core/ultratb.py", line 1185, in structured_traceback
  File "/home/jovyan/blueberry/lib/python3.12/site-packages/IPython/core/ultratb.py", line 1056, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/blueberry/lib/python3.12/site-packages/IPython/core/ultratb.py", line 864, in structured_traceback
    formatted_exceptions: list[list[str]] = self.format_exception_as_a_whole(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/blueberry/lib/python3.12/site-packages/IPython/core/ultratb.py", line 749, in format_exception_as_a_whole
    recor

In [None]:
# ------------------------------
# 2. Paths (KaggleHub cache)
# ------------------------------
DATA_ROOT = "/home/jovyan/.cache/kagglehub/datasets/mateuszbuda/lgg-mri-segmentation/versions/2"
IMAGES_DIR = os.path.join(DATA_ROOT, "kaggle_3m")          # adjust folder if necessary
LABELS_DIR = os.path.join(DATA_ROOT, "kaggle_3m_masks")    # adjust folder if necessary



In [None]:
# ------------------------------
# 3. Minimal Dataset Class
# ------------------------------
class SingleSliceDataset(Dataset):
    def __init__(self, images_dir, labels_dir, slice_axis=2, max_images=1):
        self.slices = []
        img_files = sorted(os.listdir(images_dir))[:max_images]
        for fname in img_files:
            img_path = os.path.join(images_dir, fname)
            mask_path = os.path.join(labels_dir, fname)
            img = nib.load(img_path).get_fdata(dtype=np.float32)
            mask = nib.load(mask_path).get_fdata(dtype=np.float32)
            mask = (mask > 0).astype(np.float32)  # normalize mask to 0/1

            # store slices along slice_axis
            for i in range(img.shape[slice_axis]):
                if slice_axis == 0:
                    self.slices.append((img[i, :, :], mask[i, :, :], fname, i))
                elif slice_axis == 1:
                    self.slices.append((img[:, i, :], mask[:, i, :], fname, i))
                else:
                    self.slices.append((img[:, :, i], mask[:, :, i], fname, i))

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        img_slice, mask_slice, fname, slice_idx = self.slices[idx]
        img_slice = np.expand_dims(img_slice, axis=0)
        mask_slice = np.expand_dims(mask_slice, axis=0)
        return {
            "image": torch.tensor(img_slice, dtype=torch.float32),
            "label": torch.tensor(mask_slice, dtype=torch.float32),
            "fname": fname,
            "slice_idx": slice_idx
        }



In [None]:
# ------------------------------
# 4. Prepare dataset / loader
# ------------------------------
dataset = SingleSliceDataset(IMAGES_DIR, LABELS_DIR, max_images=1)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [None]:

# ------------------------------
# 5. Model, loss, optimizer
# ------------------------------
device = "cpu"
model = UNet(spatial_dims=2, in_channels=1, out_channels=1,
             channels=(16,32,64), strides=(2,2)).to(device)
loss_fn = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")



In [None]:
# ------------------------------
# 6. Training loop
# ------------------------------
max_epochs = 5  # short for quick test
for epoch in range(max_epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        imgs = batch["image"].to(device)
        labels = batch["label"].to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # --------------------------
    # Validation / Dice
    # --------------------------
    model.eval()
    val_dices = []
    with torch.no_grad():
        for batch in train_loader:  # using same batch for demo
            imgs = batch["image"].to(device)
            labels = batch["label"].to(device)
            preds = torch.sigmoid(model(imgs)) > 0.5
            dice = dice_metric(y_pred=preds, y=labels)
            val_dices.append(dice.item())
    avg_dice = np.mean(val_dices)
    print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Dice: {avg_dice:.4f}")



In [None]:
# ------------------------------
# 7. Save model checkpoint
# ------------------------------
os.makedirs("./models", exist_ok=True)
torch.save(model.state_dict(), "./models/unet_cpu.pth")
print("Model saved to ./models/unet_cpu.pth")
