# CT-Based Lung Nodule Segmentation using 2D U-Net

### Objective
The objective of this project is to **reduce false negatives in lung cancer detection**
by performing **pixel-level lung nodule segmentation** on CT scan slices.

Since missing a malignant nodule has severe clinical consequences,
this work **prioritizes recall (sensitivity)** over accuracy.

---

### Key Highlights
- 2D U-Netâ€“based segmentation pipeline
- Multi-radiologist annotation fusion (union strategy)
- Recall-optimized training and early stopping
- Patch-based learning with padding for variable slice sizes
- Pixel-level explainability using segmentation overlays


In [None]:
%pip install torch torchvision numpy matplotlib tqdm opencv-python kagglehub pandas


In [35]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2


In [None]:
BASE_DIR = os.getcwd()   
# Root directory containing LIDC-IDRI slice-wise data
# Each patient folder contains nodule-wise subfolders with images and masks
ROOT = os.path.join(BASE_DIR, "LIDC-IDRI-slices")

OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
MODEL_DIR = os.path.join(OUTPUT_DIR, "models")
OVERLAY_DIR = os.path.join(OUTPUT_DIR, "overlays")

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(OVERLAY_DIR, exist_ok=True)

print("Dataset root:", ROOT)
print("Patients found:", len(os.listdir(ROOT)))


Dataset root: /Users/gottipalligopi/Documents/LungNoduleSegmentation/LIDC-IDRI-slices
Patients found: 875


## Preprocessing and Patch Extraction

- Lung windowing is applied to enhance nodule visibility
- No aggressive filtering is used to preserve subtle nodules
- Training is performed on **128Ã—128 patches**
- CT slices smaller than patch size are **zero-padded**, not discarded

This strategy avoids data loss while supporting variable CT resolutions.


In [24]:
def lung_window(ct, center=-600, width=1500):
    min_v = center - width // 2
    max_v = center + width // 2
    ct = np.clip(ct, min_v, max_v)
    ct = (ct - min_v) / (max_v - min_v)
    return ct.astype(np.float32)


## Dataset: LIDC-IDRI

- Source: **LIDC-IDRI**
- Format: Slice-wise CT images (PNG)
- Each slice contains **up to four independent radiologist annotations**
- Nodules are small, sparse, and often ambiguous

### Annotation Handling
To reduce false negatives caused by inter-observer variability,
annotations from all radiologists are combined using a **union (OR)** strategy.

> A pixel is considered a nodule if **any radiologist** marked it as such.


In [None]:
class LIDCDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.samples = []
        self.patch = patch_size

        for patient in tqdm(os.listdir(root_dir)):
            p_path = os.path.join(root_dir, patient)
            if not os.path.isdir(p_path):
                continue

            for nodule in os.listdir(p_path):
                case = os.path.join(p_path, nodule)
                if not os.path.isdir(case):
                    continue

                img_dir = os.path.join(case, "images")
                if not os.path.isdir(img_dir):
                    continue

                mask_dirs = [os.path.join(case, f"mask-{i}") for i in range(4)]

                for f in sorted(os.listdir(img_dir)):
                    if not f.endswith(".png"):
                        continue   # ðŸ”‘ FIX

                    # ---- Load image ----
                    img_path = os.path.join(img_dir, f)
                    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                    if img is None:
                        continue

                    img = img.astype(np.float32)
                    img = lung_window(img)

                    # ---- Mask union ----

                    # Combine annotations from multiple radiologists using union (OR)
                    # A pixel is considered nodule if ANY radiologist marked it
                    # This reduces false negatives caused by inter-observer variability

                    mask_union = np.zeros_like(img, dtype=np.uint8)
                    for md in mask_dirs:
                        mask_path = os.path.join(md, f)
                        if os.path.exists(mask_path):
                            m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                            if m is not None:
                                mask_union |= (m > 0)

                    # Skip slices without any nodule annotation
                    # Keep even 1-pixel nodules to avoid missing subtle cancers
                    if mask_union.sum() == 0:
                        continue

                    # ---- ensure patch fits ----
                    if img.shape[0] < self.patch or img.shape[1] < self.patch:
                        continue

                    self.samples.append((img, mask_union.astype(np.float32)))

        print("Total slices with nodules:", len(self.samples))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, mask = self.samples[idx]
        h, w = img.shape

        # Pad slices smaller than patch size instead of discarding
        # Prevents unnecessary data loss and supports small CT resolutions
        pad_h = max(0, self.patch - h)
        pad_w = max(0, self.patch - w)

        if pad_h > 0 or pad_w > 0:
            img = np.pad(
                img,
                ((0, pad_h), (0, pad_w)),
                mode="constant"
            )
            mask = np.pad(
                mask,
                ((0, pad_h), (0, pad_w)),
                mode="constant"
            )

        h, w = img.shape  # update after padding

        # -------- SAFE RANDOM CROP --------
        max_x = h - self.patch
        max_y = w - self.patch
       
        # Random patch extraction to increase spatial diversity
        # Ensures model does not overfit to fixed nodule locations
        x = np.random.randint(0, max_x + 1)
        y = np.random.randint(0, max_y + 1)

        img = img[x:x+self.patch, y:y+self.patch]
        mask = mask[x:x+self.patch, y:y+self.patch]

        return (
            torch.tensor(img).unsqueeze(0),
            torch.tensor(mask).unsqueeze(0)
        )


In [41]:
dataset = LIDCDataset(ROOT)
print("Dataset length:", len(dataset))

x, y = dataset[0]
print("Image shape:", x.shape)
print("Mask shape:", y.shape)
print("Mask pixels:", y.sum())


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 875/875 [00:09<00:00, 94.21it/s] 

Total slices with nodules: 15486
Dataset length: 15486
Image shape: torch.Size([1, 128, 128])
Mask shape: torch.Size([1, 128, 128])
Mask pixels: tensor(18.)





In [None]:
VAL_RATIO = 0.2
val_size = int(len(dataset) * VAL_RATIO)
train_size = len(dataset) - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

print("Train samples:", len(train_ds))
print("Val samples:", len(val_ds))


## Model Architecture: 2D U-Net

A standard **2D U-Net** architecture is used for segmentation.

### Why U-Net?
- Encoderâ€“decoder structure captures context
- Skip connections preserve spatial detail
- Proven effectiveness in medical image segmentation tasks

The final layer uses a **sigmoid activation** to produce a pixel-wise probability map.


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

# Standard 2D U-Net architecture
# Encoder-decoder with skip connections for spatial detail preservation
# Chosen due to proven effectiveness in medical image segmentation
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DoubleConv(1, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

        self.u1 = DoubleConv(256+128, 128)
        self.u2 = DoubleConv(128+64, 64)

        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        c1 = self.d1(x)
        c2 = self.d2(self.pool(c1))
        c3 = self.d3(self.pool(c2))

        u1 = self.up(c3)
        u1 = self.u1(torch.cat([u1, c2], dim=1))

        u2 = self.up(u1)
        u2 = self.u2(torch.cat([u2, c1], dim=1))

        return torch.sigmoid(self.out(u2))


In [None]:
# Dice loss is used to handle severe class imbalance
# More suitable than cross-entropy for small object segmentation
def dice_score(pred, target, eps=1e-6, thresh=0.35):
    pred = (pred > thresh).float()
    target = target.float()

    intersection = (pred * target).sum()
    dice = (2 * intersection + eps) / (pred.sum() + target.sum() + eps)
    return dice.item()

# Lower threshold used intentionally to favor recall
# Missing a nodule (false negative) is clinically more costly than false positives
def recall_score(pred, target, eps=1e-6, thresh=0.35):
    pred = (pred > thresh).float()
    target = target.float()

    tp = (pred * target).sum()
    fn = ((1 - pred) * target).sum()

    recall = (tp + eps) / (tp + fn + eps)
    return recall.item()


def confusion_matrix_counts(pred, target, thresh=0.35):
    pred = (pred > thresh).float()
    target = target.float()

    tp = (pred * target).sum().item()
    fp = (pred * (1 - target)).sum().item()
    fn = ((1 - pred) * target).sum().item()
    tn = ((1 - pred) * (1 - target)).sum().item()

    return tp, fp, fn, tn


In [48]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

EPOCHS = 20          # upper bound
PATIENCE = 5         # early stopping patience
best_val_recall = 0
patience_counter = 0


Using device: mps


## Loss Function and Training Strategy

### Loss Function
- **Dice Loss** is used to handle severe class imbalance
- More suitable than cross-entropy for small object segmentation

### False-Negativeâ€“Aware Design
- A lower segmentation threshold (0.35) is used
- Validation **recall** is prioritized over accuracy
- Early stopping is based on recall improvement


## Early Stopping Criterion

Early stopping is applied based on **validation recall**, not loss.

This ensures that training stops only when the modelâ€™s ability
to detect nodules (sensitivity) no longer improves.

This choice reflects the **clinical cost of false negatives**.


In [None]:
# Training loop with recall-based early stopping
# Validation recall is monitored instead of loss to reduce false negatives
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    # -------- TRAIN --------
    model.train()
    train_loss = 0.0

    train_bar = tqdm(
        train_loader,
        desc="Training",
        leave=False
    )

    for x, y in train_bar:
        x, y = x.to(device), y.to(device)
        pred = model(x)

        loss = dice_loss(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        train_bar.set_postfix(
            loss=f"{loss.item():.4f}"
        )

    train_loss /= len(train_loader)

    # -------- VALIDATE --------
    model.eval()
    val_recall = 0.0

    val_bar = tqdm(
        val_loader,
        desc="Validation",
        leave=False
    )

    with torch.no_grad():
        for x, y in val_bar:
            x, y = x.to(device), y.to(device)
            pred = model(x)

            r = recall(pred, y).item()
            val_recall += r

            val_bar.set_postfix(
                recall=f"{r:.4f}"
            )

    val_recall /= len(val_loader)

    # -------- EPOCH SUMMARY --------
    print(
        f"Epoch {epoch+1:02d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Recall: {val_recall:.4f}"
    )

    # Save model only when validation recall improves
    # Prevents overfitting while maintaining sensitivity
    if val_recall > best_val_recall:
        best_val_recall = val_recall
        patience_counter = 0

        torch.save(
            model.state_dict(),
            os.path.join(MODEL_DIR, "best_unet.pth")
        )
        print("âœ“ Improved â€” model saved")

    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{PATIENCE})")

    if patience_counter >= PATIENCE:
        print("ðŸ›‘ Early stopping triggered")
        break


In [None]:
model.load_state_dict(
    torch.load(os.path.join(MODEL_DIR, "best_unet.pth"), map_location=device)
)
model.eval()
print("Best model loaded")


In [None]:
img, mask = dataset[0]

with torch.no_grad():
    pred = model(img.unsqueeze(0).to(device)).cpu()[0, 0]

plt.figure(figsize=(5,5))
plt.imshow(img[0], cmap="gray")
plt.imshow(pred > 0.35, alpha=0.4, cmap="Reds")
plt.axis("off")

overlay_path = os.path.join(OVERLAY_DIR, "sample_overlay.png")
plt.savefig(overlay_path, bbox_inches="tight")
plt.show()

print("Overlay saved to:", overlay_path)


## Evaluation Metrics

The model is evaluated using **pixel-level metrics**:

- **Dice Score** â€“ segmentation overlap quality
- **Recall (Sensitivity)** â€“ ability to detect nodules
- **Confusion Matrix (TP, FP, FN, TN)** â€“ error analysis

Recall is emphasized due to the false-negativeâ€“critical nature of lung cancer detection.


In [None]:
model.eval()

total_dice = 0.0
total_recall = 0.0

TP = FP = FN = TN = 0

with torch.no_grad():
    for x, y in val_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)

        # Dice & Recall
        total_dice += dice_score(pred, y)
        total_recall += recall_score(pred, y)

        # Confusion matrix
        tp, fp, fn, tn = confusion_matrix_counts(pred, y)
        TP += tp
        FP += fp
        FN += fn
        TN += tn

# Average scores
avg_dice = total_dice / len(val_loader)
avg_recall = total_recall / len(val_loader)


# Final evaluation performed using Dice, Recall, and Confusion Matrix
# Metrics reported at pixel level
print("==== FINAL VALIDATION METRICS ====")
print(f"Dice Score   : {avg_dice:.4f}")
print(f"Recall       : {avg_recall:.4f}")

print("\nConfusion Matrix (pixel-level):")
print(f"TP: {TP}")
print(f"FP: {FP}")
print(f"FN: {FN}")
print(f"TN: {TN}")


==== FINAL VALIDATION METRICS ====
Dice Score   : 0.7314
Recall       : 0.7597

Confusion Matrix (pixel-level):
TP: 377968.0
FP: 126516.0
FN: 120963.0
TN: 50115801.0


In [None]:
import pandas as pd

cm = pd.DataFrame(
    [[TP, FP],
     [FN, TN]],
    columns=["Predicted Nodule", "Predicted Background"],
    index=["Actual Nodule", "Actual Background"]
)

cm


Unnamed: 0,Predicted Nodule,Predicted Background
Actual Nodule,377968.0,126516.0
Actual Background,120963.0,50115801.0


## Results and Interpretation

Final validation performance:

- **Dice Score:** 0.73
- **Recall:** 0.76

These results indicate strong sensitivity to lung nodules while maintaining
reasonable segmentation quality.

The Dice score reflects realistic performance given:
- Small nodule size
- Sparse annotations
- Inter-observer variability in LIDC-IDRI
