In [1]:
import os
import numpy as np
import nibabel as nib
import networkx as nx
from skimage.morphology import skeletonize, ball
from skimage.measure import regionprops, label as skimage_label # Renamed to avoid conflict
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree
from skan import csr
from scipy.sparse import coo_matrix
from scipy.ndimage import uniform_filter1d, binary_erosion, distance_transform_edt, binary_closing, binary_opening, binary_fill_holes
from scipy.signal import savgol_filter
from skimage.draw import line_nd
import pandas as pd

In [2]:
def load_nifti(file_path):
    """Load a NIfTI file and reorient it to RAS (Right-Anterior-Superior)."""
    nifti_img = nib.load(file_path)
    nifti_img = nib.as_closest_canonical(nifti_img)
    return nifti_img.get_fdata(), nifti_img.affine, nifti_img.header

def clean_mask(mask_bin):
    struct = ball(1)  
    mask_cleaned = binary_closing(mask_bin, structure=struct)
    mask_cleaned = binary_opening(mask_cleaned, structure=struct)
    mask_cleaned = binary_fill_holes(mask_cleaned)
    return mask_cleaned.astype(np.uint8)

def get_binary_mask(mask_data, label_val=2):
    """Extract a binary mask for a specific label."""
    mask_data = np.round(mask_data).astype(np.uint8)
    return (mask_data == label_val).astype(np.uint8)

def save_skeleton_nifti(skeleton_array, affine_mat, output_path_str):
    """Save the skeletonized mask as a NIfTI file."""
    if np.sum(skeleton_array) == 0:
        print(f"⚠ Warning: Skeleton is empty for {output_path_str}. Saving empty NIfTI.")
        # Ensure consistent shape for empty array if input was non-3D or None
        if not hasattr(skeleton_array, 'shape') or skeleton_array.ndim != 3:
             skeleton_array = np.zeros((1,1,1), dtype=np.uint8) # Minimal valid shape
        else:
             skeleton_array = np.zeros_like(skeleton_array, dtype=np.uint8)


    skeleton_nifti = nib.Nifti1Image(skeleton_array.astype(np.uint8), affine_mat)
    nib.save(skeleton_nifti, output_path_str)
    print(f"Saved skeleton to {output_path_str}")

def extract_skeleton(mask_arr, affine_mat, output_path_nifti, spacing_val=(1, 1, 1)):
    """Extract the skeleton, focusing on the largest connected component. Returns ZYX coords ordered by Z."""
    mask_arr = (mask_arr > 0).astype(np.uint8)
    if np.sum(mask_arr) == 0:
        save_skeleton_nifti(np.zeros_like(mask_arr), affine_mat, output_path_nifti)
        return np.array([])

    skeleton = skeletonize(mask_arr) # skimage.morphology.skeletonize

    if np.sum(skeleton) == 0:
        save_skeleton_nifti(skeleton, affine_mat, output_path_nifti)
        return np.array([])

    labeled_skel, num_features = skimage_label(skeleton, return_num=True)

    if num_features == 0:
        save_skeleton_nifti(skeleton, affine_mat, output_path_nifti)
        return np.array([])

    if num_features > 1:
        component_sizes = np.bincount(labeled_skel.ravel())
        if len(component_sizes) > 1:
            largest_component_label = component_sizes[1:].argmax() + 1
            skeleton = (labeled_skel == largest_component_label).astype(np.uint8)
            if np.sum(skeleton) == 0:
                save_skeleton_nifti(skeleton, affine_mat, output_path_nifti)
                return np.array([])
        else:
            save_skeleton_nifti(np.zeros_like(skeleton), affine_mat, output_path_nifti)
            return np.array([])

    save_skeleton_nifti(skeleton, affine_mat, output_path_nifti)

    if np.sum(skeleton) == 0:
        return np.array([])

    try:
        graph_data_tuple = csr.skeleton_to_csgraph(skeleton, spacing=spacing_val)
        pixel_graph = graph_data_tuple[0]
        raw_coords_output = graph_data_tuple[1]
        final_skan_coords_zyx = np.array([])

        if isinstance(raw_coords_output, np.ndarray):
            if raw_coords_output.ndim == 2 and raw_coords_output.shape[1] == skeleton.ndim:
                final_skan_coords_zyx = raw_coords_output
            elif raw_coords_output.size == 0 :
                 final_skan_coords_zyx = np.array([]).reshape(0, skeleton.ndim)
            else:
                print(f"Warning: Coords from skan is ndarray with unexpected shape {raw_coords_output.shape} for {output_path_nifti.split('/')[-1]}.")
                return np.array([])
        elif isinstance(raw_coords_output, tuple):
            if len(raw_coords_output) == skeleton.ndim and \
               all(isinstance(arr, np.ndarray) and arr.ndim == 1 for arr in raw_coords_output):
                lengths = [len(arr) for arr in raw_coords_output]
                if len(set(lengths)) == 1:
                    if lengths[0] > 0:
                        final_skan_coords_zyx = np.array(raw_coords_output).T
                    else:
                        final_skan_coords_zyx = np.array([]).reshape(0, skeleton.ndim)
                else:
                    print(f"Warning: Coordinate arrays in tuple from skan have inconsistent lengths for {output_path_nifti.split('/')[-1]}. Lengths: {lengths}")
                    return np.array([])
            else:
                print(f"Warning: Coords tuple from skan has unexpected content for {output_path_nifti.split('/')[-1]}. Value: {raw_coords_output}")
                return np.array([])
        else:
            print(f"Warning: Unexpected coordinate format from skeleton_to_csgraph for {output_path_nifti.split('/')[-1]}. Type: {type(raw_coords_output)}.")
            return np.array([])

        if final_skan_coords_zyx.shape[0] == 0:
            return np.array([])

        graph_coo = coo_matrix(pixel_graph)
        graph = nx.Graph()
        if graph_coo.row.size > 0 :
            for i, (start_node, end_node) in enumerate(zip(graph_coo.row, graph_coo.col)):
                graph.add_edge(start_node, end_node, weight=pixel_graph.data[i])

        if graph.number_of_nodes() == 0:
             if final_skan_coords_zyx.shape[0] > 0 and num_features == 1:
                 print(f"Warning: Graph has 0 edges, but {final_skan_coords_zyx.shape[0]} nodes. Returning all nodes for single component: {output_path_nifti.split('/')[-1]}.")
                 # Sort fallback component nodes by Z-coordinate
                 if final_skan_coords_zyx.shape[0] > 0:
                     sorted_indices = np.argsort(final_skan_coords_zyx[:, 0])
                     final_skan_coords_zyx = final_skan_coords_zyx[sorted_indices]
                 return final_skan_coords_zyx
             return np.array([])

        endpoints = [node for node, degree in dict(graph.degree()).items() if degree == 1]

        if len(endpoints) < 2:
             if final_skan_coords_zyx.shape[0] > 0 and num_features == 1 and graph.number_of_nodes() > 0:
                 print(f"Warning: Less than 2 endpoints ({len(endpoints)}) for {output_path_nifti.split('/')[-1]}. Returning all {final_skan_coords_zyx.shape[0]} nodes for this component.")
                 # Sort fallback component nodes by Z-coordinate
                 if final_skan_coords_zyx.shape[0] > 0:
                     sorted_indices = np.argsort(final_skan_coords_zyx[:, 0])
                     final_skan_coords_zyx = final_skan_coords_zyx[sorted_indices]
                 return final_skan_coords_zyx
             return np.array([])

        longest_path_nodes, max_length = [], 0
        for start_node_idx in endpoints:
            for end_node_idx in endpoints:
                if start_node_idx != end_node_idx:
                    if graph.has_node(start_node_idx) and graph.has_node(end_node_idx):
                        if nx.has_path(graph, source=start_node_idx, target=end_node_idx):
                            current_path = nx.shortest_path(graph, source=start_node_idx, target=end_node_idx, weight='weight')
                            current_length = nx.shortest_path_length(graph, source=start_node_idx, target=end_node_idx, weight='weight')
                            if current_length > max_length:
                                max_length, longest_path_nodes = current_length, current_path

        if not longest_path_nodes:
            if final_skan_coords_zyx.shape[0] > 0 and num_features == 1 and graph.number_of_nodes() > 0:
                 print(f"Warning: No path found between endpoints for {output_path_nifti.split('/')[-1]}. Returning all {final_skan_coords_zyx.shape[0]} nodes for this component as fallback.")
                 # Sort fallback component nodes by Z-coordinate
                 if final_skan_coords_zyx.shape[0] > 0:
                     sorted_indices = np.argsort(final_skan_coords_zyx[:, 0])
                     final_skan_coords_zyx = final_skan_coords_zyx[sorted_indices]
                 return final_skan_coords_zyx
            return np.array([])

        longest_path_coords = np.array([final_skan_coords_zyx[node_idx] for node_idx in longest_path_nodes])

        # Ensure the path is ordered from low Z to high Z
        if longest_path_coords.shape[0] > 1: # Check if there's more than one point to order
            if longest_path_coords[0, 0] > longest_path_coords[-1, 0]: # Compare Z of first and last point
                longest_path_coords = longest_path_coords[::-1] # Reverse the array if start Z > end Z

        return longest_path_coords

    except Exception as e:
        import traceback
        print(f"ERROR in skeleton_to_csgraph or graph processing for {output_path_nifti.split('/')[-1]}: {e}")
        # traceback.print_exc()
        return np.array([])

def smooth_skeleton(skeleton_points_zyx, window_size=5, jump_threshold=3.0):
    """Smooth Y and X coordinates of ZYX skeleton points."""
    if not hasattr(skeleton_points_zyx, 'shape') or skeleton_points_zyx.ndim != 2 or skeleton_points_zyx.shape[0] < 2:
        return skeleton_points_zyx

    z_coords = skeleton_points_zyx[:, 0]
    y_coords = skeleton_points_zyx[:, 1]
    x_coords = skeleton_points_zyx[:, 2]

    y_smooth, x_smooth = y_coords.copy(), x_coords.copy()

    actual_window_y = min(window_size, len(y_coords))
    actual_window_x = min(window_size, len(x_coords))

    if actual_window_y > 1: y_smooth = uniform_filter1d(y_coords, size=actual_window_y, mode='nearest')
    if actual_window_x > 1: x_smooth = uniform_filter1d(x_coords, size=actual_window_x, mode='nearest')
    
    dy, dx = np.array([]), np.array([])
    if len(y_smooth) > 1: dy = np.abs(np.diff(y_smooth))
    if len(x_smooth) > 1: dx = np.abs(np.diff(x_smooth))
    
    jump_indices_y, jump_indices_x = np.array([]), np.array([])
    if dy.size > 0: jump_indices_y = np.where(dy > jump_threshold)[0]
    if dx.size > 0: jump_indices_x = np.where(dx > jump_threshold)[0]

    for idx in jump_indices_y: # idx from diff, refers to point (idx+1)
        if 0 < (idx + 1) < len(y_smooth) - 1: 
            y_smooth[idx + 1] = (y_smooth[idx] + y_smooth[idx + 2]) / 2
    for idx in jump_indices_x:
        if 0 < (idx + 1) < len(x_smooth) - 1:
            x_smooth[idx + 1] = (x_smooth[idx] + x_smooth[idx + 2]) / 2
    
    return np.vstack((z_coords, y_smooth, x_smooth)).T

def compute_z_slice_com(skeleton_coords_zyx, mask_3d_zyx):
    """Computes Center of Mass (CoM) for YX plane at each Z-slice. mask_3d_zyx should be ZYX."""
    com_dict = {}
    if skeleton_coords_zyx.size == 0: return com_dict
    
    z_unique_skeleton = np.unique(skeleton_coords_zyx[:, 0].astype(int))

    for z_val in z_unique_skeleton:
        if 0 <= z_val < mask_3d_zyx.shape[0]:
            mask_slice_yx = mask_3d_zyx[z_val, :, :] # This is a YX slice
            if np.sum(mask_slice_yx) > 0:
                labeled_slice, num_comps = skimage_label(mask_slice_yx, return_num=True)
                if num_comps > 0:
                    props = regionprops(labeled_slice)
                    if props:
                        largest_region = max(props, key=lambda r: r.area)
                        # centroid is (row, col) which is (y_in_slice, x_in_slice) for this YX slice
                        com_y_slice, com_x_slice = largest_region.centroid 
                        com_dict[z_val] = (com_x_slice, com_y_slice) # Store as (X, Y) for consistency
    return com_dict

def replace_skeleton_endpoints(skeleton_coords_zyx, mask_3d_zyx, affine_mat, output_path_nifti=None, replace_fraction=0.05):
    """Replaces start of ZYX skeleton points using CoM. mask_3d_zyx is ZYX."""
    if not hasattr(skeleton_coords_zyx, 'shape') or skeleton_coords_zyx.shape[0] == 0:
        if output_path_nifti:
            save_skeleton_nifti(np.zeros(mask_3d_zyx.shape if hasattr(mask_3d_zyx,'shape') else (1,1,1)), affine_mat, output_path_nifti)
        return np.array([])

    num_points = len(skeleton_coords_zyx)
    replace_count = int(replace_fraction * num_points)
    if replace_count == 0 and num_points > 0: replace_count = 1

    # Use mask_3d_zyx which should be in ZYX order
    feature_dict_xy = compute_z_slice_com(skeleton_coords_zyx, mask_3d_zyx) 
    corrected_skeleton_zyx = skeleton_coords_zyx.copy()

    for i in range(min(replace_count, num_points)):
        z_coord = int(np.round(corrected_skeleton_zyx[i, 0]))
        if z_coord in feature_dict_xy:
            new_x_slice, new_y_slice = feature_dict_xy[z_coord]
            corrected_skeleton_zyx[i, 1] = new_y_slice # Update Y coord
            corrected_skeleton_zyx[i, 2] = new_x_slice # Update X coord
            # Z coord (corrected_skeleton_zyx[i, 0]) remains unchanged

    if output_path_nifti:
        corrected_skel_img_arr = np.zeros(mask_3d_zyx.shape, dtype=np.uint8)
        for point_zyx in corrected_skeleton_zyx:
            z_idx, y_idx, x_idx = np.round(point_zyx).astype(int)
            if (0 <= z_idx < mask_3d_zyx.shape[0] and
                0 <= y_idx < mask_3d_zyx.shape[1] and
                0 <= x_idx < mask_3d_zyx.shape[2]):
                corrected_skel_img_arr[z_idx, y_idx, x_idx] = 1
        save_skeleton_nifti(corrected_skel_img_arr, affine_mat, output_path_nifti)
    return corrected_skeleton_zyx

def find_boundary_crossing_3d(line_coords_tuple_zyx, mask_arr_zyx):
    """Finds boundary crossing. line_coords_tuple_zyx is (z_coords, y_coords, x_coords)."""
    inside_vessel = True
    for z, y, x in zip(*line_coords_tuple_zyx):
        z, y, x = int(round(z)), int(round(y)), int(round(x)) # Ensure integer indices
        if not (0 <= z < mask_arr_zyx.shape[0] and 0 <= y < mask_arr_zyx.shape[1] and 0 <= x < mask_arr_zyx.shape[2]):
            return None 
        if mask_arr_zyx[z, y, x] == 0 and inside_vessel:
            return np.array([z, y, x]) # Return ZYX
        inside_vessel = (mask_arr_zyx[z, y, x] == 1)
    return None

def save_diameters_to_csv(diameter_df, output_path_csv):
    diameter_df.to_csv(output_path_csv, index=False)
    print(f"Diameters saved to {output_path_csv}")

def get_skeleton_tangents(points_zyx, k=1):
    """Computes tangent vectors for each point on the skeleton (ZYX)."""
    num_points = len(points_zyx)
    tangents = np.zeros_like(points_zyx, dtype=float)

    if num_points == 0:
        return tangents
    if num_points == 1:
        tangents[0] = np.array([1.0, 0.0, 0.0])
        return tangents

    actual_k = min(k, (num_points - 1) // 2) if num_points > 2 else (1 if num_points == 2 else 0)
    if num_points == 2 : actual_k = 1 # Ensure k=1 logic applies for 2 points.
    if actual_k == 0 and num_points == 1: actual_k = 0 # No neighbors for single point

    if actual_k > 0:
        for i in range(actual_k, num_points - actual_k):
            tangents[i] = points_zyx[i + actual_k] - points_zyx[i - actual_k]
        for i in range(actual_k):
            tangents[i] = points_zyx[i + actual_k] - points_zyx[i]
            tangents[num_points - 1 - i] = points_zyx[num_points - 1 - i] - points_zyx[num_points - 1 - i - actual_k]
    elif num_points >= 2: # e.g. num_points = 2, actual_k becomes 1 here from above. Or if k=0 passed.
        # This case needs to be robust for num_points = 2 with default k or small k.
        # The logic above for actual_k tries to ensure it's 1 for num_points=2 if k>=1.
        # If k=0 or num_points=1, it's handled. If num_points=2 and initial k was 0, actual_k is 0.
        # Let's assume actual_k=1 for 2 points as intended by the initial re-assignment.
        # Fallback for very short segments if complex k logic fails
        if num_points == 2:
             tangents[0] = points_zyx[1] - points_zyx[0]
             tangents[1] = points_zyx[1] - points_zyx[0]

    for i in range(num_points):
        norm = np.linalg.norm(tangents[i])
        if norm > 1e-9:
            tangents[i] /= norm
        else:
            if i > 0 and np.linalg.norm(tangents[i-1]) > 1e-9 :
                tangents[i] = tangents[i-1]
            elif i + 1 < num_points and np.linalg.norm(tangents[i+1]) > 1e-9:
                 # This part needs care to not use uninitialized tangents[i+1] if loop order matters
                 # For now, just a default if no neighbor helps
                 pass # Will be caught by next default
            if np.linalg.norm(tangents[i]) < 1e-9: # if still zero
                 tangents[i] = np.array([1.0, 0.0, 0.0])
    return tangents

def get_orthonormal_basis_from_normal(normal_vec_zyx):
    """Creates an orthonormal basis (U, V) perpendicular to normal_vec_zyx (ZYX)."""
    N = normal_vec_zyx
    if np.linalg.norm(N) < 1e-9: # Handle zero vector input
        return np.array([0.,1.,0.]), np.array([0.,0.,1.])

    # Using Pixar's method for building an orthonormal basis (adapted for ZYX: N[2] is X-component)
    # N must be unit length. Assumed it is, or normalize at start.
    # N = N / np.linalg.norm(N) # Ensure normalization if not guaranteed
    
    # Let N = (n_z, n_y, n_x)
    if N[2] < -0.9999999: # Tangent pointing almost perfectly along -X axis
        U = np.array([0.0, -1.0, 0.0]) # U along -Y
        V = np.array([-1.0, 0.0, 0.0]) # V along -Z
    elif N[2] > 0.9999999: # Tangent pointing almost perfectly along +X axis
        U = np.array([0.0, 1.0, 0.0])  # U along +Y
        V = np.array([1.0, 0.0, 0.0])  # V along +Z
    else:
        # General case based on Pixar's method (using N[2] as the 'z' in their Z-up formula)
        # This refers to the x-component of our ZYX tangent vector
        a = 1.0 / (1.0 + N[2]) 
        b = -N[0] * N[1] * a
        U = np.array([1.0 - N[0] * N[0] * a, b, -N[0]]) 
        V = np.array([b, 1.0 - N[1] * N[1] * a, -N[1]])
        # U and V from this formula are (Tangent_z, Tangent_y, Tangent_x) components for U and V vectors
        # Correction to map Pixar's Z-up (x,y,z) to our ZYX (comp0, comp1, comp2)
        # If N = (Nz, Ny, Nx)
        # Using Nx as N.z in Pixar's formula. N0 as N.x, N1 as N.y
        # U = (1 - Nx^2 / (1+Nz), -Nx Ny / (1+Nz), -Nx) -> if mapping (x,y,z) to (N0, N1, N2=Nz)
        # The Pixar formula is for w=(wx,wy,wz) with wz being the "up" component.
        # If our tangent is T=(Tz, Ty, Tx) and we consider Tx as the component for singularity check:
        if N[2] < -0.9999999: # Tangent's X component is very negative
            U = np.array([0.0, -1.0, 0.0]) # U = -Y_axis
            V = np.array([-1.0, 0.0, 0.0]) # V = -Z_axis
        else: # General case from Frisvad, 2012, "Building an Orthonormal Basis from a 3D Unit Vector"
            a = 1.0/(1.0 + N[2]); # N[2] is the x-component of the tangent T=(Tz, Ty, Tx)
            b = -N[0]*N[1]*a; # N[0] is Tz, N[1] is Ty
            U = np.array([1.0-N[0]*N[0]*a, b, -N[0]]);
            V = np.array([b, 1.0-N[1]*N[1]*a, -N[1]]);

    # Fallback if U or V is zero (should not happen with Frisvad/Pixar if N is unit)
    if np.linalg.norm(U) < 1e-6 or np.linalg.norm(V) < 1e-6:
        # Default robust cross product method if above fails
        if np.abs(N[0]) < 0.9:
            U_fallback = np.cross(N, np.array([1.0, 0.0, 0.0]))
        else:
            U_fallback = np.cross(N, np.array([0.0, 1.0, 0.0]))
        if np.linalg.norm(U_fallback) < 1e-6:
             U_fallback = np.cross(N, np.array([0.0,0.0,1.0])) # Try X if Z and Y aligned
        
        if np.linalg.norm(U_fallback) < 1e-6: # Highly problematic N
            return np.array([0.,1.,0.]), np.array([0.,0.,1.]) # Default to YX plane basis

        U = U_fallback / np.linalg.norm(U_fallback)
        V_fallback = np.cross(N, U)
        V = V_fallback / np.linalg.norm(V_fallback) # ensure unit
        return U, V
        
    return U, V

def calculate_improved_diameter(mask_arr_zyx, skeleton_coords_zyx, voxel_spacing_vals=(1, 1, 1),
                                tangent_k=1, orthogonality_threshold_cos=0.90):
    """Computes diameters (ZYX) with handling for non-orthogonal sections."""
    if not hasattr(skeleton_coords_zyx, 'shape') or skeleton_coords_zyx.shape[0] == 0:
        return pd.DataFrame(columns=["X", "Y", "Z", "Diameter", "RawDiameter"])

    voxel_spacing_np = np.array(voxel_spacing_vals)
    distance_map = distance_transform_edt(mask_arr_zyx, sampling=voxel_spacing_np)
    raw_diameters_list = []
    
    skeleton_coords_zyx_np = np.asarray(skeleton_coords_zyx)
    if skeleton_coords_zyx_np.shape[0] == 0: # handle empty array after asarray
        return pd.DataFrame(columns=["X", "Y", "Z", "Diameter", "RawDiameter"])

    tangents_zyx = get_skeleton_tangents(skeleton_coords_zyx_np, k=tangent_k)
    z_axis_vector = np.array([1.0, 0.0, 0.0])

    for i, skel_point_zyx in enumerate(skeleton_coords_zyx_np):
        skel_z, skel_y, skel_x = np.round(skel_point_zyx).astype(int)

        if not (0 <= skel_z < mask_arr_zyx.shape[0] and \
                0 <= skel_y < mask_arr_zyx.shape[1] and \
                0 <= skel_x < mask_arr_zyx.shape[2]):
            raw_diameters_list.append(np.nan)
            continue

        dt_radius = distance_map[skel_z, skel_y, skel_x]
        dt_diameter = 2 * dt_radius
        ray_diameters = []
        num_rays = 8
        max_search_dist_phys = max(mask_arr_zyx.shape) * max(voxel_spacing_np)

        current_tangent_normalized = tangents_zyx[i]
        cos_angle_tangent_vs_z = np.abs(np.dot(current_tangent_normalized, z_axis_vector))
        use_yx_plane_rays = (cos_angle_tangent_vs_z >= orthogonality_threshold_cos)

        ray_casting_directions = []
        if use_yx_plane_rays:
            for angle_yx in np.linspace(0, 2*np.pi, num_rays, endpoint=False):
                ray_casting_directions.append(np.array([0, np.cos(angle_yx), np.sin(angle_yx)]))
        else:
            U_vec, V_vec = get_orthonormal_basis_from_normal(current_tangent_normalized)
            for angle_ortho in np.linspace(0, 2*np.pi, num_rays, endpoint=False):
                ray_casting_directions.append(np.cos(angle_ortho) * U_vec + np.sin(angle_ortho) * V_vec)

        for direction_vec in ray_casting_directions:
            # Ensure direction_vec is unit vector before scaling for physical distance step
            norm_dir = np.linalg.norm(direction_vec)
            if norm_dir < 1e-9 : continue # Skip if direction is zero
            unit_direction_vec = direction_vec / norm_dir
            
            delta_voxels = (max_search_dist_phys / voxel_spacing_np) * unit_direction_vec

            end_point_1 = skel_point_zyx + delta_voxels
            line_coords_1_tpl = line_nd(skel_point_zyx, end_point_1)
            
            end_point_2 = skel_point_zyx - delta_voxels
            line_coords_2_tpl = line_nd(skel_point_zyx, end_point_2)
            
            boundary_1_zyx = find_boundary_crossing_3d(line_coords_1_tpl, mask_arr_zyx)
            boundary_2_zyx = find_boundary_crossing_3d(line_coords_2_tpl, mask_arr_zyx)
            
            if boundary_1_zyx is not None and boundary_2_zyx is not None:
                ray_diameter = np.linalg.norm((boundary_1_zyx - boundary_2_zyx) * voxel_spacing_np)
                ray_diameters.append(ray_diameter)

        if len(ray_diameters) >= num_rays / 2: # Use num_rays/2 as a threshold
            diameter = (dt_diameter + np.median(ray_diameters)) / 2.0
        else:
            diameter = dt_diameter
        raw_diameters_list.append(diameter)

    raw_diameters_arr = np.array(raw_diameters_list)
    if raw_diameters_arr.size == 0:
        return pd.DataFrame(columns=["X", "Y", "Z", "Diameter", "RawDiameter"])

    # Interpolate NaNs (ensure Series is float for interpolate)
    raw_diameters_s = pd.Series(raw_diameters_arr, dtype=float)
    raw_diameters_arr = raw_diameters_s.interpolate(method='linear').bfill().ffill().values
    
    smoothed_diameters = raw_diameters_arr.copy()
    if raw_diameters_arr.size > 0 :
        polyorder = 3
        window_length = min(51, len(raw_diameters_arr))
        if window_length < polyorder +1 : window_length = polyorder + 1 + (polyorder+1)%2 # ensure WL > polyorder and odd
        if window_length % 2 == 0: window_length +=1 # Ensure odd
        
        if len(raw_diameters_arr) >= window_length and window_length > polyorder : # Check if WL valid for savgol
            smoothed_diameters = savgol_filter(raw_diameters_arr, window_length, polyorder)
        elif len(raw_diameters_arr) >=3 :
            win = min(3, len(raw_diameters_arr))
            if win % 2 == 0: win = max(1, win -1) 
            smoothed_diameters = pd.Series(raw_diameters_arr).rolling(window=win, center=True, min_periods=1).mean().values
    
    output_list = []
    for i, skel_point_zyx in enumerate(skeleton_coords_zyx_np):
        if i < len(smoothed_diameters) and i < len(raw_diameters_arr): # Check bounds
            output_list.append({
                "X": skel_point_zyx[2], 
                "Y": skel_point_zyx[1], 
                "Z": skel_point_zyx[0],
                "Diameter": np.round(smoothed_diameters[i], 5),
                "RawDiameter": np.round(raw_diameters_arr[i], 5)
            })
    return pd.DataFrame(output_list)

In [51]:
input_dir_main = 'Enrique_results/Multimodal_ablation/Single_strong/post/umamba'
# input_dir_main = '/Users/mahdiislam/Higher_Studies/THESIS/Aorta_seg/Diameter_Cal_Output_Robust/GT/gt' 

base_output_dir = "Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba"
# base_output_dir = 'Diameter_Cal_Output_Robust/GT_v2/'

output_skeleton_dir_illiac = os.path.join(base_output_dir, "Illiac/Skeletons")
output_diameter_dir_illiac = os.path.join(base_output_dir, "Illiac/Diameters")
output_skeleton_aorta = os.path.join(base_output_dir, "Aorta/Skeletons")
output_diameter_aorta = os.path.join(base_output_dir, "Aorta/Diameters")

os.makedirs(output_skeleton_dir_illiac, exist_ok=True)
os.makedirs(output_diameter_dir_illiac, exist_ok=True)
os.makedirs(output_skeleton_aorta, exist_ok=True)
os.makedirs(output_diameter_aorta, exist_ok=True)
correct_skel_flag = True 

In [52]:
mask_files_list = [f for f in os.listdir(input_dir_main) if f.endswith(".nii.gz") or f.endswith(".nii")]

for mask_file_name in mask_files_list:
    mask_file_path = os.path.join(input_dir_main, mask_file_name)
    
    # Attempt to extract a robust case_id
    base_name_for_id = mask_file_name.replace(".nii.gz", "").replace(".nii", "")
    parts = base_name_for_id.split("_")
    case_id_str = parts[1] if len(parts) > 1 else base_name_for_id # Fallback to basename if split fails

    # Define output paths using case_id_str
    path_skel_r_illiac = os.path.join(output_skeleton_dir_illiac, f"{case_id_str}_r_illiac_skeleton.nii.gz")
    path_skel_l_illiac = os.path.join(output_skeleton_dir_illiac, f"{case_id_str}_l_illiac_skeleton.nii.gz")
    path_skel_aorta = os.path.join(output_skeleton_aorta, f"{case_id_str}_aorta_skeleton.nii.gz")

    path_diam_r_illiac_csv = os.path.join(output_diameter_dir_illiac, f"{case_id_str}_diameters_r_illiac.csv")
    path_diam_l_illiac_csv = os.path.join(output_diameter_dir_illiac, f"{case_id_str}_diameters_l_illiac.csv")
    path_diam_aorta_csv = os.path.join(output_diameter_aorta, f"{case_id_str}_diameters_aorta.csv")

    path_plot_illiac_png = os.path.join(output_diameter_dir_illiac, f"{case_id_str}_diameter_plot_iliac.png")
    path_plot_aorta_png = os.path.join(output_diameter_aorta, f"{case_id_str}_diameter_plot_aorta.png")

    path_corr_skel_r_illiac, path_corr_skel_l_illiac, path_corr_skel_aorta = None, None, None
    if correct_skel_flag:
        path_corr_skel_r_illiac = os.path.join(output_skeleton_dir_illiac, f"{case_id_str}_corrected_r_illiac_skeleton.nii.gz")
        path_corr_skel_l_illiac = os.path.join(output_skeleton_dir_illiac, f"{case_id_str}_corrected_l_illiac_skeleton.nii.gz")
        path_corr_skel_aorta = os.path.join(output_skeleton_aorta, f"{case_id_str}_corrected_aorta_skeleton.nii.gz")

    print(f"\nProcessing case: {case_id_str} from file {mask_file_name}")
    # fdata is typically XYZ, but nib.as_closest_canonical orients to RAS (XYZ)
    # We will internally use ZYX for skeleton points for consistency with np.where and array indexing
    mask_data_ras, affine_matrix_ras, _ = load_nifti(mask_file_path)
    # Assuming mask_data_ras is (X,Y,Z). For internal ZYX processing, if needed, transpose.
    # However, most skimage functions handle data as is. np.where(mask_data_ras > 0) for XYZ skeleton.
    # To be safe, let's be explicit if we change order.
    # The crucial part is that extract_skeleton returns ZYX if it uses np.where((z,y,x))
    # And mask_3d in functions like compute_z_slice_com needs to match this.
    # For now, assume mask_data_ras is used directly by get_binary_mask, and its output is passed.
    # extract_skeleton will define its internal coordinate system (ZYX) from its np.where call.

    # Initialize variables
    skel_coords_aorta, skel_coords_l_illiac, skel_coords_r_illiac = np.array([]), np.array([]), np.array([])
    diam_df_aorta, diam_df_l_illiac, diam_df_r_illiac = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    # --- Process Aorta (label 1) ---
    mask_aorta = get_binary_mask(mask_data_ras, label_val=1)
    mask_aorta = clean_mask(mask_aorta) # Clean mask to remove small artifacts
    if np.any(mask_aorta):
        print(f"Processing Aorta for case: {case_id_str}")
        # Optional: struct_el = np.ones((3,3,3), dtype=bool); mask_aorta = binary_closing(mask_aorta, structure=struct_el).astype(np.uint8)
        mask_aorta_zyx = np.transpose(mask_aorta, (2, 1, 0)) 
        raw_skel_aorta = extract_skeleton(mask_aorta_zyx, affine_matrix_ras, path_skel_aorta) # Returns ZYX
        if raw_skel_aorta.size > 0:
            # Reorient based on Z (first column of ZYX)
            # if raw_skel_aorta.shape[0] > 10 and np.argmin(raw_skel_aorta[:, 0]) > 10:
            #     raw_skel_aorta = raw_skel_aorta[::-1]
            skel_coords_aorta = smooth_skeleton(raw_skel_aorta)
            if correct_skel_flag and skel_coords_aorta.size > 0:
                # Pass mask_aorta (ZYX if mask_data_ras was transposed, or XYZ if not and functions handle it)
                # Assuming functions now expect ZYX for mask_3d if skeleton is ZYX
                skel_coords_aorta = replace_skeleton_endpoints(skel_coords_aorta, mask_aorta_zyx, affine_matrix_ras, path_corr_skel_aorta) 
            if skel_coords_aorta.size > 0:
                diam_df_aorta = calculate_improved_diameter(mask_aorta_zyx, skel_coords_aorta)
                if not diam_df_aorta.empty:
                    save_diameters_to_csv(diam_df_aorta, path_diam_aorta_csv)
    else:
        print(f"Aorta mask (label 1) is empty for case: {case_id_str}.")
        save_skeleton_nifti(np.zeros_like(mask_aorta), affine_matrix_ras, path_skel_aorta) # Save empty

    # --- Process Left Iliac (label 2) ---
    mask_l_illiac = get_binary_mask(mask_data_ras, label_val=2)
    mask_l_illiac = clean_mask(mask_l_illiac) # Clean mask to remove small artifacts
    if np.any(mask_l_illiac):
        print(f"Processing Left Iliac for case: {case_id_str}")
        # Optional: struct_el = np.ones((3,3,3), dtype=bool); mask_l_illiac = binary_closing(mask_l_illiac, structure=struct_el).astype(np.uint8)
        mask_l_illiac_zyx = np.transpose(mask_l_illiac, (2, 1, 0))
        raw_skel_l_illiac = extract_skeleton(mask_l_illiac_zyx, affine_matrix_ras, path_skel_l_illiac) # ZYX
        if raw_skel_l_illiac.size > 0:
            # if raw_skel_l_illiac.shape[0] > 10 and np.argmin(raw_skel_l_illiac[:, 0]) > 10:
            #     raw_skel_l_illiac = raw_skel_l_illiac[::-1]
            skel_coords_l_illiac = smooth_skeleton(raw_skel_l_illiac)
            if correct_skel_flag and skel_coords_l_illiac.size > 0:
                skel_coords_l_illiac = replace_skeleton_endpoints(skel_coords_l_illiac, mask_l_illiac_zyx, affine_matrix_ras, path_corr_skel_l_illiac)
            if skel_coords_l_illiac.size > 0:
                diam_df_l_illiac = calculate_improved_diameter(mask_l_illiac_zyx, skel_coords_l_illiac)
                if not diam_df_l_illiac.empty:
                    save_diameters_to_csv(diam_df_l_illiac, path_diam_l_illiac_csv)
    else:
        print(f"Left Iliac mask (label 2) is empty for case: {case_id_str}.")
        save_skeleton_nifti(np.zeros_like(mask_l_illiac), affine_matrix_ras, path_skel_l_illiac)

    # --- Process Right Iliac (label 3) ---
    mask_r_illiac = get_binary_mask(mask_data_ras, label_val=3)
    mask_r_illiac = clean_mask(mask_r_illiac) # Clean mask to remove small artifacts
    if np.any(mask_r_illiac):
        print(f"Processing Right Iliac for case: {case_id_str}")
        # Optional: struct_el = np.ones((3,3,3), dtype=bool); mask_r_illiac = binary_closing(mask_r_illiac, structure=struct_el).astype(np.uint8)
        mask_r_illiac_zyx = np.transpose(mask_r_illiac, (2, 1, 0))
        raw_skel_r_illiac = extract_skeleton(mask_r_illiac_zyx, affine_matrix_ras, path_skel_r_illiac) # ZYX
        if raw_skel_r_illiac.size > 0:
            # if raw_skel_r_illiac.shape[0] > 10 and np.argmin(raw_skel_r_illiac[:, 0]) > 10:
            #     raw_skel_r_illiac = raw_skel_r_illiac[::-1]
            skel_coords_r_illiac = smooth_skeleton(raw_skel_r_illiac)
            if correct_skel_flag and skel_coords_r_illiac.size > 0:
                skel_coords_r_illiac = replace_skeleton_endpoints(skel_coords_r_illiac, mask_r_illiac_zyx, affine_matrix_ras, path_corr_skel_r_illiac)
            if skel_coords_r_illiac.size > 0:
                diam_df_r_illiac = calculate_improved_diameter(mask_r_illiac_zyx, skel_coords_r_illiac)
                if not diam_df_r_illiac.empty:
                    save_diameters_to_csv(diam_df_r_illiac, path_diam_r_illiac_csv)
    else:
        print(f"Right Iliac mask (label 3) is empty for case: {case_id_str}.")
        save_skeleton_nifti(np.zeros_like(mask_r_illiac), affine_matrix_ras, path_skel_r_illiac)

    # --- Visualize Diameters ---
    # Plotting Y (Diameter) vs X (Index along skeleton)
    plt.figure(figsize=(8, 10)) 
    plot_iliac_ok = False
    if not diam_df_l_illiac.empty:
        plt.plot(diam_df_l_illiac["Diameter"], diam_df_l_illiac.index, marker='o', linestyle='-', label="Left Iliac")
        plot_iliac_ok = True
    if not diam_df_r_illiac.empty:
        plt.plot(diam_df_r_illiac["Diameter"], diam_df_r_illiac.index, marker='x', linestyle='--', label="Right Iliac")
        plot_iliac_ok = True
    
    if plot_iliac_ok:
        plt.xlabel("Diameter (mm)")
        plt.ylabel("Point Index along Skeleton (from bottom to top)")
        plt.title(f"Iliac Artery Diameters - Case {case_id_str}")
        plt.axvline(x=5, color='red', linestyle=':', linewidth=2.5, label="5mm Threshold")
        plt.legend()
        plt.grid(True)
    else:
        plt.text(0.5, 0.5, "No Iliac data to plot", ha='center', va='center', transform=plt.gca().transAxes)
        plt.title(f"Iliac Diameters (No Data) - Case {case_id_str}")
    plt.tight_layout()
    plt.savefig(path_plot_illiac_png)
    plt.close()

    plt.figure(figsize=(8, 10))
    if not diam_df_aorta.empty:
        plt.plot(diam_df_aorta["Diameter"], diam_df_aorta.index, marker='s', linestyle='-', label="Aorta", color='red')
        plt.xlabel("Diameter (mm)")
        plt.ylabel("Point Index along Skeleton (from bottom to top)")
        plt.title(f"Aorta Diameter - Case {case_id_str}")
        plt.legend()
        plt.grid(True)
    else:
        plt.text(0.5, 0.5, "No Aorta data to plot", ha='center', va='center', transform=plt.gca().transAxes)
        plt.title(f"Aorta Diameter (No Data) - Case {case_id_str}")
    plt.tight_layout()
    plt.savefig(path_plot_aorta_png)
    plt.close()

print("\nProcessing complete for all files.")


Processing case: 207 from file prepoststrong_207.nii.gz
Processing Aorta for case: 207
Saved skeleton to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Aorta/Skeletons/207_aorta_skeleton.nii.gz
Saved skeleton to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Aorta/Skeletons/207_corrected_aorta_skeleton.nii.gz
Diameters saved to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Aorta/Diameters/207_diameters_aorta.csv
Processing Left Iliac for case: 207
Saved skeleton to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Illiac/Skeletons/207_l_illiac_skeleton.nii.gz
Saved skeleton to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Illiac/Skeletons/207_corrected_l_illiac_skeleton.nii.gz
Diameters saved to Enrique_results/Multimodal_ablation/Single_strong/post/Diameter_cal/umamba/Illiac/Diameters/207_diameters_l_illiac.csv
Processing Right Iliac for case: 207