In [1]:
%pip install torchio --q
%pip install monai --q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m84.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [2]:
from pathlib import Path

import numpy as np

import torchio as tio 
import torch
import pytorch_lightning as pl 

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger 

## **PREPROCESING**

In [3]:
root_path = Path("/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/")

In [4]:
def get_img_path(patient_path: Path) -> Path:
    return next((patient_path / "data").glob("*.nii"))

def get_label_path(patient_path: Path) -> Path:
    return next((patient_path / "label").glob("*.nii"))

In [5]:
subject_path_list = list(root_path.glob("*"))

In [6]:
print(subject_path_list[0])
len(subject_path_list)

/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/7


63

In [7]:
subjects = []

for subject_path in subject_path_list:

    img_path = get_img_path(subject_path)
    label_path = get_label_path(subject_path)

    subject = tio.Subject(
        CT=tio.ScalarImage(img_path),  # lazy load
        Label=tio.LabelMap(label_path)     # lazy load
    )
    
    subjects.append(subject)

In [8]:
print(type(subjects[15]["CT"]), subjects[15]["CT"])
print(type(subjects[15]["Label"]), subjects[15]["Label"])

<class 'torchio.data.image.ScalarImage'> ScalarImage(shape: (1, 256, 256, 95); spacing: (1.00, 1.00, 1.00); orientation: RAS+; path: "/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/60/data/60_data.nii")
<class 'torchio.data.image.LabelMap'> LabelMap(shape: (1, 256, 256, 95); spacing: (1.00, 1.00, 1.00); orientation: RAS+; path: "/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/60/label/60_mask.nii")


In [9]:
labs = [sub["Label"].data.numpy() for sub in subjects]
labels = [np.any(lab != 0) for lab in labs]

if False in labels:
    print("There is at least one False")

**just realised it folks, there's no negative case in dataset**

**medumb:/**

**gotta evaluate overlap between predicted mask and ground truth**

In [10]:
depths = [sub["CT"].shape[3] for sub in subjects]
median_depth = int(np.median(depths)) # use median cuz, mean is sensitive to outliers

median_depth

222

In [11]:
process = tio.Compose([
    tio.ToCanonical(),                              # step 1: fix orientation - RAS
    tio.Resample(target = 'CT'),                    # step 2: align all images
    tio.RescaleIntensity((-1, 1)),                  # step 3: normalize intensity
    tio.CropOrPad((256, 256, median_depth))         # step 4: crop or pad (of course:|)
])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))

train_transform = tio.Compose([process, augmentation])
val_transform = tio.Compose([process])

In [12]:
train_dataset = tio.SubjectsDataset(subjects[:50], transform = train_transform) # 80/20 split
val_dataset = tio.SubjectsDataset(subjects[50:], transform = val_transform)     # ~50 train, ~13 val

In [13]:
label_sampler = tio.data.LabelSampler(
    patch_size = 64, 
    label_name = 'Label', 
    label_probabilities = {0:0.3, 1:0.7}                         
)

| Patch size (HxWxD) | Notes                                                     |
| ------------------ | --------------------------------------------------------- |
| (64, 64, 64)       | Standard, fits most GPUs, may crop large tumors           |
| (96, 96, 96)       | Larger context, fewer patches per volume, higher memory   |
| (128, 128, 128)    | Only if GPU can handle; almost full volume in-plane (HxW) |

In [14]:
train_queue = tio.Queue(
    train_dataset,
    samples_per_volume=4,    # 4 patches (sampled) per subject/volume
    max_length=40,           # until 40 patches filled up 
    sampler=label_sampler,   # sampled acc. to label_probs
    num_workers=2         
)

val_queue = tio.Queue(
    val_dataset,
    samples_per_volume=4, 
    max_length=40,           
    sampler=label_sampler,
    num_workers=2           
)

In [15]:
def subject_to_tensor(batch):
    
    ct_list = []
    label_list = []

    for subject in batch:                         # batch is a list of Subjects
        ct_list.append(subject['CT'].data)
        label_list.append(subject['Label'].data)

    return {                                      # Stack along a new batch dimension
        'CT': torch.stack(ct_list, dim=0),
        'Label': torch.stack(label_list, dim=0)
    }

In [16]:
train_loader = torch.utils.data.DataLoader(
    train_queue,  
    batch_size=2,            # So tensor shape for images: (2, 1, 64, 64, 64)
    num_workers=0,
    collate_fn=subject_to_tensor,
    shuffle = True
)

val_loader = torch.utils.data.DataLoader(
    val_queue, 
    batch_size=2,
    num_workers=0,
    collate_fn=subject_to_tensor
)

In [17]:
batch = next(iter(train_loader))
print(type(batch)) 

<class 'dict'>


## **TRAIN**

In [18]:
from monai.networks.nets import UNet
from monai.losses import DiceLoss

class LungTumorSegmentationModel(pl.LightningModule):
    
    def __init__(self, learning_rate=1e-4):
        super().__init__()

        self.save_hyperparameters()
        
        self.model = UNet(
            spatial_dims=3,                  # specifies 3D convolutions because input is 3D CT data
            in_channels=1,                   # grayscale CT
            out_channels=2,                  # number of output channels (i.e. the number of segmentation classes)
            channels=(16, 32, 64, 128, 256), # number of feature maps at each level of the U-Net encoder/decoder
            strides=(2, 2, 2, 2),            # downsampling factors for each encoder level (16>32, , , 128>256) (i.e. patches of 64 > 32 > 16 > 8 > 4)
            num_res_units=2,                 # number of residual units per level
        )
        
        self.dice_loss = DiceLoss(to_onehot_y=True, softmax=True)
        self.ce_loss = torch.nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 1e-4)
    
    def forward(self, x):
        return self.model(x)
    
    def compute_loss(self, outputs, labels_ce, labels_dice):
        return self.dice_loss(outputs, labels_dice) + self.ce_loss(outputs, labels_ce)
    
    def training_step(self, batch, batch_idx):
        images = batch['CT'].data         
        labels = batch['Label'].data 

        labels_ce = labels.squeeze(1).long()      # (B, H, W, D) for CrossEntropy
        labels_dice = labels                      # (B, 1, H, W, D) for DiceLoss
        
        outputs = self(images)                    # output: (B, H, W, D)
        loss = self.compute_loss(outputs, labels_ce, labels_dice)
        
        self.log('train_loss', loss, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images = batch['CT'].data
        labels = batch['Label'].data

        labels_ce = labels.squeeze(1).long()  # (B, 1, H, W, D) -> (B, H, W, D) - B dimension needed by CELoss
        labels_dice = labels                  # (B, 1, H, W, D) for DiceLoss
        
        outputs = self(images)
        loss = self.compute_loss(outputs, labels_ce, labels_dice)
        
        self.log('val_loss', loss, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

2025-09-25 15:09:57.241686: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758812997.465923      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758812997.530222      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


**Shapes for 3D Lung Tumor Segmentation**

- CT: `(1, 256, 256, ~222)` → when you make a batch of size b: `(b, channelSize(c)=1, 256, 256, ~222)`

- **CrossEntropyLoss**  
  - input: `(b, c, H, W, D)`  
  - target: `(b, H, W, D)` → `labels_ce = labels.squeeze(1).long()`

- **DiceLoss(to_onehot_y=True)**  
  - input: `(b, c, H, W, D)`  
  - target: `(b, 1, H, W, D)` → `labels_dice = labels`


| Loss                       | Labels type             | Labels shape      |
| -------------------------- | ----------------------- | ----------------- |
| CrossEntropyLoss           | `torch.long` (integers) | `(B, H, W, D)`    |
| DiceLoss(to_onehot_y=True) | `torch.float`           | `(B, 1, H, W, D)` |

In [19]:
model = LungTumorSegmentationModel()

In [20]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',        
    save_top_k=3,               # save top 3 models; 1 for saving only the best model
    mode='min'
)

In [21]:
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs"),
    log_every_n_steps=10,
    callbacks=checkpoint_callback,
    max_epochs=10
)

| Parameter           | Where it’s used   | What it controls                   |
| ------------------- | ----------------- | ---------------------------------- |
| `every_n_epochs`    | `ModelCheckpoint` | Save checkpoint every N epochs     |
| `log_every_n_steps` | `Trainer`         | Log metrics every N training steps |

In [22]:
trainer.fit(model, train_loader, val_loader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## **EVALUATION**

model.eval(); # semicolon; to suppress the output

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device); 

In [24]:
from monai.metrics import DiceMetric

dice_metric = DiceMetric(
    include_background=False,   # ignore the background class (0) when computing Dice, because tumors are usually sparse.
    reduction="mean"            # average the Dice across the batch and classes.
)

with torch.no_grad():
    
    for batch in val_loader:
        
        image = batch['CT'].data.to(device)
        label = batch['Label'].data.to(device)

        outputs = model(image)  # (B, C, H, W, D) # raw logits per voxel - kind of - output
        
        preds = torch.argmax(    # for each voxel, selects the class with highest predicted probability
            outputs, 
            dim=1, 
            keepdim=True         # since performing argmax removes channel dim, so preserve preds shape: (B, C, H, W, D).
        )  

        dice_metric(y_pred=preds, y=label)

### **What `torch.argmax` does**

* Converts raw logits outputs of shape (2,1,H,W,D) to class values (0: bg, 1: tumor) voxeled output of same shape 
* Each voxel has **C values**, one for each class (logits).
* `torch.argmax(outputs, dim=1)` → selects the class **with highest probability** per voxel.

**Shape after argmax (without keepdim):** `(B, H, W, D)`

**With keepdim=True:** `(B, 1, H, W, D)`

---

### **Why not dim=0 or dim=2,3,4?**

* `dim=0` → across batch → nonsensical, you’d mix different images
* `dim=2,3,4` → across spatial dimensions → nonsensical, you’d mix different voxels
* Only `dim=1` makes sense because **that’s the class/channel axis**

In [25]:
dice_score = dice_metric.aggregate().item()
dice_metric.reset()
print("Mean Dice Score:", dice_score)

Mean Dice Score: 0.4951266646385193


- Even though we use a **combination of CrossEntropyLoss and DiceLoss** for training, we often report **Dice metric** because:  
  - It directly measures the **overlap between predicted and ground truth masks**.  
  - It handles **class imbalance** well, making it sensitive to small tumors.

- **CrossEntropyLoss** measures **voxel-wise classification accuracy**. (Computationaly Taxing)

- **Dice metric values:**  
  - Range: **0 to 1**, where 1 = perfect overlap  
  - ~0.7–0.8 → considered good  
  - ~0.5 → moderate/average