In [1]:
# -------------------------------
# Limit threads & CPU-only
# -------------------------------
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""   # disable GPU
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

# Verify
import torch
print("Torch version:", torch.__version__)
print("CUDA available?", torch.cuda.is_available())



Torch version: 2.9.0+cu128
CUDA available? False


In [2]:
# =====================================
# Tiny CPU-friendly 2D UNet Training
# Low-memory version
# =====================================

import os
from PIL import Image
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split


In [3]:
# -------------------------------
# 1. Device
# -------------------------------
device = torch.device("cpu")

In [4]:
# -------------------------------
# 2. Dataset helper (pairs only)
# -------------------------------
def collect_image_mask_pairs(root_dir):
    pairs = []
    for patient_dir in os.listdir(root_dir):
        full_patient_dir = os.path.join(root_dir, patient_dir)
        if not os.path.isdir(full_patient_dir):
            continue
        files = sorted(os.listdir(full_patient_dir))
        images = [f for f in files if not f.endswith("_mask.tif")]
        for img_name in images:
            img_path = os.path.join(full_patient_dir, img_name)
            mask_name = img_name.replace(".tif", "_mask.tif")
            mask_path = os.path.join(full_patient_dir, mask_name)
            if os.path.exists(mask_path):
                pairs.append((img_path, mask_path))
    return pairs

In [16]:
# -------------------------------
# 3. On-the-fly Dataset
# -------------------------------
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import torch

class SliceDataset(Dataset):
    def __init__(self, root_dir):
        # collect all image/mask pairs once
        self.samples = []
        for patient_dir in os.listdir(root_dir):
            full_patient_dir = os.path.join(root_dir, patient_dir)
            if not os.path.isdir(full_patient_dir):
                continue
            files = sorted(os.listdir(full_patient_dir))
            images = [f for f in files if not f.endswith("_mask.tif")]
            for img_name in images:
                img_path = os.path.join(full_patient_dir, img_name)
                mask_name = img_name.replace(".tif", "_mask.tif")
                mask_path = os.path.join(full_patient_dir, mask_name)
                if os.path.exists(mask_path):
                    self.samples.append((img_path, mask_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]

        # Load image/mask lazily, convert to float32, normalize
        img = np.array(Image.open(img_path), dtype=np.float32) / 255.0
        mask = np.array(Image.open(mask_path), dtype=np.float32) / 255.0

        # If 3D (H, W, D), pick middle slice
        if img.ndim == 3:
            mid = img.shape[2] // 2
            img = img[:, :, mid]
            mask = mask[:, :, mid]

        # Add channel dimension for Conv2d
        img = np.expand_dims(img, 0)
        mask = np.expand_dims(mask, 0)

        return {
            "image": torch.tensor(img, dtype=torch.float32),
            "mask": torch.tensor(mask, dtype=torch.float32),
        }


In [17]:
# -------------------------------
# 4. Tiny UNet
# -------------------------------
class UNetTiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(),
                                  nn.Conv2d(8, 8, 3, padding=1), nn.ReLU())
        self.pool = nn.MaxPool2d(2)
        self.enc2 = nn.Sequential(nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(),
                                  nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec1 = nn.Sequential(nn.Conv2d(24, 8, 3, padding=1), nn.ReLU(),
                                  nn.Conv2d(8, 8, 3, padding=1), nn.ReLU())
        self.outc = nn.Conv2d(8, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        d1 = self.up(e2)
        d1 = torch.cat([d1, e1], dim=1)
        out = self.dec1(d1)
        return self.outc(out)

In [18]:
# -------------------------------
# 5. Dice loss / metric
# -------------------------------
def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return 1 - (2.0 * intersection + eps) / (union + eps)

def dice_metric(pred, target, eps=1e-6):
    pred = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    if union == 0:
        return 1.0
    return (2.0 * intersection + eps) / (union + eps)

In [19]:
# -------------------------------
# 6. Data
# -------------------------------
data_root = "/home/jovyan/.cache/kagglehub/datasets/mateuszbuda/lgg-mri-segmentation/versions/2/kaggle_3m"
pairs = collect_image_mask_pairs(data_root)
dataset = SliceDataset(pairs)
if len(dataset) == 0:
    raise RuntimeError("No image/mask pairs found. Check paths!")
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1)

TypeError: listdir: path should be string, bytes, os.PathLike, integer or None, not list

In [15]:

# -------------------------------
# 7. Model / optimizer
# -------------------------------
model = UNetTiny().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# -------------------------------
# 8. Training loop
# -------------------------------
max_epochs = 3  # keep tiny for testing
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    val_dice = 0.0
    
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            val_dice += dice_metric(outputs, masks)
    val_dice /= len(val_loader)
    print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Dice: {val_dice:.4f}")

    # Save checkpoint each epoch
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(model.state_dict(), f"checkpoints/unet_epoch{epoch+1}.pth")

AttributeError: 'SliceDataset' object has no attribute 'samples'