In [1]:
%pip install torchio --q
%pip install monai --q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m87.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [2]:
from pathlib import Path

import numpy as np

import torchio as tio 
import torch
import pytorch_lightning as pl 

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger 

## **PREPROCESING**

In [3]:
root_path = Path("/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/")

In [4]:
def get_img_path(patient_path: Path) -> Path:
    return next((patient_path / "data").glob("*.nii"))

def get_label_path(patient_path: Path) -> Path:
    return next((patient_path / "label").glob("*.nii"))

In [5]:
subject_path_list = list(root_path.glob("*"))

In [6]:
print(subject_path_list[0])
len(subject_path_list)

/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/7


63

In [7]:
subjects = []

for subject_path in subject_path_list:

    img_path = get_img_path(subject_path)
    label_path = get_label_path(subject_path)

    subject = tio.Subject(
        CT=tio.ScalarImage(img_path),  # lazy load
        Label=tio.LabelMap(label_path)     # lazy load
    )
    
    subjects.append(subject)

In [8]:
print(type(subjects[15]["CT"]), subjects[15]["CT"])
print(type(subjects[15]["Label"]), subjects[15]["Label"])

<class 'torchio.data.image.ScalarImage'> ScalarImage(shape: (1, 256, 256, 95); spacing: (1.00, 1.00, 1.00); orientation: RAS+; path: "/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/60/data/60_data.nii")
<class 'torchio.data.image.LabelMap'> LabelMap(shape: (1, 256, 256, 95); spacing: (1.00, 1.00, 1.00); orientation: RAS+; path: "/kaggle/input/medical-decathlon-lung-tumor-segmentation/Lung-Tumor-Segmentation/60/label/60_mask.nii")


In [9]:
depths = [sub["CT"].shape[3] for sub in subjects]
median_depth = int(np.median(depths)) # use median cuz, mean is sensitive to outliers

median_depth

222

In [10]:
process = tio.Compose([
    tio.ToCanonical(),                              # step 1: fix orientation - RAS
    tio.Resample(target = 'CT'),                    # step 2: align all images
    tio.RescaleIntensity((-1, 1)),                  # step 3: normalize intensity
    tio.CropOrPad((256, 256, median_depth))         # step 4: crop or pad (of course:|)
])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))

train_transform = tio.Compose([process, augmentation])
val_transform = tio.Compose([process])

In [11]:
train_dataset = tio.SubjectsDataset(subjects[:50], transform = train_transform) # 80/20 split
val_dataset = tio.SubjectsDataset(subjects[50:], transform = val_transform)     # ~50 train, ~13 val

In [12]:
label_sampler = tio.data.LabelSampler(
    patch_size = 64, 
    label_name = 'Label', 
    label_probabilities = {0:0.3, 1:0.7}                         
)

In [13]:
train_queue = tio.Queue(
    train_dataset,
    samples_per_volume=4,    
    max_length=40,           
    sampler=label_sampler,   
    num_workers=2         
)

val_queue = tio.Queue(
    val_dataset,
    samples_per_volume=4, 
    max_length=40,           
    sampler=label_sampler,
    num_workers=2           
)

In [14]:
def subject_to_tensor(batch):
    
    ct_list = []
    label_list = []

    for subject in batch:                         # Batch is a list of Subjects
        ct_list.append(subject['CT'].data)
        label_list.append(subject['Label'].data)

    return {                                      # Stack along a new batch dimension
        'CT': torch.stack(ct_list, dim=0),
        'Label': torch.stack(label_list, dim=0)
    }

In [15]:
train_loader = torch.utils.data.DataLoader(
    train_queue,  
    batch_size=2,            # So tensor shape for images: (2, 1, 64, 64, 64)
    num_workers=0,
    collate_fn=subject_to_tensor,
    shuffle = True
)

val_loader = torch.utils.data.DataLoader(
    val_queue, 
    batch_size=2,
    num_workers=0,
    collate_fn=subject_to_tensor
)

In [16]:
batch = next(iter(train_loader))
print(type(batch)) 

<class 'dict'>


## **TRAIN**

In [17]:
# LOSS FUNCTION
import torch.nn as nn
from monai.losses import GeneralizedDiceLoss, TverskyLoss

class MaskedGDL(nn.Module):
        
    def __init__(self, include_background=False, smooth_nr=1e-5, smooth_dr=1e-5):
        super().__init__()
        self.gdl = GeneralizedDiceLoss(
            include_background=include_background,
            smooth_nr=smooth_nr,
            smooth_dr=smooth_dr,
            to_onehot_y=True,
            softmax=True
        )
    def forward(self, pred, target, mask=None):
        """
        Args:
            pred: torch.Tensor, shape [B, C, H, W, D], predicted probabilities/logits
            target: torch.Tensor, shape [B, C, H, W, D], one-hot encoded ground truth
            mask: torch.Tensor, shape [B, 1, H, W, D] or [B, H, W, D], binary mask to focus on ROI
        Returns:
            Loss scalar
        """
        if mask is not None:
            # Ensure mask has same number of channels as pred/target
            if mask.ndim == 4:
                mask = mask.unsqueeze(1)  # [B,1,H,W,D]
            mask = mask.float()
            pred = pred * mask
            target = target * mask

        loss = self.gdl(pred, target)
        return loss

2025-10-02 14:10:50.344369: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759414250.664574      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759414250.754298      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [18]:
from monai.networks.nets import UNet
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete
from monai.networks.utils import one_hot

class LungTumorSegmentationModel(pl.LightningModule):
    
    def __init__(self, learning_rate=1e-4):
        super().__init__()

        self.save_hyperparameters()
        
        self.model = UNet(
            spatial_dims=3,                  # specifies 3D convolutions because input is 3D CT data
            in_channels=1,                   
            out_channels=2,                  
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),            # downsampling factors for each encoder level (16>32, , , 128>256) (i.e. patches of 64 > 32 > 16 > 8 > 4)
            num_res_units=2,                 
        )
        
        self.loss_fn = MaskedGDL(include_background=False)
        self.dice_metric = DiceMetric(include_background=False, reduction="mean")
        self.post_pred = AsDiscrete(argmax=True)
    
    def forward(self, x):
        return self.model(x)
    
    def compute_loss(self, pred, target, mask=None):
        return  self.loss_fn(pred, target, mask)
    
    def training_step(self, batch, batch_idx):
        images = batch['CT'].data
        labels = batch['Label'].data      # shape [B, 1, H, W, D]
        mask = batch.get('Mask', None)    # optional mask
    
        outputs = self(images)            # shape [B, 1, H, W, D]
        loss = self.compute_loss(outputs, labels, mask)
    
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images = batch['CT'].data
        labels = batch['Label'].data
        mask = batch.get('Mask', None)
    
        outputs = self(images)
        loss = self.compute_loss(outputs, labels, mask)
    
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)

        # calculate and log dice score
        preds = self.post_pred(outputs)
        self.dice_metric(y_pred=preds, y=one_hot(labels, num_classes=2))
        self.log('val_dice', self.dice_metric.aggregate().item(), on_epoch=True, prog_bar=True)
        self.dice_metric.reset()
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)

In [19]:
model = LungTumorSegmentationModel()

In [20]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_dice",
    mode="max",         
    save_top_k=3
)

In [21]:
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs"),
    log_every_n_steps=10,
    callbacks=checkpoint_callback,
    max_epochs=30
)

In [22]:
trainer.fit(model, train_loader, val_loader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]