In [1]:
%pip install torchio

Collecting torchio
  Downloading torchio-0.20.22-py3-none-any.whl.metadata (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9->torchio)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9->torchio)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9->torchio)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-no

### imports

In [2]:
from pathlib import Path

import torchio as tio # 3D - DataHandling (sampler/aggregator)
import torch #
import pytorch_lightning as pl # training/fit

from pytorch_lightning.callbacks import ModelCheckpoint # Automatically saves your model during training.
from pytorch_lightning.loggers import TensorBoardLogger # Sends training metrics to TensorBoard so you can visualize them.

In [3]:
import sys
sys.path.append("/kaggle/input/updated-unet3d")  # add directory to Python path

from updated_unet3d import UNet  # import UNet from the file

### set paths

In [4]:
train_path_1 = Path("/kaggle/input/liversegtrainimages1/")
train_path_2 = Path("/kaggle/input/liversegtrainimages2/")

val_path = Path("/kaggle/input/liversegtestimages/")

label_root = Path("/kaggle/input/liversegtrainlabels/")

In [5]:
def change_img_to_label_path(img_path: Path, label_root: Path = label_root) -> Path:
    return label_root / img_path.name

In [6]:
subject_path_list = list(train_path_1.glob("liver_*"))
subject_path_list.extend(train_path_2.glob("liver_*"))

In [7]:
subject_path_list[0].name

'liver_50.nii'

In [8]:
len(subject_path_list)

131

In [9]:
subject_path_list

[PosixPath('/kaggle/input/liversegtrainimages1/liver_50.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_9.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_46.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_90.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_96.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_97.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_55.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_4.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_91.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_38.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_83.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_7.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_51.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_81.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_67.nii'),
 PosixPath('/kaggle/input/liversegtrainimages1/liver_39.ni

### create subjects

In [10]:
subjects = []

for subject_path in subject_path_list:
    label_path = change_img_to_label_path(subject_path)

    # sanity check
    if isinstance(subject_path, dict):
        raise TypeError(f"subject_path is a dictionary! {subject_path}")
    if isinstance(label_path, dict):
        raise TypeError(f"label_path is a dictionary! {label_path}")

    subject = tio.Subject(
        CT=tio.ScalarImage(subject_path),  # lazy load
        Label=tio.LabelMap(label_path)     # lazy load
    )
    
    subjects.append(subject)

In [11]:
print(type(subjects[0]["CT"]), subjects[0]["CT"])
print(type(subjects[0]["Label"]), subjects[0]["Label"])

<class 'torchio.data.image.ScalarImage'> ScalarImage(shape: (1, 512, 512, 240); spacing: (0.91, 0.91, 2.50); orientation: RAS+; path: "/kaggle/input/liversegtrainimages1/liver_50.nii")
<class 'torchio.data.image.LabelMap'> LabelMap(shape: (1, 512, 512, 240); spacing: (0.91, 0.91, 2.50); orientation: RAS+; path: "/kaggle/input/liversegtrainlabels/liver_50.nii")


In [12]:
import torchio as tio

for subject in subjects:
    # Check orientation
    assert subject["CT"].orientation == ("R", "A", "S"), \
        f"CT orientation is wrong for subject: {subject}"

    # Check that CT is a ScalarImage
    assert isinstance(subject["CT"], tio.ScalarImage), \
        f"CT is not a ScalarImage for subject: {subject}"

    # Check that Label is a LabelMap
    assert isinstance(subject["Label"], tio.LabelMap), \
        f"Label is not a LabelMap for subject: {subject}"


### set augmentation pipeline for transform
- applied when creating datasets

In [13]:
process = tio.Compose([                                 # Deterministic: Things inside Compose happen the same way every time you call it.
    tio.ToCanonical(),                     # step 1: fix orientation
    tio.Resample(target = 'CT'),                         # step 2: align all images to same grid
    tio.RescaleIntensity((-1, 1)),          # step 3: normalize intensity
    tio.CropOrPad((256, 256, 200)),         # step 4: now safe to crop/pad
])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10)) # Stocastic(Random): Every time you apply it, the result changes slightly.

# OR: Combine both
# pipeline = tio.Compose([
#     tio.RescaleIntensity((-1, 1)),
#     tio.CropOrPad((256, 256, 200)),
#     tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))
# ])

train_transform = tio.Compose([process, augmentation]) # equivalent to transforms.Compose in torchVision
val_transform = process

### create datasets in torchio

### 1️⃣ Why smaller patches speed up training

#### Less memory usage
Smaller patches mean fewer voxels per batch. This allows:

- Larger batch sizes
- Fewer out-of-memory errors on the GPU

#### Fewer computations
Neural networks (especially 3D CNNs) scale roughly with the **number of voxels in the input**.

- Example:  
  - A `96³` patch has `884,736` voxels (`96×96×96`)  
  - A `64³` patch has `262,144` voxels → ~3.3× fewer voxels → faster forward/backward passes

#### Faster data loading
Smaller patches take less disk I/O and less augmentation processing time.

---

### 2️⃣ Trade-offs of smaller patches

#### Less context
Smaller patches capture **less of the surrounding anatomy**, which can hurt network performance for tasks where context matters (e.g., organ segmentation).

#### More patches per volume needed
To cover the same volume or ensure all labels are seen, you may need to sample **more patches per epoch**, which can reduce some of the speed gains.

#### Potential edge effects
Very small patches can cut off important structures, leading to more **boundary artifacts** during training.

---

### 3️⃣ Practical guidelines

- For 3D medical images:
  - Patch sizes around `64³` to `128³` are common.  
  - Larger patches (`128³–256³`) give more context but are slower.  
  - Smaller patches (`32³–64³`) are fast but may lose anatomical context.

- A good strategy:
  1. Start with a patch size that fits your **GPU comfortably**.  
  2. Use **data augmentation** to increase diversity.  
  3. Experiment to find the best trade-off between **accuracy** and **training speed**.

In [14]:
train_dataset = tio.SubjectsDataset(subjects[:105], transform = train_transform) # subjects[:105] → first 105 subjects for training.
val_dataset = tio.SubjectsDataset(subjects[105:], transform = val_transform) # subjects[105:] → remaining subjects for validation.

sampler = tio.data.LabelSampler(
    patch_size = 96, # Size of the cubic patch to sample (here 96×96×96 voxels).
    label_name = 'Label', # The sampler uses this to find where each class is in the volume.
    label_probabilities = {0:0.2, 1:0.3, 2:0.5}
)

Controls how often patches are centered on each label class.

- Class 0 → 20% of patches
- Class 1 → 30% of patches
- Class 2 → 50% of patches

Useful when some labels are rare — ensures your network sees enough examples of each class.

In [15]:
train_patches_queue = tio.Queue(
    train_dataset,
    max_length=40,           # Maximum number of patches stored in the queue at any given time. 
    samples_per_volume=4,    # Number of patches sampled per subject per epoch.
    sampler=sampler,
    num_workers=2            # how many CPU processes are used to sample and apply transformations to patches.
)

val_patches_queue = tio.Queue(
    val_dataset,
    max_length=40,           # Smaller queue → less RAM used, but may increase CPU/GPU waiting if patches are generated on-the-fly.
    samples_per_volume=4,    # Reducing it (fewer patches per volume) → faster queue filling
    sampler=sampler,
    num_workers=2            # higher num_workers → faster patch preparation → GPU waits less. but more RAM usage
)

Original volumes (CPU)
        **>**
tio.Queue (CPU RAM) (patches are stored here)
        **>**
Batch sent to GPU (forward/backward pass)
        **>**
Next batch sent from CPU → GPU

| Parameter                | What it does                            | Memory impact                             | Importance                     |
| ------------------------ | --------------------------------------- | ----------------------------------------- | ------------------------------ |
| `Queue num_workers`      | Parallel patch sampling + augmentation  | High (each worker generates patches)      | Critical for GPU utilization   |
| `DataLoader num_workers` | Parallel fetching of batches from queue | Low (just retrieves preprocessed patches) | Optional for queue-based setup |

In [16]:
def subject_to_tensor(batch):
    """
    Convert a list of TorchIO Subjects into a dict of batched tensors.
    """
    ct_list = []
    label_list = []

    for subject in batch:  # batch is a list of Subjects
        ct_list.append(subject['CT'].data)
        label_list.append(subject['Label'].data)

    # Stack along a new batch dimension
    return {
        'CT': torch.stack(ct_list, dim=0),
        'Label': torch.stack(label_list, dim=0)
    }

In [17]:
train_loader = torch.utils.data.DataLoader(
    train_patches_queue,  
    batch_size=2,            
    num_workers=0,
    collate_fn=subject_to_tensor
)

val_loader = torch.utils.data.DataLoader(
    val_patches_queue, 
    batch_size=2,
    num_workers=0,
    collate_fn=subject_to_tensor
)

| Parameter            | Memory impact                           |
| -------------------- | --------------------------------------- |
| `samples_per_volume` | CPU RAM (more patches in queue)         |
| `batch_size`         | GPU memory (number of patches per step) |
| `patch_size`         | Both CPU (queue) and GPU (batch)        |

In [18]:
batch = next(iter(train_loader))
print(type(batch))  # should be a dict of torch.Tensor batches

<class 'dict'>


In [19]:
class Segmenter(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.model = UNet()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 1e-4)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, data):

        return self.model(data)

    def training_step(self, batch, batch_idx):
        img = batch["CT"]  # ["data"] already a tensor
        mask = batch["Label"][:, 0]  # remove channel dim
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        self.log("Train Loss", loss)    # you could add on_step=False, on_epoch=True to log once per epoch.
        return loss

    def validation_step(self, batch, batch_idx):
        img = batch["CT"]  # ["data"]already a tensor
        mask = batch["Label"][:, 0]  # remove channel dim
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        self.log("Val Loss", loss)
        return loss

    def configure_optimizers(self):
        return [self.optimizer]

In [20]:
model = Segmenter()

In [21]:
checkpoint_callback = ModelCheckpoint(monitor = "Val Loss", save_top_k = 10, mode = "min")

With mode="min" and save_top_k=3:

- Epoch 1 → 0.50 → saved (top 1)

- Epoch 2 → 0.48 → saved (top 2)

- Epoch 3 → 0.52 → not saved (higher than current top 2)

- Epoch 4 → 0.46 → saved (top 3, now top losses: 0.46, 0.48, 0.50)

- Epoch 5 → 0.49 → replaces worst in top 3? Depends on save_top_k logic; it may replace 0.50

✅ Key point: Lightning keeps the

In [22]:
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs"),
    log_every_n_steps=10,
    callbacks=checkpoint_callback,
    max_epochs=10
    # precision="16-mixed"
)

In [23]:
# trainer.fit(model, train_loader, val_loader)