<a href="https://www.kaggle.com/code/muichimon/lungtumor-attentioncustomresunet2d?scriptVersionId=270711806" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# **# 00 - set up**

In [None]:
%pip install torchio --q
%pip install monai --q
%pip install celluloid --q

In [None]:
%matplotlib notebook
    
from pathlib import Path
import nibabel as nib
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

from celluloid import Camera
from IPython.display import HTML

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import os
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  

set_seed(42)

## **set root paths**

In [None]:
train_path = Path("/kaggle/input/lung-cancer-segment/train/")
val_path = Path("/kaggle/input/lung-cancer-segment/val/")

In [None]:
# understanding single slice
img = np.load(train_path/"0"/"data"/"100.npy")
print(img.shape)

fig = plt.figure()
plt.imshow(img, cmap="bone")

In [None]:
def get_sorted_slice_paths(folder: Path):
    files = sorted(folder.glob("*.npy"), key=lambda f: int(f.stem))
    if not files:
        raise ValueError(f"No .npy slices found in {folder}")
    return files

def get_patient_path(patient_id: str, train: bool) -> Path:
    if train:
        return train_path / str(patient_id)
    else:
        return val_path / str(patient_id)

def get_img_path(patient_path: Path) -> Path:
    img_path = patient_path / "data"
    if not img_path.exists():
        raise FileNotFoundError(f"Image folder not found: {img_path}")
    return img_path

def get_label_path(patient_path: Path) -> Path:
    label_path = patient_path / "masks"
    if not label_path.exists():
        raise FileNotFoundError(f"Label folder not found: {label_path}")
    return label_path

## **visualization**

In [None]:
patient_path = get_patient_path("0", True)

img_path = get_img_path(patient_path)
label_path = get_label_path(patient_path)

sorted_img_files = get_sorted_slice_paths(img_path)
sorted_mask_files = get_sorted_slice_paths(label_path)

img_volume = np.stack([np.load(f) for f in sorted_img_files], axis=-1).astype(np.float32)
mask_volume = np.stack([np.load(f) for f in sorted_mask_files], axis=-1).astype(np.uint8)

In [None]:
fig = plt.figure()
camera = Camera(fig)

for i in range(img_volume.shape[2]):
    
    plt.imshow(img_volume[:, :, i], cmap="bone")
    mask_ = np.ma.masked_where(mask_volume[:, :, i] == 0, mask_volume[:, :, i])
    plt.imshow(mask_, cmap="autumn", alpha=0.5)
    
    camera.snap();

animation = camera.animate();
HTML(animation.to_html5_video())

# **# 01 - data set**

In [None]:
from torch.utils.data import Dataset

class LungTumorDataset(Dataset):
    
    """
    Returns individual 2D slices (img, mask) from 3D CT volumes.
    """

    def __init__(self, patient_ids, train=True, transform=None):
        """
        patient_ids : list of patient identifiers (str or int)
        train       : bool, whether to use train or val dataset
        transform   : optional, function to apply data augmentation (img, mask)
        """
        self.samples = []
        self.train = train
        self.transform = transform

        # Build a list of (patient_id, slice_idx)
        for pid in patient_ids:
            patient_path = get_patient_path(pid, self.train)
            img_folder = get_img_path(patient_path)
            num_slices = len(get_sorted_slice_paths(img_folder))
            for slice_idx in range(30, num_slices): # remove first 30 slices
                self.samples.append((pid, slice_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        pid, slice_idx = self.samples[idx]

        patient_path = get_patient_path(pid, self.train)
        img_folder = get_img_path(patient_path)
        mask_folder = get_label_path(patient_path)

        img_file = get_sorted_slice_paths(img_folder)[slice_idx]
        mask_file = get_sorted_slice_paths(mask_folder)[slice_idx]

        img = np.load(img_file).astype(np.float32)
        mask = np.load(mask_file).astype(np.uint8)

         # convert 2D slice to fake 3D volume (C,H,W,D)
        img_volume = torch.tensor(img).unsqueeze(0).unsqueeze(-1)   # (1,H,W,1)
        mask_volume = torch.tensor(mask).unsqueeze(0).unsqueeze(-1) # (1,H,W,1)

        if self.transform:
            # TorchIO expects a Subject
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=img_volume),
                mask=tio.LabelMap(tensor=mask_volume)
            )
            transformed = self.transform(subject)
            img_volume = transformed['image'].data
            mask_volume = transformed['mask'].data

        # squeeze back to 2D (C,H,W)
        img_tensor = img_volume.squeeze(-1)
        mask_tensor = mask_volume.squeeze(-1)

        return img_tensor, mask_tensor

## **define transforms**

In [None]:
import torchio as tio

process = tio.Compose([
    tio.ToCanonical(),               # fix orientation
    tio.RescaleIntensity((-1,1)),    # normalize intensity
    tio.CropOrPad((256,256,1))       # crop/pad HxWxD=1
])

augmentation = tio.RandomAffine(scales=(0.9,1.1), degrees=(-20,20), translation=5)

train_transform = tio.Compose([process, augmentation])
val_transform = tio.Compose([process])

In [None]:
train_ids = list(range(57))
val_ids = list(range(57, 63))

train_dataset = LungTumorDataset(train_ids, train=True, transform=train_transform)
val_dataset   = LungTumorDataset(val_ids, train=False, transform=val_transform)

print(len(val_dataset), len(train_dataset))

# **# 03 - model**

In [None]:
import torch.nn as nn

class DoubleConv2D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),   
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=False)    
        )
        # Match dimensions for residual connection
        self.residual_conv = None
        if in_ch != out_ch:
            self.residual_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        residual = x
        out = self.conv_block(x)
        
        if self.residual_conv is not None:
            residual = self.residual_conv(residual)
        
        out = out + residual  
        return out

In [None]:
import torch
import torch.nn as nn

class AttentionGate2D(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        """
        F_g: number of channels in the gating (decoder) signal
        F_l: number of channels in the skip connection (encoder)
        F_int: number of intermediate channels (usually smaller)
        """
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, g):
        """
        x: encoder feature map (skip connection)
        g: decoder feature map (gating signal)
        """
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi  # element-wise attention

In [None]:
class UNet2D_Attention(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        # Encoder
        self.layer1 = DoubleConv2D(in_channels, 32)
        self.layer2 = DoubleConv2D(32, 64)
        self.layer3 = DoubleConv2D(64, 128)
        self.layer4 = DoubleConv2D(128, 256)

        # Attention Gates
        self.att3 = AttentionGate2D(F_g=128, F_l=128, F_int=64)
        self.att2 = AttentionGate2D(F_g=64, F_l=64, F_int=32)
        self.att1 = AttentionGate2D(F_g=32, F_l=32, F_int=16)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.layer5 = DoubleConv2D(128 + 128, 128)

        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.layer6 = DoubleConv2D(64 + 64, 64)

        self.upconv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.layer7 = DoubleConv2D(32 + 32, 32)

        # Output
        self.layer8 = nn.Conv2d(32, out_channels, kernel_size=1)

        # Pooling
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.layer1(x)                 # (32)
        x2 = self.layer2(self.maxpool(x1))  # (64)
        x3 = self.layer3(self.maxpool(x2))  # (128)
        x4 = self.layer4(self.maxpool(x3))  # (256)

        # Decoder level 1
        x5 = self.upconv1(x4)               # (128)
        x3 = self.att3(x3, x5)              # attention on skip from encoder
        x5 = torch.cat([x5, x3], dim=1)
        x5 = self.layer5(x5)

        # Decoder level 2
        x6 = self.upconv2(x5)               # (64)
        x2 = self.att2(x2, x6)
        x6 = torch.cat([x6, x2], dim=1)
        x6 = self.layer6(x6)

        # Decoder level 3
        x7 = self.upconv3(x6)               # (32)
        x1 = self.att1(x1, x7)
        x7 = torch.cat([x7, x1], dim=1)
        x7 = self.layer7(x7)

        out = self.layer8(x7)
        return out

In [None]:
model = UNet2D_Attention()

random_input = torch.randn(1, 1, 256, 256)
output = model(random_input)
assert output.shape == random_input.shape

# **# 04 - train**

## **sampler**

In [None]:
target_list = []

for _, label in tqdm(train_dataset):
    if label.any():
        target_list.append(1)
    else:
        target_list.append(0)

In [None]:
from torch.utils.data import WeightedRandomSampler

target_list = np.array(target_list)

# Compute class counts
class_counts = np.bincount(target_list)   # [count_0, count_1]
class_weights = 1. / class_counts         # inverse of frequency

# Assign each sample a weight
sample_weights = [class_weights[t] for t in target_list]
sample_weights = torch.DoubleTensor(sample_weights)

# Create sampler
sampler = WeightedRandomSampler(
    sample_weights,
    num_samples=len(sample_weights),  # you can choose fewer if you want
    replacement=True
)

## **dataloader**

In [None]:
train_loader = DataLoader(
    train_dataset, 
    sampler=sampler,
    num_workers=2, 
    batch_size=8
)

In [None]:
val_loader = DataLoader(
    val_dataset, 
    num_workers=2, 
    batch_size=8
)

In [None]:
print(len(val_loader), len(train_loader))

batch = next(iter(val_loader))
print(type(batch)) 

imgs, masks = batch
print(imgs.shape, masks.shape)

## **pytorch lightning model**

In [None]:
import pytorch_lightning as pl

class LungTumorModel(pl.LightningModule):
    def __init__(self, lr=1e-4):
        super().__init__()
        self.model = UNet2D_Attention()
        self.lr = lr
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self(imgs)
        loss = self.loss_fn(preds.float(), masks.float())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        preds = self(imgs)
        loss = self.loss_fn(preds.float(), masks.float())
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

In [None]:
model = LungTumorModel()

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint 

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',        
    save_top_k=1,        
    mode='min',
    filename="best_model"
)

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs"),
    log_every_n_steps=10,
    callbacks=checkpoint_callback,
    max_epochs=30
)

In [None]:
trainer.fit(model, train_loader, val_loader)

# **# 05 - evaluation**

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = LungTumorModel.load_from_checkpoint("/kaggle/input/checkpoints/LungTumor_BCE_CustomUNET2D.ckpt")
model.to(device);
model.freeze();
model.eval();

## **iou sensitivity specificty**

In [None]:
def compute_iou_sens_spec(preds, targets, threshold=0.5, eps=1e-6):
    """
    Compute mean IoU, sensitivity, and specificity over a batch of predictions and targets.

    Args:
        preds (torch.Tensor): Model outputs, shape [B, 1, H, W]
        targets (torch.Tensor): Ground truth masks, shape [B, 1, H, W]
        threshold (float): Binarization threshold for predictions
        eps (float): Small epsilon to prevent division by zero

    Returns:
        dict: {'iou': float, 'sensitivity': float, 'specificity': float}
    """
    # Binarize predictions
    preds = (preds > threshold).float()

    # Flatten spatial dimensions
    preds_flat = preds.view(preds.size(0), -1)
    targets_flat = targets.view(targets.size(0), -1)

    # True positives, false positives, false negatives, true negatives
    TP = (preds_flat * targets_flat).sum(dim=1)
    FP = (preds_flat * (1 - targets_flat)).sum(dim=1)
    FN = ((1 - preds_flat) * targets_flat).sum(dim=1)
    TN = ((1 - preds_flat) * (1 - targets_flat)).sum(dim=1)

    # IoU
    iou = (TP + eps) / (TP + FP + FN + eps)

    # Sensitivity (Recall)
    sensitivity = (TP + eps) / (TP + FN + eps)

    # Specificity
    specificity = (TN + eps) / (TN + FP + eps)

    # Return mean across batch
    return {
        'iou': iou.mean().item(),
        'sensitivity': sensitivity.mean().item(),
        'specificity': specificity.mean().item()
    }

In [None]:
total_iou = 0
total_sens = 0
total_spec = 0
num_batches = 0

with torch.no_grad():
    for images, masks in tqdm(val_loader):
        
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        metrics = compute_iou_sens_spec(outputs.float(), masks.float())

        total_iou += metrics['iou']
        total_sens += metrics['sensitivity']
        total_spec += metrics['specificity']
        num_batches += 1

mean_iou = total_iou / num_batches
mean_sens = total_sens / num_batches
mean_spec = total_spec / num_batches

print(f"Mean IoU: {mean_iou:.4f}")
print(f"Mean Sensitivity: {mean_sens:.4f}")
print(f"Mean Specificity: {mean_spec:.4f}")


## **result visualization**

In [None]:
patient_path = get_patient_path("57", False)
img_path = get_img_path(patient_path)
label_path = get_label_path(patient_path)

sorted_img_files = get_sorted_slice_paths(img_path)
sorted_mask_files = get_sorted_slice_paths(label_path)

img_volume = np.stack([np.load(f) for f in sorted_img_files], axis=-1).astype(np.float32)  # (H, W, D)
mask_volume = np.stack([np.load(f) for f in sorted_mask_files], axis=-1).astype(np.float32)

# --------------------------
# Wrap in TorchIO Subject
# --------------------------
subject = tio.Subject(
    image=tio.ScalarImage(tensor=torch.from_numpy(img_volume).unsqueeze(0)),  # (1, H, W, D)
    mask=tio.LabelMap(tensor=torch.from_numpy(mask_volume).unsqueeze(0))
)

# --------------------------
# Apply validation transform
# --------------------------
val_transform = tio.Compose([
    tio.ToCanonical(),
    tio.RescaleIntensity((-1, 1)),
])


subject = val_transform(subject)
img_volume = subject.image.data      # shape (1, 256, 256, D)
mask_volume = subject.mask.data      # shape (1, 256, 256, D)

# --------------------------
# Reorder for model: (D, 1, H, W)
# --------------------------
img_volume = img_volume.permute(3, 0, 1, 2).contiguous()  # (D, 1, H, W)
mask_volume = mask_volume.permute(3, 0, 1, 2).contiguous()

img_volume = img_volume.to(device)
mask_volume = mask_volume.to(device)

print("Prepared volume shape:", img_volume.shape)

In [None]:
chunk_size = 50
preds_list = []

# Iterate over the depth dimension (0th axis)
for start in range(0, img_volume.shape[0], chunk_size):
    end = min(start + chunk_size, img_volume.shape[0])
    
    # Slice depth chunk → shape [chunk, 1, H, W]
    chunk = img_volume[start:end]  
    
    # Move to device
    chunk = chunk.to(device).float()
    
    with torch.no_grad():
        preds_chunk = torch.sigmoid(model(chunk))  # [chunk, 1, H, W]
    
    preds_list.append(preds_chunk.cpu())  # move to CPU immediately

# Concatenate predictions along depth axis
preds = torch.cat(preds_list, dim=0)  # shape: [D, 1, H, W]

# Binarize
preds_bin = (preds > 0.5).float()

print("Unique values in preds_bin:", torch.unique(preds_bin))

# Rearrange to (H, W, D)
pred_mask = preds_bin.permute(2, 3, 0, 1).squeeze(-1)  # (H, W, D)

# Move image to same shape for visualization
img_volume_np = img_volume.permute(2, 3, 0, 1).squeeze(-1).squeeze(1).cpu().numpy()
mask_np = mask_volume.squeeze(1).permute(1, 2, 0).cpu().numpy()
pred_mask_np = pred_mask.cpu().numpy() 

print("Shapes:")
print("img_volume:", img_volume_np.shape)
print("mask_volume:", mask_np.shape)
print("pred_mask:", pred_mask_np.shape)

In [None]:
fig = plt.figure()
camera = Camera(fig)

for i in range(img_volume_np.shape[2]):
    
    plt.imshow(img_volume_np[:, :, i], cmap="bone")
    
    mask_ = np.ma.masked_where(mask_np[:, :, i] == 0, mask_np[:, :, i])
    plt.imshow(mask_, cmap="autumn", alpha=0.5)
    
    pred_ = np.ma.masked_where(pred_mask_np[:, :, i] == 0, pred_mask_np[:, :, i])
    plt.imshow(pred_, cmap="winter", alpha=0.5)
    
    camera.snap()

animation = camera.animate();
HTML(animation.to_html5_video())