In [100]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os

In [101]:
def get_nifti_nib(series_id: str, nifti_dir: str, is_label: bool):
    """Checks all subfolders and returns the matching Nibabel object (for `series_id`) within those subfolders."""
    
    assert any(["group" in f for f in os.listdir(nifti_dir)])

    for f in os.listdir(nifti_dir):
        if "group" in f and os.path.isdir(os.path.join(nifti_dir, f)):
            for dir_fname in os.listdir(os.path.join(nifti_dir, f)):
                if is_label:
                    if dir_fname == f"{series_id}_output.nii.gz":
                        return nib.load(os.path.join(nifti_dir, f, dir_fname))
                else:
                    if dir_fname == f"{series_id}_0000.nii.gz":
                        return nib.load(os.path.join(nifti_dir, f, dir_fname))
    
    raise FileNotFoundError

In [102]:
def process_nifti(series_id, slice_num, crop_box, image_dir, mask_dir, output_dir):
    """
    Extracts a slice from a 3D NIfTI image and mask, applies cropping, and 
    overlays red contours from the mask onto the image.

    Parameters:
        image_path (str): Path to the NIfTI image file.
        mask_path (str): Path to the NIfTI mask file.
        slice_num (int): Slice index to extract.
        crop_box (tuple): (x, y, width, height) defining the crop area.
        output_dir (str): Directory to save the output images.
    """
    # Load the NIfTI files
    image_nii = get_nifti_nib(series_id, image_dir, False)
    mask_nii = get_nifti_nib(series_id, mask_dir, True)

    # Get data arrays
    image_data = image_nii.get_fdata()  # Shape: (Z, Y, X) [NLST-specific]
    mask_data = mask_nii.get_fdata()    # Shape: (Z, Y, X) [NLST-specific]

    # Extract the specified slice
    image_slice = image_data[slice_num, :, :]
    mask_slice = mask_data[slice_num, :, :]

    print(image_slice.shape)
    print(mask_slice.shape)
    

    # Apply cropping (x, y, w, h)
    x, y, w, h = crop_box
    cropped_image = image_slice[y:y+h, x:x+w]
    cropped_mask = mask_slice[y:y+h, x:x+w]
    # cropped_image = image_slice[x:x+w, y:y+h]
    # cropped_mask = mask_slice[x:x+w, y:y+h]

    # Normalize image to [0, 255] for display
    cropped_image = (255 * (cropped_image - np.min(cropped_image)) / 
                     (np.max(cropped_image) - np.min(cropped_image))).astype(np.uint8)

    # Convert to RGB
    overlay_image = cv2.cvtColor(cropped_image, cv2.COLOR_GRAY2RGB)

    # Find contours
    contours, _ = cv2.findContours(cropped_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Draw contours in red (BGR: (0,0,255))
    cv2.drawContours(overlay_image, contours, -1, (255, 0, 0), 2 if w == 512 and h == 512 else 1)  # Red contour

    # Save the overlaid image
    plt.imsave(os.path.join(output_dir, f"{series_id}_slice{slice_num}_cropped_overlay.png"), overlay_image)
    plt.imsave(os.path.join(output_dir, f"{series_id}_slice{slice_num}_cropped.png"), cv2.cvtColor(cropped_image, cv2.COLOR_GRAY2RGB))

In [107]:
image_dir = "/data/scratch/erubel/nlst/niftis"
mask_dir = "/data/scratch/erubel/nlst/nnInteractive/bmp2d_nn_all_mask_min_3"
output_dir = "/data/scratch/erubel/nlst/v2/demo"

# pid = 126967

# image_mask_slices = [
#     (f"nlst_{pid}T0", 150),
#     (f"nlst_{pid}T1", 150),
#     (f"nlst_{pid}T2", 150),
# ]

# pid = 126967

# # (x, y)
# image_mask_slices = [
#     (f"nlst_{pid}T0", 69, (108, 395, 40, 40)), 
#     (f"nlst_{pid}T1", 66, (122, 376, 40, 40)),
#     (f"nlst_{pid}T2", 89, (118, 402, 40, 40)),
# ]

# pid = 115571

# # (x, y)
# image_mask_slices = [
#     # (f"nlst_{pid}T0", 77, (352, 174, 40, 40)), 
#     # (f"nlst_{pid}T1", 82, (352, 180, 40, 40)), 
#     (f"nlst_{pid}T2", 78, (364, 188, 40, 40)), 
# ]

# pid = 123521

# # (x, y)
# image_mask_slices = [
#     # (f"nlst_{pid}T0", 79, (60, 268, 40, 40)), 
#     # (f"nlst_{pid}T1", 83, (64, 274, 40, 40)), 
#     (f"nlst_{pid}T2", 87, (68, 264, 40, 40)), 
# ]

pid = 100012
image_mask_slices = [
    (f"nlst_{pid}T0", 125, (0, 0, 512, 512)), 
]

# pid = 204711
# image_mask_slices = [
#     (f"nlst_{pid}T0", 88, (0, 0, 512, 512)), 
# ]

# pid = 102581
# image_mask_slices = [
#     (f"nlst_{pid}T0", 58, (0, 0, 512, 512)), 
# ]

os.makedirs(output_dir, exist_ok=True)

for series_id, slice_num, crop_box in image_mask_slices:
    process_nifti(series_id, slice_num, crop_box, image_dir, mask_dir, output_dir)

(512, 512)
(512, 512)
