In [1]:
# Import required libraries
from segmentation_pipeline import nnUNet, nnUNetConfidence
from lungmask import LMInferer
from segmentation_pipeline import pydicom_to_nifti
from segmentation_pipeline import apply_windowing
from segmentation_pipeline import random_pad_3d_box
import torch
import torch.nn.functional as F
import cc3d
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pydicom
import os
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


## Helper Functions

In [2]:
def get_dicom_files_with_slice_locations(directory_path):
    """
    Loop through all DICOM files in a directory and extract their paths and slice locations.
    
    Args:
        directory_path (str): Path to the directory containing DICOM files
        
    Returns:
        tuple: (list of file paths, list of corresponding slice locations)
    """
    dicom_paths = []
    slice_locations = []
    
    # Convert to Path object for easier handling
    dir_path = Path(directory_path)
    
    # Check if directory exists
    if not dir_path.exists():
        raise ValueError(f"Directory does not exist: {directory_path}")
    
    # Loop through all files in directory (including subdirectories)
    for file_path in dir_path.rglob('*'):
        # Skip directories
        if file_path.is_dir():
            continue
            
        try:
            # Try to read as DICOM file
            dcm = pydicom.dcmread(str(file_path), stop_before_pixels=True)
            
            # Check if SliceLocation exists in metadata
            if hasattr(dcm, 'Slice Location'):
                dicom_paths.append(str(file_path))
                slice_locations.append(float(dcm['Slice Location'].value))
            elif hasattr(dcm, 'ImagePositionPatient'):
                # Alternative: use Z coordinate from ImagePositionPatient if SliceLocation not available
                dicom_paths.append(str(file_path))
                slice_locations.append(float(dcm.ImagePositionPatient[2]))
        except Exception as e:
            # Skip files that are not valid DICOM files
            continue
    
    return dicom_paths, slice_locations

In [3]:
def order_slices(img_paths, slice_locations, reverse=False):
    sorted_ids = np.argsort(slice_locations)
    if reverse:
        sorted_ids = sorted_ids[::-1]
    sorted_img_paths = np.array(img_paths)[sorted_ids].tolist()
    sorted_slice_locs = np.sort(slice_locations).tolist()

    return sorted_img_paths, sorted_slice_locs

## Load Models

In [4]:
# Load segmentation model checkpoint
segmentation_model_checkpoint = torch.load(
    "/data/rbg/scratch/lung_ct/checkpoints/5678b14bb8a563a32f448d19a7d12e6b/last.ckpt",
    weights_only=False
)

new_segmentation_model_state_dict = {}
for k, v in segmentation_model_checkpoint["state_dict"].items():
    if "classifier" not in k:
        new_k = k.replace("model.model", "model")  
        new_segmentation_model_state_dict[new_k] = v

In [5]:
# Load confidence model checkpoint
confidence_model_checkpoint = torch.load(
    "/data/rbg/scratch/lung_ct/checkpoints/4296b4b6cda063e96d52aabfb0694a04/4296b4b6cda063e96d52aabfb0694a04epoch=9.ckpt",
    weights_only=False
)

new_confidence_model_state_dict = {}
for k, v in confidence_model_checkpoint["state_dict"].items():
    new_k = k.replace("model.model", "model")  
    if "model.classifier" in new_k:
        new_k = new_k.replace("model.classifier", "classifier")
    new_confidence_model_state_dict[new_k] = v

In [6]:
# Initialize models
segmentation_model = nnUNet(
    segmentation_model_checkpoint["hyper_parameters"]["args"]
)
segmentation_model.load_state_dict(new_segmentation_model_state_dict)

confidence_model = nnUNetConfidence(
    confidence_model_checkpoint["hyper_parameters"]["args"]
)
confidence_model.load_state_dict(new_confidence_model_state_dict)

# Load lungmask model
model = LMInferer(
    modelpath="/data/rbg/users/pgmikhael/current/lungmask/checkpoints/unet_r231-d5d2fc3d.pth",
    tqdm_disable=True,
    batch_size=100,
    force_cpu=False,
)

# Set to eval mode
segmentation_model.eval()
confidence_model.eval()

nnUNetConfidence(
  (model): ResidualEncoder(
    (stem): StackedConvBlocks(
      (convs): Sequential(
        (0): ConvDropoutNormReLU(
          (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
          (all_modules): Sequential(
            (0): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
    )
    (stages): Sequential(
      (0): StackedResidualBlocks(
        (blocks): Sequential(
          (0): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (norm

## Load Image Data

In [7]:
# Option 1: Load from DICOM directory (commented out)
dicom_dir = "/data/rbg/shared/datasets/NLST/NLST/all_nlst-ct/set2/batch1/102676/T0/1.2.840.113654.2.55.106468547949258489874106374248199128625/"
img_paths, slice_locations = get_dicom_files_with_slice_locations(dicom_dir)
sorted_img_paths, sorted_slice_locs = order_slices(img_paths, slice_locations, reverse=False)
depth = len(sorted_img_paths)
# sorted_img_paths = sorted_img_paths[depth//2 - 10: depth//2 + 10]
print(f"Number of slices found: {len(sorted_img_paths)}")
image = pydicom_to_nifti(
    sorted_img_paths,
    return_nifti=False, save_nifti=False,
    output_path="buffer",
)


# # Option 2: Load from numpy file
# voxel_spacing = [0.8, 0.8, 1.5]  # y, x, z
# affine = torch.diag(torch.tensor(voxel_spacing + [1]))

# image = np.load("image_array_depth20.npy")
# print(f"Image shape: {image.shape}, dtype: {image.dtype}")
# print(f"Original image type: {type(image)}, shape: {image.shape}")

Number of slices found: 118


## Generate Lung Mask

In [8]:
# Run lung mask - transpose to put depth first
image_ = np.transpose(image, (2, 0, 1))
lung_mask = model.apply(image_)
print(f"Lung mask shape: {lung_mask.shape}")
print(f"Lung mask unique values: {np.unique(lung_mask)}")

mid_slice = lung_mask.shape[0] // 2

lungmask 2025-12-03 19:38:18 Postprocessing
Lung mask shape: (118, 512, 512)
Lung mask unique values: [0 1 2]


## Preprocess Image

In [9]:
# Apply windowing
image = apply_windowing(image.astype(np.float64), -600, 1600)
print(f"image type after windowing: {type(image)}, shape: {image.shape}")

# Convert to tensor and normalize
image = torch.tensor(image) // 256
image = image.permute(2, 0, 1).unsqueeze(0).unsqueeze(0)  # shape: [1, 1, D, H, W]
print(f"Image shape after unsqueezes: {image.shape}")

# Interpolate to target size
image = F.interpolate(
    image,
    size=(image.shape[2], 1024, 1024),
    mode="trilinear",
    align_corners=False,
)
image = image.squeeze(1)
image = image[None]
print(f"Image shape after interpolation: {image.shape}")

image type after windowing: <class 'numpy.ndarray'>, shape: (512, 512, 118)
Image shape after unsqueezes: torch.Size([1, 1, 118, 512, 512])
Image shape after interpolation: torch.Size([1, 1, 118, 1024, 1024])


In [10]:
# Interpolate lung mask
lung_mask = torch.tensor(lung_mask).unsqueeze(1)
lung_mask = F.interpolate(
    lung_mask,
    size=(1024, 1024),
    mode="nearest-exact",
)
lung_mask = lung_mask.squeeze()
print(f"Lung mask shape after interpolation: {lung_mask.shape}")

Lung mask shape after interpolation: torch.Size([118, 1024, 1024])


## Run Segmentation Model

In [11]:
# Run segmentation
with torch.no_grad():
    segmentation_outputs = segmentation_model.predict(image.float())

print(f"segmentation unique values: {torch.unique(segmentation_outputs)}")

# Create binary segmentation
binary_segmentation = (
    1 * (F.softmax(segmentation_outputs, 1)[0, 1] > 0.5) * lung_mask
)
print(f"Binary segmentation shape: {binary_segmentation.shape}")

segmentation unique values: tensor([-201.7136, -201.0410, -198.1624,  ...,  198.1781,  201.0568,
         201.7296])
Binary segmentation shape: torch.Size([118, 1024, 1024])


## Extract Connected Components

In [12]:
# Get connected components
instance_segmentation, num_instances = cc3d.connected_components(
    binary_segmentation.cpu().numpy(),
    return_N=True,
)
# Get unique values from instance segmentation
unique_vals = np.unique(instance_segmentation)
print(f"Unique instance values: {unique_vals}")
print(f"Number of instances found: {num_instances}")
print(f"Instance segmentation shape: {instance_segmentation.shape}")

# Convert to sparse tensor
sparse_segmentation = torch.tensor(instance_segmentation, dtype=torch.int32).to_sparse()
print(f"Sparse segmentation indices shape: {sparse_segmentation.indices().shape}")

Unique instance values: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33]
Number of instances found: 33
Instance segmentation shape: (118, 1024, 1024)
Sparse segmentation indices shape: torch.Size([3, 2440])


In [13]:
# Reshape image for patch extraction
image = image.squeeze(0).squeeze(0).permute(1, 2, 0)  # shape: H, W, D
print(f"Image shape: {image.shape}")

Image shape: torch.Size([1024, 1024, 118])


## Extract Patches for Each Instance

In [None]:
patches = []
patch_sizes = []  # Track sizes to determine max dimensions
temp_patches = []

# First pass: extract patches and track sizes
for inst_id in range(1, num_instances + 1):
    zs, ys, xs = sparse_segmentation.indices()[
        :, sparse_segmentation.values() == inst_id
    ]
    box = {
        "x_start": torch.min(xs).item(),
        "x_stop": torch.max(xs).item(),
        "y_start": torch.min(ys).item(),
        "y_stop": torch.max(ys).item(),
        "z_start": torch.min(zs).item(),
        "z_stop": torch.max(zs).item(),
    }
    patch = torch.zeros_like(image)
    print(f"patch shape: {patch.shape}")
    patch[
        box["y_start"] : box["y_stop"] + 1,
        box["x_start"] : box["x_stop"] + 1,
        box["z_start"] : box["z_stop"] + 1,
    ] = binary_segmentation[
        box["z_start"] : box["z_stop"] + 1,
        box["y_start"] : box["y_stop"] + 1,
        box["x_start"] : box["x_stop"] + 1,
    ].permute(1, 2, 0)
    cbbox = random_pad_3d_box(
        box,
        image,
        min_height=128,
        min_width=128,
        min_depth=32,  # Increased to 32 to ensure enough depth for network layers
        random_hw=False,
        random_d=False,
    )
    patchx = image[cbbox]
    patchl = patch[cbbox]
    print(f"patchx shape: {patchx.shape}, patchl shape: {patchl.shape}")
    temp_patches.append((patchx, patchl))
    patch_sizes.append(patchx.shape)

print(f"Extracted {len(temp_patches)} patches.")

patch shape: torch.Size([1024, 1024, 118])
[0. 2.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 1.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 1.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 2.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 1.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 2.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 1.]
patchx shape: torch.Size([128, 128, 32]), patchl shape: torch.Size([128, 128, 32])
patch shape: torch.Size([1024, 1024, 118])
[0. 2.]
patchx shap

In [20]:
# Determine maximum dimensions across all patches
# Also save images of slices with most non-zero values
# Also get the number of pixels greater than zero in each patch
output_path = Path("patch_visualizations")
output_path.mkdir(exist_ok=True)
pixelCounts = {}

for inst_id, (patchx, patchl) in enumerate(temp_patches, start=1):
    # Find slice with most non-zero values
    non_zero_counts = (patchl > 0).sum(dim=(0, 1))  # Sum over H and W dimensions
    num_pixels_greater_than_zero = non_zero_counts.sum().item()
    pixelCounts[inst_id] = num_pixels_greater_than_zero
    max_slice_idx = torch.argmax(non_zero_counts).item()
    
    # Extract the slice
    img_slice = patchx[:, :, max_slice_idx].numpy()
    mask_slice = patchl[:, :, max_slice_idx].numpy()
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img_slice, cmap='gray')
    axes[0].set_title(f'Instance {inst_id} - Image (Slice {max_slice_idx})')
    axes[0].axis('off')
    
    # Mask
    axes[1].imshow(mask_slice, cmap='gray')
    axes[1].set_title(f'Instance {inst_id} - Mask (Slice {max_slice_idx})')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(img_slice, cmap='gray')
    axes[2].imshow(mask_slice, cmap='Reds', alpha=0.5)
    axes[2].set_title(f'Instance {inst_id} - Overlay (Slice {max_slice_idx})')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path / f'instance_{inst_id:03d}_slice_{max_slice_idx}.png', dpi=100, bbox_inches='tight')
    plt.close()

print(f"Saved visualizations to {output_path}")
print(f"Pixel counts greater than zero per instance: {pixelCounts}")

Saved visualizations to patch_visualizations
Pixel counts greater than zero per instance: {1: 49, 2: 46, 3: 224, 4: 68, 5: 2, 6: 113, 7: 274, 8: 1, 9: 11, 10: 76, 11: 1, 12: 400, 13: 4, 14: 125, 15: 1, 16: 2, 17: 9, 18: 217, 19: 390, 20: 2, 21: 3, 22: 1, 23: 156, 24: 2, 25: 117, 26: 3, 27: 1, 28: 42, 29: 1, 30: 2, 31: 9, 32: 2, 33: 93}


In [21]:
pixel_spacing = [0.703125, 0.703125, 2.5]
pixel_volume = pixel_spacing[0] * pixel_spacing[1] * pixel_spacing[2]
patch_volumes = {inst_id: count * pixel_volume for inst_id, count in pixelCounts.items()}
print(f"Patch volumes (in mm^3) per instance: {patch_volumes}")

Patch volumes (in mm^3) per instance: {1: 60.5621337890625, 2: 56.854248046875, 3: 276.85546875, 4: 84.04541015625, 5: 2.471923828125, 6: 139.6636962890625, 7: 338.653564453125, 8: 1.2359619140625, 9: 13.5955810546875, 10: 93.93310546875, 11: 1.2359619140625, 12: 494.384765625, 13: 4.94384765625, 14: 154.4952392578125, 15: 1.2359619140625, 16: 2.471923828125, 17: 11.1236572265625, 18: 268.2037353515625, 19: 482.025146484375, 20: 2.471923828125, 21: 3.7078857421875, 22: 1.2359619140625, 23: 192.81005859375, 24: 2.471923828125, 25: 144.6075439453125, 26: 3.7078857421875, 27: 1.2359619140625, 28: 51.910400390625, 29: 1.2359619140625, 30: 2.471923828125, 31: 11.1236572265625, 32: 2.471923828125, 33: 114.9444580078125}


## Pad Patches to Uniform Size

In [None]:
# Determine maximum dimensions across all patches
max_h = max(s[0] for s in patch_sizes)
max_w = max(s[1] for s in patch_sizes)
# max_d = max(s[2] for s in patch_sizes)
max_d = 32
print(f"Max patch dimensions: H={max_h}, W={max_w}, D={max_d}")

# Second pass: pad all patches to max dimensions
for patchx, patchl in temp_patches:
    # Pad to max dimensions
    pad_h = max_h - patchx.shape[0]
    pad_w = max_w - patchx.shape[1]
    pad_d = max_d - patchx.shape[2]
    
    if pad_h > 0 or pad_w > 0 or pad_d > 0:
        patchx = F.pad(patchx, (0, pad_d, 0, pad_w, 0, pad_h), mode='constant', value=0)
        patchl = F.pad(patchl, (0, pad_d, 0, pad_w, 0, pad_h), mode='constant', value=0)
    
    patch = torch.stack([patchx, patchl])
    print(f"Padded patch shape: {patch.shape}")
    patches.append(patch)

patches = torch.stack(patches)
print(f"Total patches shape: {patches.shape}")

Max patch dimensions: H=128, W=128, D=32
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded pa

## Run Confidence Model

In [16]:
# Run confidence model
with torch.no_grad():
    confidence_outputs = confidence_model(patches.float())

# print(f"Confidence outputs shape: {torch.as_tensor(confidence_outputs).shape}")
# print(f"Confidence outputs: {confidence_outputs}")

logits = confidence_outputs['logit']
print(f"Logits shape: {logits.shape}")
# run softmax on logits to get confidence scores
confidence_scores = F.softmax(logits, dim=1)
print(f"confidence scores: {confidence_scores}")

Logits shape: torch.Size([33, 2])
confidence scores: tensor([[0.7551, 0.2449],
        [0.8005, 0.1995],
        [0.7366, 0.2634],
        [0.6641, 0.3359],
        [0.6298, 0.3702],
        [0.5936, 0.4064],
        [0.6993, 0.3007],
        [0.6840, 0.3160],
        [0.6881, 0.3119],
        [0.5337, 0.4663],
        [0.7260, 0.2740],
        [0.7104, 0.2896],
        [0.5772, 0.4228],
        [0.5021, 0.4979],
        [0.5042, 0.4958],
        [0.6560, 0.3440],
        [0.6987, 0.3013],
        [0.5870, 0.4130],
        [0.6801, 0.3199],
        [0.6517, 0.3483],
        [0.6830, 0.3170],
        [0.7110, 0.2890],
        [0.7970, 0.2030],
        [0.7895, 0.2105],
        [0.6574, 0.3426],
        [0.7392, 0.2608],
        [0.4851, 0.5149],
        [0.5064, 0.4936],
        [0.5082, 0.4918],
        [0.6633, 0.3367],
        [0.6563, 0.3437],
        [0.7020, 0.2980],
        [0.7968, 0.2032]])
