In [25]:
# Import required libraries
from segmentation_pipeline import nnUNet, nnUNetConfidence
from lungmask import LMInferer
from segmentation_pipeline import pydicom_to_nifti
from segmentation_pipeline import apply_windowing
from segmentation_pipeline import random_pad_3d_box
import torch
import torch.nn.functional as F
import cc3d
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import pydicom
import os
from pathlib import Path

## Helper Functions

In [26]:
def get_dicom_files_with_slice_locations(directory_path):
    """
    Loop through all DICOM files in a directory and extract their paths and slice locations.
    
    Args:
        directory_path (str): Path to the directory containing DICOM files
        
    Returns:
        tuple: (list of file paths, list of corresponding slice locations)
    """
    dicom_paths = []
    slice_locations = []
    
    # Convert to Path object for easier handling
    dir_path = Path(directory_path)
    
    # Check if directory exists
    if not dir_path.exists():
        raise ValueError(f"Directory does not exist: {directory_path}")
    
    # Loop through all files in directory (including subdirectories)
    for file_path in dir_path.rglob('*'):
        # Skip directories
        if file_path.is_dir():
            continue
            
        try:
            # Try to read as DICOM file
            dcm = pydicom.dcmread(str(file_path), stop_before_pixels=True)
            
            # Check if SliceLocation exists in metadata
            if hasattr(dcm, 'Slice Location'):
                dicom_paths.append(str(file_path))
                slice_locations.append(float(dcm['Slice Location'].value))
            elif hasattr(dcm, 'ImagePositionPatient'):
                # Alternative: use Z coordinate from ImagePositionPatient if SliceLocation not available
                dicom_paths.append(str(file_path))
                slice_locations.append(float(dcm.ImagePositionPatient[2]))
        except Exception as e:
            # Skip files that are not valid DICOM files
            continue
    
    return dicom_paths, slice_locations

In [27]:
def order_slices(img_paths, slice_locations, reverse=False):
    sorted_ids = np.argsort(slice_locations)
    if reverse:
        sorted_ids = sorted_ids[::-1]
    sorted_img_paths = np.array(img_paths)[sorted_ids].tolist()
    sorted_slice_locs = np.sort(slice_locations).tolist()

    return sorted_img_paths, sorted_slice_locs

## Load Models

In [28]:
# Load segmentation model checkpoint
segmentation_model_checkpoint = torch.load(
    "/data/rbg/scratch/lung_ct/checkpoints/5678b14bb8a563a32f448d19a7d12e6b/last.ckpt",
    weights_only=False
)

new_segmentation_model_state_dict = {}
for k, v in segmentation_model_checkpoint["state_dict"].items():
    if "classifier" not in k:
        new_k = k.replace("model.model", "model")  
        new_segmentation_model_state_dict[new_k] = v

In [29]:
# Load confidence model checkpoint
confidence_model_checkpoint = torch.load(
    "/data/rbg/scratch/lung_ct/checkpoints/4296b4b6cda063e96d52aabfb0694a04/4296b4b6cda063e96d52aabfb0694a04epoch=9.ckpt",
    weights_only=False
)

new_confidence_model_state_dict = {}
for k, v in confidence_model_checkpoint["state_dict"].items():
    new_k = k.replace("model.model", "model")  
    if "model.classifier" in new_k:
        new_k = new_k.replace("model.classifier", "classifier")
    new_confidence_model_state_dict[new_k] = v

In [30]:
# Initialize models
segmentation_model = nnUNet(
    segmentation_model_checkpoint["hyper_parameters"]["args"]
)
segmentation_model.load_state_dict(new_segmentation_model_state_dict)

confidence_model = nnUNetConfidence(
    confidence_model_checkpoint["hyper_parameters"]["args"]
)
confidence_model.load_state_dict(new_confidence_model_state_dict)

# Load lungmask model
model = LMInferer(
    modelpath="/data/rbg/users/pgmikhael/current/lungmask/checkpoints/unet_r231-d5d2fc3d.pth",
    tqdm_disable=True,
    batch_size=100,
    force_cpu=False,
)

# Set to eval mode
segmentation_model.eval()
confidence_model.eval()

nnUNetConfidence(
  (model): ResidualEncoder(
    (stem): StackedConvBlocks(
      (convs): Sequential(
        (0): ConvDropoutNormReLU(
          (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
          (all_modules): Sequential(
            (0): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
    )
    (stages): Sequential(
      (0): StackedResidualBlocks(
        (blocks): Sequential(
          (0): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (norm

## Load Image Data

In [None]:
# Option 1: Load from DICOM directory (commented out)
dicom_dir = "/data/rbg/shared/datasets/NLST/NLST/all_nlst-ct/set2/batch1/102676/T0/1.2.840.113654.2.55.106468547949258489874106374248199128625/"
img_paths, slice_locations = get_dicom_files_with_slice_locations(dicom_dir)
sorted_img_paths, sorted_slice_locs = order_slices(img_paths, slice_locations, reverse=False)
depth = len(sorted_img_paths)
# sorted_img_paths = sorted_img_paths[depth//2 - 10: depth//2 + 10]
print(f"Number of slices found: {len(sorted_img_paths)}")
image = pydicom_to_nifti(
    sorted_img_paths,
    return_nifti=False, save_nifti=False,
    output_path="buffer",
)


# # Option 2: Load from numpy file
# voxel_spacing = [0.8, 0.8, 1.5]  # y, x, z
# affine = torch.diag(torch.tensor(voxel_spacing + [1]))

# image = np.load("image_array_depth20.npy")
# print(f"Image shape: {image.shape}, dtype: {image.dtype}")
# print(f"Original image type: {type(image)}, shape: {image.shape}")

Image shape: (512, 512, 20), dtype: float64
Original image type: <class 'numpy.ndarray'>, shape: (512, 512, 20)


## Generate Lung Mask

In [32]:
# Run lung mask - transpose to put depth first
image_ = np.transpose(image, (2, 0, 1))
lung_mask = model.apply(image_)
print(f"Lung mask shape: {lung_mask.shape}")
print(f"Lung mask unique values: {np.unique(lung_mask)}")

mid_slice = lung_mask.shape[0] // 2

lungmask 2025-11-27 17:56:37 Postprocessing
Lung mask shape: (20, 512, 512)
Lung mask unique values: [0 1 2]


## Preprocess Image

In [33]:
# Apply windowing
image = apply_windowing(image.astype(np.float64), -600, 1600)
print(f"image type after windowing: {type(image)}, shape: {image.shape}")

# Convert to tensor and normalize
image = torch.tensor(image) // 256
image = image.permute(2, 0, 1).unsqueeze(0).unsqueeze(0)  # shape: [1, 1, D, H, W]
print(f"Image shape after unsqueezes: {image.shape}")

# Interpolate to target size
image = F.interpolate(
    image,
    size=(image.shape[2], 1024, 1024),
    mode="trilinear",
    align_corners=False,
)
image = image.squeeze(1)
image = image[None]
print(f"Image shape after interpolation: {image.shape}")

image type after windowing: <class 'numpy.ndarray'>, shape: (512, 512, 20)
Image shape after unsqueezes: torch.Size([1, 1, 20, 512, 512])
Image shape after interpolation: torch.Size([1, 1, 20, 1024, 1024])


In [34]:
# Interpolate lung mask
lung_mask = torch.tensor(lung_mask).unsqueeze(1)
lung_mask = F.interpolate(
    lung_mask,
    size=(1024, 1024),
    mode="nearest-exact",
)
lung_mask = lung_mask.squeeze()
print(f"Lung mask shape after interpolation: {lung_mask.shape}")

Lung mask shape after interpolation: torch.Size([20, 1024, 1024])


## Run Segmentation Model

In [None]:
# Run segmentation
with torch.no_grad():
    segmentation_outputs = segmentation_model.predict(image.float())

print(f"segmentation unique values: {torch.unique(segmentation_outputs)}")

# Create binary segmentation
binary_segmentation = (
    1 * (F.softmax(segmentation_outputs, 1)[0, 1] > 0.5) * lung_mask
)
print(f"Binary segmentation shape: {binary_segmentation.shape}")

## Extract Connected Components

In [None]:
# Get connected components
instance_segmentation, num_instances = cc3d.connected_components(
    binary_segmentation.cpu().numpy(),
    return_N=True,
)
print(f"Number of instances found: {num_instances}")
print(f"Instance segmentation shape: {instance_segmentation.shape}")

# Convert to sparse tensor
sparse_segmentation = torch.tensor(instance_segmentation, dtype=torch.int32).to_sparse()
print(f"Sparse segmentation indices shape: {sparse_segmentation.indices().shape}")

Number of instances found: 11
Instance segmentation shape: (20, 1024, 1024)
Sparse segmentation indices shape: torch.Size([3, 2587])


In [None]:
# Reshape image for patch extraction
image = image.squeeze(0).squeeze(0).permute(1, 2, 0)  # shape: H, W, D
print(f"Image shape: {image.shape}")

Image shape: torch.Size([1024, 1024, 20])


## Extract Patches for Each Instance

In [None]:
patches = []
patch_sizes = []  # Track sizes to determine max dimensions
temp_patches = []

# First pass: extract patches and track sizes
for inst_id in range(1, num_instances + 1):
    zs, ys, xs = sparse_segmentation.indices()[
        :, sparse_segmentation.values() == inst_id
    ]
    box = {
        "x_start": torch.min(xs).item(),
        "x_stop": torch.max(xs).item(),
        "y_start": torch.min(ys).item(),
        "y_stop": torch.max(ys).item(),
        "z_start": torch.min(zs).item(),
        "z_stop": torch.max(zs).item(),
    }
    patch = torch.zeros_like(image)
    print(f"patch shape: {patch.shape}")
    patch[
        box["y_start"] : box["y_stop"] + 1,
        box["x_start"] : box["x_stop"] + 1,
        box["z_start"] : box["z_stop"] + 1,
    ] = binary_segmentation[
        box["z_start"] : box["z_stop"] + 1,
        box["y_start"] : box["y_stop"] + 1,
        box["x_start"] : box["x_stop"] + 1,
    ].permute(1, 2, 0)
    cbbox = random_pad_3d_box(
        box,
        image,
        min_height=128,
        min_width=128,
        min_depth=32,  # Increased to 32 to ensure enough depth for network layers
        random_hw=False,
        random_d=False,
    )
    patchx = image[cbbox]
    patchl = patch[cbbox]
    print(f"patchx shape: {patchx.shape}, patchl shape: {patchl.shape}")
    temp_patches.append((patchx, patchl))
    patch_sizes.append(patchx.shape)

print(f"Extracted {len(temp_patches)} patches.")

patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])
patch shape: torch.Size([1024, 1024, 20])
patchx shape: torch.Size([128, 128, 20]), patchl shape: torch.Size([128, 128, 20])


## Pad Patches to Uniform Size

In [None]:
# Determine maximum dimensions across all patches
max_h = max(s[0] for s in patch_sizes)
max_w = max(s[1] for s in patch_sizes)
# max_d = max(s[2] for s in patch_sizes)
max_d = 32
print(f"Max patch dimensions: H={max_h}, W={max_w}, D={max_d}")

# Second pass: pad all patches to max dimensions
for patchx, patchl in temp_patches:
    # Pad to max dimensions
    pad_h = max_h - patchx.shape[0]
    pad_w = max_w - patchx.shape[1]
    pad_d = max_d - patchx.shape[2]
    
    if pad_h > 0 or pad_w > 0 or pad_d > 0:
        patchx = F.pad(patchx, (0, pad_d, 0, pad_w, 0, pad_h), mode='constant', value=0)
        patchl = F.pad(patchl, (0, pad_d, 0, pad_w, 0, pad_h), mode='constant', value=0)
    
    patch = torch.stack([patchx, patchl])
    print(f"Padded patch shape: {patch.shape}")
    patches.append(patch)

patches = torch.stack(patches)
print(f"Total patches shape: {patches.shape}")

Max patch dimensions: H=128, W=128, D=32
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Padded patch shape: torch.Size([2, 128, 128, 32])
Total patches shape: torch.Size([11, 2, 128, 128, 32])


## Run Confidence Model

In [None]:
# Run confidence model
with torch.no_grad():
    confidence_outputs = confidence_model(patches.float())

# print(f"Confidence outputs shape: {torch.as_tensor(confidence_outputs).shape}")
print(f"Confidence outputs: {confidence_outputs}")

logits = confidence_outputs['logit']
print(f"Logits shape: {logits.shape}")
# run softmax on logits to get confidence scores
confidence_scores = F.softmax(logits, dim=1)
print(f"confidence scores: {confidence_scores}")

AttributeError: 'Tensor' object has no attribute 'as_tensor'