## **Heuristic and topological-based methods**

Code for the the "normal" heuristic method and the final one (based on heart chambers distances)

### **Imports and Installs**

In [None]:
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from collections import defaultdict
import plotly.graph_objects as go
import networkx as nx
from scipy.spatial.distance import cdist
from scipy.spatial import distance_matrix
from scipy.spatial import ConvexHull
from scipy.spatial import KDTree
import nibabel as nib
import os

### **Google Drive connection**

In [None]:
# Connect to Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Upload files**

In [None]:
cavities = [
    "heart_ventricle_right",
    "heart_ventricle_left",
    "heart_atrium_right",
    "heart_atrium_left"
]

##### Normal 1

In [None]:
path_chamb_n1 = '/content/drive/Shared drives/TFGs Coronarias 2024_25/Maren/Data/Heartchambers Segmentations/Segmentations/Normal_1_0000/'

In [None]:
file_path_rca_n1 = '/content/drive/Shared drives/TFG Maren/Data/ASOCA 4 Cases Mimics/Normal_1/rca_centerline.txt'
file_path_lca_n1= '/content/drive/Shared drives/TFG Maren/Data/ASOCA 4 Cases Mimics/Normal_1/lca_centerline.txt'

In [None]:
path_skel_rca_n1 = '/content/drive/Shared drives/TFGs Coronarias 2024_25/Maren/Data/Skeletonization/Dataframes/df_rca_n1.csv'

In [None]:
path_skel_lca_n1 = '/content/drive/Shared drives/TFGs Coronarias 2024_25/Maren/Data/Skeletonization/Dataframes/df_lca_n1.csv'

### **Functions**

#### **Heuristic rules algorithm code**

In [None]:
def parse_and_export_centerlines(file_path):
    branches = defaultdict(list)  # Diccionario donde la clave es el ID de la rama
    current_branch = None  # Guarda el ID de la rama actual
    has_branches = False  # Bandera para detectar si hay múltiples ramas

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()

            # Detectar una nueva rama
            branch_match = re.match(r"\[New Branch Set\] Branch Segment (\d+):", line)
            if branch_match:
                current_branch = int(branch_match.group(1))
                has_branches = True  # Confirmamos que hay múltiples ramas
                continue

            # Si no hay identificadores de rama, asignamos un ID por defecto
            if current_branch is None:
                current_branch = 0

            # Extraer puntos de la rama actual (ignorar líneas vacías o encabezados)
            if re.match(r"^\s*-?\d+\.\d+", line):
                data = []
                for value in line.split():
                    try:
                        data.append(float(value))  # Convertir a float
                    except ValueError:
                        data.append(None)  # Manejar valores no numéricos como None
                branches[current_branch].append(data)

    # Crear lista de datos para el DataFrame
    all_data = []
    for branch_id, points in branches.items():
        for point in points:
            all_data.append([branch_id] + point)

    # Columnas según el formato esperado
    columns = ["Branch ID", "Px", "Py", "Pz", "Tx", "Ty", "Tz", "Nx", "Ny", "Nz",
               "BNx", "BNy", "BNz", "Dfit", "Dmin", "Dmax", "C", "Dh", "Xh", "Scf", "Area", "E"]

    # Crear DataFrame y retornarlo
    df = pd.DataFrame(all_data, columns=columns)

    return df

In [None]:
def visualize_segments_names(df):
    """
    Visualizes the segmented branches of the artery using a 3D scatter plot.
    Each segment is color-coded based on its assigned Branch ID and labeled with its proper name.

    Parameters:
    df (pd.DataFrame): DataFrame containing artery segments with 'Px', 'Py', 'Pz', and 'Branch ID'.
    """
    # Mapping of Branch IDs to their names
    branch_name_map = {
      1: "1 - pRCA",
      2: "2 - mRCA",
      3: "3 - dRCA",
      4: "4 - R-PDA",
      5: "5 - LM",
      6: "6 - pLAD",
      7: "7 - mLAD",
      8: "8 - dLAD",
      9: "9 - D1",
      10: "10 - D2",
      11: "11 - pCx",
      12: "12 - OM1",
      13: "13 - LCx",
      14: "14 - OM2",
      15: "15 - L-PDA",
      16: "16 - R-PLB",
      17: "17 - RI",
      18: "18 - L-PLB"
    }
    # Get unique branch IDs
    unique_branches = df['Branch ID'].unique()
    colors = np.linspace(0, 1, len(unique_branches))  # Generate color indices
    color_map = dict(zip(unique_branches, colors))  # Map each branch to a color

    # Create figure
    fig = go.Figure()

    for branch_id in unique_branches:
        branch_data = df[df['Branch ID'] == branch_id]
        branch_name = branch_name_map.get(branch_id, f"Branch {branch_id}")

        fig.add_trace(go.Scatter3d(
            x=branch_data['Px'],
            y=branch_data['Py'],
            z=branch_data['Pz'],
            mode='markers',
            marker=dict(size=3, color=color_map[branch_id], colorscale='viridis'),
            name=branch_name
        ))

    # Configure layout
    fig.update_layout(
        title='Segmented Coronary Artery Visualization',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'  # Maintain aspect ratio
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800  # Set figure height
    )

    fig.show()


In [None]:
# information retrieved with Mimics
def clean_dataframe(df):

    columns_to_drop = [
        'Branch ID',
        'Tx', 'Ty', 'Tz',
        'C',
        'Nx', 'Ny', 'Nz',  # Normal vector coordinates
        'BNx', 'BNy', 'BNz',  # Binormal vector coordinates
        'Dfit', 'Dmin', 'Dmax',  # Circle diameters
        'Dh', 'Xh',  # Hydraulic properties
        'Scf', 'Area', 'E'  # Sectional properties and ellipticity
    ]

    # Drop columns
    df = df.drop(columns=columns_to_drop, errors='ignore')

    # Drop duplicate rows
    df = df.drop_duplicates()

    return df

In [None]:
def calculate_mst(df):

    # Extract the coordinates of the points
    positions = df[['Px', 'Py', 'Pz']].values

    # Calculate the Euclidean distance matrix between all points
    dist_matrix = distance_matrix(positions, positions)

    # Create a graph with all nodes and distances as weights
    G = nx.Graph()
    num_nodes = len(positions)
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            G.add_edge(i, j, weight=dist_matrix[i, j])

    # Get the Minimum Spanning Tree (MST)
    MST = nx.minimum_spanning_tree(G)

    return MST

In [None]:
def visualize_mst(df, MST):

    # Extract the positions of the nodes
    positions = df[['Px', 'Py', 'Pz']].values

    # Get the positions for the nodes
    pos = {i: positions[i] for i in range(len(positions))}

    # Create lists for node and edge coordinates
    node_x, node_y, node_z = zip(*positions)

    edge_x = []
    edge_y = []
    edge_z = []
    for edge in MST.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    # Create the interactive visualization with Plotly
    fig = go.Figure()

    # Add edges (aristas)
    fig.add_trace(go.Scatter3d(
        x=edge_x, y=edge_y, z=edge_z,
        mode='lines',
        line=dict(color='blue', width=2),
        name='Edges'
    ))

    # Add nodes (nodos)
    fig.add_trace(go.Scatter3d(
        x=node_x, y=node_y, z=node_z,
        mode='markers',
        marker=dict(size=3, color='red'),
        name='Nodes'
    ))

    # Configure the layout for the visualization
    fig.update_layout(
        title='Minimum Spanning Tree (MST)',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'  # Ensure correct proportions between axes
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800  # Increase the plot size
    )

    # Display the plot
    fig.show()

In [None]:
def visualize_segments(df):
    """
    Visualizes the segmented branches of the artery using a 3D scatter plot.
    Each segment is color-coded based on its assigned Branch ID.

    Parameters:
    df (pd.DataFrame): DataFrame containing artery segments with 'Px', 'Py', 'Pz', and 'Branch ID'.
    """
    # Get unique branch IDs
    unique_branches = df['Branch ID'].unique()
    colors = np.linspace(0, 1, len(unique_branches))  # Generate color indices
    color_map = dict(zip(unique_branches, colors))  # Map each branch to a color

    # Create figure
    fig = go.Figure()

    for branch_id in unique_branches:
        branch_data = df[df['Branch ID'] == branch_id]

        fig.add_trace(go.Scatter3d(
            x=branch_data['Px'],
            y=branch_data['Py'],
            z=branch_data['Pz'],
            mode='markers',
            marker=dict(size=3, color=color_map[branch_id], colorscale='viridis'),
            name=f'Branch {branch_id}'
        ))

    # Configure layout
    fig.update_layout(
        title='Segmented Coronary Artery Visualization',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'  # Maintain aspect ratio
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800  # Set figure height
    )

    fig.show()

In [None]:
def labeling_rca(MST, df):
  # Asegura que los índices del dataframe son consecutivos y planos
  df = df.reset_index(drop=True)
  df.index.name = None

  # Mapeo de nodos del grafo a índices del dataframe
  mapping = {node: node for node in MST.nodes if node in df.index}

  # Paso 1: identificar los nodos hoja
  leaf_nodes = [node for node in MST.nodes if MST.degree(node) == 1]

  # Paso 2: nodo hoja con mayor coordenada Pz
  start_node = max(leaf_nodes, key=lambda n: df.loc[mapping[n], 'Pz'])
  print("star node:", start_node)

  # Paso 3: identificar nodos de bifurcación
  bifurcations = [node for node in MST.nodes if MST.degree(node) > 2]

  # Paso 4: BFS para llegar a la primera bifurcación
  visited = set()
  path = []
  queue = [(start_node, [start_node])]
  bifurcation_node = None

  while queue:
      current_node, current_path = queue.pop(0)
      if current_node in visited:
          continue
      visited.add(current_node)
      if MST.degree(current_node) > 2:
          bifurcation_node = current_node
          path = current_path
          break
      for neighbor in MST.neighbors(current_node):
          if neighbor not in visited:
              queue.append((neighbor, current_path + [neighbor]))

  if not path or (bifurcation_node == None):
      raise ValueError("No bifurcation found in the graph starting from highest Pz leaf.")

  # Paso 5: etiquetar el segmento previo a la bifurcación como 1, 2, 3 (tercios)
  labels = np.zeros(len(df), dtype=int)
  path_len = len(path)
  split1 = path_len // 3
  split2 = 2 * path_len // 3

  for i, node in enumerate(path):
      if i < split1:
          labels[mapping[node]] = 1
      elif i < split2:
          labels[mapping[node]] = 2
      else:
          labels[mapping[node]] = 3

  # Paso 6: elegir rama con menor Pz como PDA (segmento 4), la otra como PL (segmento 16)
  parent = path[-2] if len(path) >= 2 else None
  children = [n for n in MST.neighbors(bifurcation_node) if n != parent]

  if len(children) != 2:
      raise ValueError("Expected exactly two children after bifurcation.")

  # Comparar Pz de los nodos hijos para asignar PDA y PL
  child_Pz = [(child, df.loc[mapping[child], 'Pz']) for child in children]
  child_Pz.sort(key=lambda x: x[1])  # menor Pz → más distal → PDA

  # 4: La PDA (posterior descending artery) baja hacia la base del corazón → menor valor de Pz.
  # 16: La PL (posterolateral) no baja tanto, sino que se va más hacia el lateral → mayor valor de Pz.
  segment_map = {child_Pz[0][0]: 4, child_Pz[1][0]: 16}

  for child_node, segment_label in segment_map.items():
      stack = [child_node]
      visited_branch = set([bifurcation_node])  # Para evitar volver hacia atrás

      while stack:
          node = stack.pop()
          if node in visited_branch:
              continue
          visited_branch.add(node)
          labels[mapping[node]] = segment_label
          for neighbor in MST.neighbors(node):
              if neighbor not in visited_branch:
                  stack.append(neighbor)

  # Asignar etiquetas al dataframe
  df['Branch ID'] = labels
  return df

In [None]:
# 2. Función labeling_lad:
# Procesa los segmentos 6, 7, 8, 9 y 10.

def labeling_lca_lad(MST, df, node_bif_5, node_6_start, labels):
    # ------------------ Segment 6 ------------------
    # Se recorre hasta la siguiente bifurcación → node_6_end
    # Todos los nodos hasta ahí → label = 6
    # El nodo de bifurcación se etiqueta como 9
    visited_6 = set([node_bif_5])
    stack = [node_6_start]
    path6 = []
    node_6_end = None

    while stack:
        node = stack.pop()
        if node in visited_6:
            continue
        visited_6.add(node)
        labels[node] = 6
        path6.append(node)
        if MST.degree(node) > 2:
            node_6_end = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_6:
                stack.append(neighbor)

    if node_6_end is None:
        raise ValueError("No bifurcation found after segment 6.")

    labels[node_6_end] = 9
    children_after_6 = [n for n in MST.neighbors(node_6_end) if n not in visited_6]
    if len(children_after_6) != 2:
        raise ValueError("Expected two children after segment 6.")

    # Desde el nodo final del 6 (node_6_end):
    # Se elige hijo con menor Pz → segmento 7
    # Hijo con mayor Pz → segmento 9
    child_pz = {c: df.loc[c, 'Pz'] for c in children_after_6}
    node_7_start = min(child_pz, key=child_pz.get)
    node_9_start = max(child_pz, key=child_pz.get)

    # ------------------ Segment 7 ------------------
    visited_7 = set([node_6_end])
    stack = [node_7_start]
    node_7_end = None
    while stack:
        node = stack.pop()
        if node in visited_7:
            continue
        visited_7.add(node)
        labels[node] = 7
        if MST.degree(node) > 2:
            node_7_end = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_7:
                stack.append(neighbor)

    # ------------------ Segment 9 ------------------
    visited_9 = set([node_6_end])
    stack = [node_9_start]
    while stack:
        node = stack.pop()
        if node in visited_9:
            continue
        visited_9.add(node)
        labels[node] = 9
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_9:
                stack.append(neighbor)

    # Segmentos 8 y 10
    if node_7_end is None:
        raise ValueError("No bifurcation found after segment 7.")

    labels[node_7_end] = 8
    children_after_7 = [n for n in MST.neighbors(node_7_end) if n not in visited_7]
    if len(children_after_7) != 2:
        raise ValueError("Expected two branches after segment 7.")

    # Desde node_7_end:
    # Se comparan Pz de los hijos:
    # Menor Pz → segmento 8
    # Mayor Pz → segmento 10
    child_pz_8 = {c: df.loc[c, 'Pz'] for c in children_after_7}
    node_8_start = min(child_pz_8, key=child_pz_8.get)
    node_10_start = max(child_pz_8, key=child_pz_8.get)

    visited_8 = set([node_7_end])
    stack = [node_8_start]
    while stack:
        node = stack.pop()
        if node in visited_8:
            continue
        visited_8.add(node)
        labels[node] = 8
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_8:
                stack.append(neighbor)

    visited_10 = set([node_7_end])
    stack = [node_10_start]
    while stack:
        node = stack.pop()
        if node in visited_10:
            continue
        visited_10.add(node)
        labels[node] = 10
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_10:
                stack.append(neighbor)

    return labels

In [None]:
# 3. Función labeling_lcx:
# Procesa los segmentos 11, 12, 13 y 17.

def labeling_lca_lcx(MST, df, node_bif_5, node_11_start, labels):
    # ------------------ Segmento 11 ------------------
    # Se recorre desde su bifurcación original (node_11_start) hasta su bifurcación (node_11_bif1)
    visited_11 = set([node_bif_5])
    stack = [node_11_start]
    path11 = []
    node_11_bif1 = None

    while stack:
        node = stack.pop()
        if node in visited_11:
            continue
        visited_11.add(node)
        labels[node] = 11
        path11.append(node)
        if MST.degree(node) > 2:
            node_11_bif1 = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_11:
                stack.append(neighbor)

    if node_11_bif1 is None:
        raise ValueError("No bifurcation found on segment 11.")

    # Desde node_11_bif1:
    # Hijo con mayor Px → segmento 17 (normalmente rama marginal)
    # El otro hijo (node_11_continue) se sigue recorriendo hasta otra bifurcación (node_11_bif2)
    children_bif1 = [n for n in MST.neighbors(node_11_bif1) if n not in visited_11]
    child_x = {c: df.loc[c, 'Px'] for c in children_bif1}
    node_17_start = max(child_x, key=child_x.get)
    node_11_continue = min(child_x, key=child_x.get)

    # ------------------ Segmento 17 ------------------
    # Desde node_11_bif1: Hijo con mayor Px → segmento 17 (normalmente rama marginal)
    visited_17 = set([node_11_bif1])
    stack = [node_17_start]
    while stack:
        node = stack.pop()
        if node in visited_17:
            continue
        visited_17.add(node)
        labels[node] = 17
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_17:
                stack.append(neighbor)

    # Continuación segmento 11
    visited_11_cont = set([node_11_bif1])
    stack = [node_11_continue]
    node_11_bif2 = None
    while stack:
        node = stack.pop()
        if node in visited_11_cont:
            continue
        visited_11_cont.add(node)
        labels[node] = 11
        if MST.degree(node) > 2:
            node_11_bif2 = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_11_cont:
                stack.append(neighbor)

    if node_11_bif2 is None:
        raise ValueError("No second bifurcation found in segment 11.")

    # Bifurcación en 11:
    # Se comparan Px:
    # Mayor Px → segmento 12
    # Menor Px → segmento 13
    children_bif2 = [n for n in MST.neighbors(node_11_bif2) if n not in visited_11_cont]
    child_x2 = {c: df.loc[c, 'Px'] for c in children_bif2}
    node_12_start = max(child_x2, key=child_x2.get)
    node_13_start = min(child_x2, key=child_x2.get)

    # Segmento 12
    visited_12 = set([node_11_bif2])
    stack = [node_12_start]
    while stack:
        node = stack.pop()
        if node in visited_12:
            continue
        visited_12.add(node)
        labels[node] = 12
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_12:
                stack.append(neighbor)

    # Segmento 13
    visited_13 = set([node_11_bif2])
    stack = [node_13_start]
    while stack:
        node = stack.pop()
        if node in visited_13:
            continue
        visited_13.add(node)
        labels[node] = 13
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_13:
                stack.append(neighbor)

    return labels

In [None]:
# 1. Función principal: labeling_lca
# Encargada de recorrer el tronco común (segmento 5) y derivar hacia las funciones de la LAD y la LCX.

def labeling_lca(MST, df):
    # Paso 1: Identificar nodo inicial con mayor Pz (más basal)
    leaf_nodes = [node for node, degree in MST.degree() if degree == 1]
    candidate_starts = [node for node in leaf_nodes if len(list(MST.neighbors(node))) == 1]
    start_node = max(candidate_starts, key=lambda n: df.loc[n, 'Pz'])

    # ------------------ Segment 5 ------------------
    visited = set()
    path5 = []
    queue = [(start_node, [start_node])]
    node_bif_5 = None

    while queue:
        current_node, current_path = queue.pop(0)
        if current_node in visited:
            continue
        visited.add(current_node)
        if MST.degree(current_node) > 2:
            node_bif_5 = current_node
            path5 = current_path
            break
        for neighbor in MST.neighbors(current_node):
            if neighbor not in visited:
                queue.append((neighbor, current_path + [neighbor]))

    if not path5 or not node_bif_5:
        raise ValueError("No bifurcation found from highest Pz leaf.")

    labels = np.zeros(len(df), dtype=int)
    for node in path5:
        # todos los nodes des de la hoja más basal (z más alto) hasta esa bifurcación se etiquetan como 5
        labels[node] = 5

    parent_5 = path5[-2] if len(path5) >= 2 else None
    children_5 = [n for n in MST.neighbors(node_bif_5) if n != parent_5]

    if len(children_5) != 2:
        raise ValueError("Expected exactly two branches after bifurcation 5.")

    # Tras la bifurcación del tronco común (5) se comparan las Py de los hijos:
    # Menor Py → Segmento 6 (LAD)
    # Mayor Py → Segmento 11 (LCx)
    child_coords = {c: df.loc[c, 'Py'] for c in children_5}
    node_6_start = min(child_coords, key=child_coords.get)
    node_11_start = max(child_coords, key=child_coords.get)

    labels = labeling_lca_lad(MST, df, node_bif_5, node_6_start, labels)
    labels = labeling_lca_lcx(MST, df, node_bif_5, node_11_start, labels)

    df['Branch ID'] = labels

    return df

#### **Heart chambers algorithm code**

In [None]:
def load_cavity_points(filepath):
    """
    Carga un archivo .nii.gz y extrae los puntos no nulos como una nube de puntos.

    Args:
        filepath (str): Ruta al archivo .nii.gz de la cavidad.

    Returns:
        np.ndarray: Nube de puntos con las coordenadas (x, y, z).
    """

    nii_data = nib.load(filepath)
    voxel_size = nii_data.header.get_zooms()
    data = nii_data.get_fdata()

    # Extract voxel coordinates and convert to real-world coordinates
    voxel_coords = np.column_stack(np.where(data > 0)).astype(float)
    real_coords = (voxel_coords + 0.5) * voxel_size
    return real_coords

In [None]:
# Mapa de colores para las cavidades
CAVITY_COLORS = {
    "heart_ventricle_right": "red",
    "heart_ventricle_left": "blue",
    "heart_atrium_right": "green",
    "heart_atrium_left": "purple",
}

def visualize_point_cloud(points, cavity_name, title="Point Cloud", sample_size=100000):
    """
    Visualiza una nube de puntos de forma interactiva usando Plotly, asignando colores según la cavidad.

    Args:
        points (np.ndarray): Nube de puntos con coordenadas (x, y, z).
        cavity_name (str): Nombre de la cavidad para asignar un color.
        title (str): Título del gráfico.
        sample_size (int): Máximo número de puntos a visualizar.
    """
    if cavity_name not in CAVITY_COLORS:
        raise ValueError(f"Nombre de cavidad no válido. Opciones: {', '.join(CAVITY_COLORS.keys())}")

    if points.shape[0] > sample_size:
        original_size = points.shape[0]
        indices = np.random.choice(points.shape[0], sample_size, replace=False)
        points = points[indices]
        print(f"Visualizando {sample_size} de {original_size} puntos totales.")

    color = CAVITY_COLORS[cavity_name]  # Asigna el color basado en la cavidad

    # Crear figura de Plotly
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=points[:, 0],
                y=points[:, 1],
                z=points[:, 2],
                mode='markers',
                marker=dict(
                    size=2,
                    color=color,  # Asigna el color único para la cavidad
                    opacity=0.8
                ),
                name=cavity_name  # Nombre para la leyenda
            )
        ]
    )

    # Configurar layout
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis=dict(title="Px", backgroundcolor="white", gridcolor="lightgrey", showbackground=True),
            yaxis=dict(title="Py", backgroundcolor="white", gridcolor="lightgrey", showbackground=True),
            zaxis=dict(title="Pz", backgroundcolor="white", gridcolor="lightgrey", showbackground=True),
            aspectmode="data"  # Mantener proporciones
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800,  # Altura de la figura
        showlegend=True  # Habilitar la leyenda
    )

    # Mostrar figura
    fig.show()


In [None]:
def load_and_visualize_cavity(base_path, cavity_name):
    """
    Carga y visualiza la nube de puntos de una cavidad específica.

    Args:
        base_path (str): Ruta base donde están los archivos de cavidades.
        cavity_name (str): Nombre del archivo de la cavidad (sin extensión).
    """
    filepath = os.path.join(base_path, f"{cavity_name}.nii.gz")
    if not os.path.exists(filepath):
        print(f"El archivo {filepath} no existe.")
        return

    points = load_cavity_points(filepath)
    visualize_point_cloud(points, cavity_name, title=f"Point Cloud: {cavity_name}")

In [None]:
def build_kdtree_from_cavity(filepath):
    """
    Crea un KDTree a partir de una cavidad segmentada.
    """
    points = load_cavity_points(filepath)
    return KDTree(points)

def load_all_cavities_as_kdtrees(base_path, cavities):
    """
    Devuelve un diccionario {nombre_cavidad: KDTree}
    """
    kdtrees = {}
    for cavity in cavities:
        path = os.path.join(base_path, f"{cavity}.nii.gz")
        kdtrees[cavity] = build_kdtree_from_cavity(path)
    return kdtrees

def assign_cavity_distances(df, cavity_kdtrees):
    """
    Añade al DataFrame columnas con las distancias desde cada punto a cada cavidad.

    Args:
        df (pd.DataFrame): Debe contener columnas ['Px', 'Py', 'Pz'].
        cavity_kdtrees (dict): Diccionario {nombre_cavidad: KDTree}

    Returns:
        pd.DataFrame: Mismo DataFrame con columnas nuevas para cada cavidad.
    """
    coords = df[['Px', 'Py', 'Pz']].values

    for cavity_name, tree in cavity_kdtrees.items():
        distances, _ = tree.query(coords)  # Distancias mínimas desde cada punto a la cavidad
        df[f'dist_to_{cavity_name}'] = distances

    return df

In [None]:
def get_apex_from_cavity(filepath):
    """
    Estima el ápex como el punto más inferior y apical (mínimo en Z, luego Y, luego X)
    dentro de la envolvente convexa de la cavidad.

    Args:
        filepath (str): Ruta al archivo .nii.gz del ventrículo.

    Returns:
        np.ndarray: Coordenadas (x, y, z) del ápex estimado.
    """
    points = load_cavity_points(filepath)

    # Calcular la envolvente convexa
    hull = ConvexHull(points)
    hull_points = points[hull.vertices]

    # Encontrar el punto con menor (Z, Y, X)
    apex_point = hull_points[np.lexsort((hull_points[:, 0], hull_points[:, 1], hull_points[:, 2]))][0]

    return apex_point

In [None]:
def assign_distance_to_apex(df, apex_point, label="ventricle_right"):
    """
    Añade una columna con la distancia desde cada punto al ápex estimado y una columna con el punto del ápex.

    Args:
        df (pd.DataFrame): Debe contener ['Px', 'Py', 'Pz'].
        apex_point (np.ndarray): Coordenadas del ápex (x, y, z).
        label (str): Para nombrar las columnas.

    Returns:
        pd.DataFrame con columnas `dist_to_apex_<label>` y `apex_point_<label>`.
    """
    coords = df[['Px', 'Py', 'Pz']].values
    distances = np.linalg.norm(coords - apex_point, axis=1)
    df[f'dist_to_apex_{label}'] = distances
    # df[f'apex_point_{label}'] = [tuple(apex_point)] * len(df)  # Agrega el ápex como una columna constante
    return df

In [None]:
def labeling_rca_chambers(MST, df):
    df = df.reset_index(drop=True)
    df.index.name = None
    mapping = {node: node for node in MST.nodes if node in df.index}

    leaf_nodes = [node for node in MST.nodes if MST.degree(node) == 1]
    start_node = max(leaf_nodes, key=lambda n: df.loc[mapping[n], 'Pz'])
    print("start node:", start_node)

    bifurcations = [node for node in MST.nodes if MST.degree(node) > 2]

    visited = set()
    path = []
    queue = [(start_node, [start_node])]
    bifurcation_node = None

    while queue:
        current_node, current_path = queue.pop(0)
        if current_node in visited:
            continue
        visited.add(current_node)
        if MST.degree(current_node) > 2:
            bifurcation_node = current_node
            path = current_path
            break
        for neighbor in MST.neighbors(current_node):
            if neighbor not in visited:
                queue.append((neighbor, current_path + [neighbor]))

    if not path or bifurcation_node is None:
        raise ValueError("No bifurcation found in the graph.")

    labels = np.zeros(len(df), dtype=int)
    path_len = len(path)
    split1 = path_len // 3
    split2 = 2 * path_len // 3

    for i, node in enumerate(path):
        if i < split1:
            labels[mapping[node]] = 1
        elif i < split2:
            labels[mapping[node]] = 2
        else:
            labels[mapping[node]] = 3

    parent = path[-2] if len(path) >= 2 else None
    children = [n for n in MST.neighbors(bifurcation_node) if n != parent]

    if len(children) != 2:
        raise ValueError("Expected exactly two children after bifurcation.")

    # Evaluar distancias combinadas a cavidades y apex
    child_scores = []
    for child in children:
        row = df.loc[mapping[child]]
        score_pda = row['dist_to_apex_ventricle_right']  # << más importante para el 4
        score_pl = row['dist_to_heart_ventricle_left'] + row['dist_to_heart_atrium_left']
        child_scores.append((child, score_pda, score_pl))

    # Asignar: menor distancia al ápex → 4, menor a VI + AI → 16
    if child_scores[0][1] < child_scores[1][1]:
        pda_node = child_scores[0][0]
        pl_node = child_scores[1][0]
    else:
        pda_node = child_scores[1][0]
        pl_node = child_scores[0][0]

    segment_map = {pda_node: 4, pl_node: 16}

    for child_node, segment_label in segment_map.items():
        stack = [child_node]
        visited_branch = set([bifurcation_node])

        while stack:
            node = stack.pop()
            if node in visited_branch:
                continue
            visited_branch.add(node)
            labels[mapping[node]] = segment_label
            for neighbor in MST.neighbors(node):
                if neighbor not in visited_branch:
                    stack.append(neighbor)

    df['Branch ID'] = labels
    return df

In [None]:
def visualize_segments_cavities_apex(df, cavity_points_dict, apex_point_right=None, apex_point_left=None, sample_size=10000):
    """
    Visualizes segmented coronary arteries, heart cavity point clouds, and apex points (left and right) in a single 3D plot.

    Parameters:
    - df (pd.DataFrame): DataFrame with artery points, including 'Px', 'Py', 'Pz', and 'Branch ID'.
    - cavity_points_dict (dict): Dictionary mapping cavity names to point clouds (np.ndarray of shape Nx3).
    - apex_point_right (np.ndarray or None): Apex of the right ventricle (x, y, z).
    - apex_point_left (np.ndarray or None): Apex of the left ventricle (x, y, z).
    - sample_size (int): Max number of cavity points to plot per cavity.
    """
    import plotly.graph_objects as go
    import numpy as np

    # --- Segment colour setup ---
    branch_name_map = {
      1: "1 - pRCA",
      2: "2 - mRCA",
      3: "3 - dRCA",
      4: "4 - R-PDA",
      5: "5 - LM",
      6: "6 - pLAD",
      7: "7 - mLAD",
      8: "8 - dLAD",
      9: "9 - D1",
      10: "10 - D2",
      11: "11 - pCx",
      12: "12 - OM1",
      13: "13 - LCx",
      14: "14 - OM2",
      15: "15 - L-PDA",
      16: "16 - R-PLB",
      17: "17 - RI",
      18: "18 - L-PLB"
    }
    unique_branches = df['Branch ID'].unique()
    colors = np.linspace(0, 1, len(unique_branches))
    color_map = dict(zip(unique_branches, colors))

    fig = go.Figure()

    # --- Plot arteries ---
    for branch_id in unique_branches:
        branch_data = df[df['Branch ID'] == branch_id]
        branch_name = branch_name_map.get(branch_id, f"Branch {branch_id}")
        fig.add_trace(go.Scatter3d(
            x=branch_data['Px'],
            y=branch_data['Py'],
            z=branch_data['Pz'],
            mode='markers',
            marker=dict(size=3, color=color_map[branch_id], colorscale='viridis'),
            name=branch_name
        ))

    # --- Plot cavity point clouds ---
    for cavity_name, points in cavity_points_dict.items():
        if cavity_name not in CAVITY_COLORS:
            print(f"Skipping unknown cavity: {cavity_name}")
            continue
        if points.shape[0] > sample_size:
            indices = np.random.choice(points.shape[0], sample_size, replace=False)
            points = points[indices]
        color = CAVITY_COLORS[cavity_name]
        fig.add_trace(go.Scatter3d(
            x=points[:, 0],
            y=points[:, 1],
            z=points[:, 2],
            mode='markers',
            marker=dict(size=2, color=color, opacity=0.6),
            name=cavity_name
        ))

    # --- Plot apex points ---
    if apex_point_right is not None:
        fig.add_trace(go.Scatter3d(
            x=[apex_point_right[0]],
            y=[apex_point_right[1]],
            z=[apex_point_right[2]],
            mode='markers+text',
            marker=dict(size=10, color='black', symbol='cross'),
            name='Apex Right Ventricle',
            text=["Apex Right"],
            textposition='top center'
        ))

    if apex_point_left is not None:
        fig.add_trace(go.Scatter3d(
            x=[apex_point_left[0]],
            y=[apex_point_left[1]],
            z=[apex_point_left[2]],
            mode='markers+text',
            marker=dict(size=10, color='red', symbol='cross'),
            name='Apex Left Ventricle',
            text=["Apex Left"],
            textposition='top center'
        ))

    fig.update_layout(
        title='Coronary Artery Segments with Heart Cavities and Apex Points',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800,
        showlegend=True
    )

    fig.show()


###### LAD

In [None]:
def labeling_lca_lad(MST, df, node_bif_5, node_6_start, labels):
    # ------------------ Segmento 6 ------------------
    # Se recorre hasta la siguiente bifurcación → node_6_end
    # Todos los nodos hasta ahí → label = 6
    # El nodo de bifurcación se etiqueta como 9 (inicio de la siguiente bifurcación)
    visited_6 = set([node_bif_5])
    stack = [node_6_start]
    path6 = []
    node_6_end = None

    while stack:
        node = stack.pop()
        if node in visited_6:
            continue
        visited_6.add(node)
        labels[node] = 6
        path6.append(node)
        if MST.degree(node) > 2:
            node_6_end = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_6:
                stack.append(neighbor)

    if node_6_end is None:
        raise ValueError("No bifurcation found after segment 6.")

    labels[node_6_end] = 9
    children_after_6 = [n for n in MST.neighbors(node_6_end) if n not in visited_6]
    if len(children_after_6) != 2:
        raise ValueError("Expected two children after segment 6.")

    # ------------------ Selección basada en distancias anatómicas ------------------
    # Usamos distancia al ventrículo derecho e izquierdo para elegir ramas
    child_metrics = {
        c: {
            'ventricle_right': df.loc[c, 'dist_to_heart_ventricle_right'],
            'ventricle_left': df.loc[c, 'dist_to_heart_ventricle_left']
        }
        for c in children_after_6
    }

    node_7_start = min(child_metrics, key=lambda c: child_metrics[c]['ventricle_right'])
    remaining = [c for c in children_after_6 if c != node_7_start]
    if not remaining:
        raise ValueError("Could not find a distinct node for segment 9.")
    node_9_start = min(remaining, key=lambda c: child_metrics[c]['ventricle_left'])

    # ------------------ Segmento 7 ------------------
    visited_7 = set([node_6_end])
    stack = [node_7_start]
    path7 = []
    node_7_end = None
    while stack:
        node = stack.pop()
        if node in visited_7:
            continue
        visited_7.add(node)
        labels[node] = 7
        path7.append(node)
        if MST.degree(node) > 2:
            node_7_end = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_7:
                stack.append(neighbor)

    # ------------------ Segmento 9 ------------------
    visited_9 = set([node_6_end])
    stack = [node_9_start]
    while stack:
        node = stack.pop()
        if node in visited_9:
            continue
        visited_9.add(node)
        labels[node] = 9
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_9:
                stack.append(neighbor)

    # ------------------ Segmentos 8 y 10 ------------------
    if node_7_end is None:
        # Caso especial: no hay bifurcación después del segmento 7
        mid_point = len(path7) // 2
        for node in path7[:mid_point]:
            labels[node] = 7
        for node in path7[mid_point:]:
            labels[node] = 8
    else:
        labels[node_7_end] = 8
        children_after_7 = [n for n in MST.neighbors(node_7_end) if n not in visited_7]
        if len(children_after_7) != 2:
            raise ValueError("Expected two branches after segment 7.")

        # Distancias anatómicas para elegir entre segmento 8 y 10
        child_metrics_7 = {
            c: {
                'ventricle_right': df.loc[c, 'dist_to_heart_ventricle_right'],
                'ventricle_left': df.loc[c, 'dist_to_heart_ventricle_left']
            }
            for c in children_after_7
        }

        node_8_start = min(child_metrics_7, key=lambda c: child_metrics_7[c]['ventricle_right'])
        remaining_8 = [c for c in children_after_7 if c != node_8_start]
        if not remaining_8:
            raise ValueError("Could not find a distinct node for segment 10.")
        node_10_start = min(remaining_8, key=lambda c: child_metrics_7[c]['ventricle_left'])

        visited_8 = set([node_7_end])
        stack = [node_8_start]
        while stack:
            node = stack.pop()
            if node in visited_8:
                continue
            visited_8.add(node)
            labels[node] = 8
            for neighbor in MST.neighbors(node):
                if neighbor not in visited_8:
                    stack.append(neighbor)

        visited_10 = set([node_7_end])
        stack = [node_10_start]
        while stack:
            node = stack.pop()
            if node in visited_10:
                continue
            visited_10.add(node)
            labels[node] = 10
            for neighbor in MST.neighbors(node):
                if neighbor not in visited_10:
                    stack.append(neighbor)

    return labels

###### LCX

In [None]:
# El segmento 17 será la rama con mayor Px solo si está a menos de 10 nodos del inicio del segmento 11; si no, esa rama será el 12.
## FUNCIONA: té en compte min nodes pel 17 i els chambers pel 17 i 11, 12 amb 13 (no cas q hi hagi14)
def labeling_lca_lcx(MST, df, node_bif_5, node_11_start, labels):
    # ------------------ Segmento 11 ------------------
    # Se recorre desde su bifurcación original (node_11_start) hasta su bifurcación (node_11_bif1)
    visited_11 = set([node_bif_5])
    stack = [node_11_start]
    path11 = []
    node_11_bif1 = None

    while stack:
        node = stack.pop()
        if node in visited_11:
            continue
        visited_11.add(node)
        labels[node] = 11
        path11.append(node)
        if MST.degree(node) > 2:
            node_11_bif1 = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_11:
                stack.append(neighbor)

    if node_11_bif1 is None:
        raise ValueError("No bifurcation found on segment 11.")

    children_bif1 = [n for n in MST.neighbors(node_11_bif1) if n not in visited_11]

    # ---------- Lógica basada en distancias anatómicas ----------
    child_metrics = {
        c: {
            'ventricle_right': df.loc[c, 'dist_to_heart_ventricle_right'],
            'ventricle_left': df.loc[c, 'dist_to_heart_ventricle_left'],
            'atrium_left': df.loc[c, 'dist_to_heart_atrium_left']
        }
        for c in children_bif1
    }

    # Calculamos la puntuación de cada hijo para el segmento 17
    scores_17 = {
        c: (
            child_metrics[c]['ventricle_right'] +
            child_metrics[c]['ventricle_left'] -
            child_metrics[c]['atrium_left']
        )
        for c in children_bif1
    }
    # El segmento 17 se asigna a la rama más cercana al ventrículo izquierdo y derecho y más lejana a la aurícula izquierda.
    cand_17 = min(scores_17, key=scores_17.get)

    # Criterio para segmento 11: más cerca de la aurícula izquierda
    remaining = [c for c in children_bif1 if c != cand_17]
    if not remaining:
        raise ValueError("Could not find a distinct node for segment 11.")

    # El segmento 11 se asigna a la rama más cercana a la aurícula izquierda.
    cand_11cont = min(remaining, key=lambda c: child_metrics[c]['atrium_left'])
    # -------------------------------------------------------------------

    # Verificamos la distancia entre node_11_start y cand_17
    dist_nod = nx.shortest_path_length(MST, source=node_11_start, target=cand_17)
    if dist_nod < 10:
        node_17_start = cand_17
        node_11_continue = cand_11cont
    else:
        # TO DO: en caso de no haber segmento 17, la primera bifurcación del segmento 11 sera el 12 y 13
        node_17_start = None  # no hay segmento 17
        node_11_continue = cand_17  # el de mayor score ahora se considera continuación de 11

    # ------------------ Segmento 17 ------------------
    # Desde node_11_bif1: Hijo con mayor Px → segmento 17 (normalmente rama marginal)
    if node_17_start is not None:
        visited_17 = set([node_11_bif1])
        stack = [node_17_start]
        while stack:
            node = stack.pop()
            if node in visited_17:
                continue
            visited_17.add(node)
            labels[node] = 17
            for neighbor in MST.neighbors(node):
                if neighbor not in visited_17:
                    stack.append(neighbor)

    # Continuación segmento 11: La otra rama (que no fue el 17) se sigue desde
    # node_11_bif1 hasta que se encuentra una segunda bifurcación (node_11_bif2):
    visited_11_cont = set([node_11_bif1])
    stack = [node_11_continue]
    node_11_bif2 = None
    while stack:
        node = stack.pop()
        if node in visited_11_cont:
            continue
        visited_11_cont.add(node)
        labels[node] = 11
        if MST.degree(node) > 2:
            node_11_bif2 = node
            break
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_11_cont:
                stack.append(neighbor)

    if node_11_bif2 is None:
        raise ValueError("No second bifurcation found in segment 11.")

    # Bifurcación en 11:
    # Segmento 12 → más cerca del ápice del ventrículo izquierdo
    # Segmento 13 → más cerca de la aurícula izquierda
    children_bif2 = [n for n in MST.neighbors(node_11_bif2) if n not in visited_11_cont]
    child_distances = {
        c: {
            'apex_left': df.loc[c, 'dist_to_apex_ventricle_left'],
            'atrium_left': df.loc[c, 'dist_to_heart_atrium_left']
        }
        for c in children_bif2
    }


    # Segmento 13 → más cerca de la aurícula izquierda (entre los nodos restantes)
    node_13_start = min(children_bif2, key=lambda c: child_distances[c]['atrium_left'])

    # Eliminar el nodo ya asignado a 13 antes de buscar el más cercano al ápice del ventrículo izquierdo
    remaining_children = [c for c in children_bif2 if c != node_13_start]

    # Segmento 12 → más cerca del ápice del ventrículo izquierdo
    node_12_start = min(remaining_children, key=lambda c: child_distances[c]['apex_left'])

    # Segmento 12
    visited_12 = set([node_11_bif2])
    stack = [node_12_start]
    while stack:
        node = stack.pop()
        if node in visited_12:
            continue
        visited_12.add(node)
        labels[node] = 12
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_12:
                stack.append(neighbor)

    # Segmento 13
    visited_13 = set([node_11_bif2])
    stack = [node_13_start]
    while stack:
        node = stack.pop()
        if node in visited_13:
            continue
        visited_13.add(node)
        labels[node] = 13
        for neighbor in MST.neighbors(node):
            if neighbor not in visited_13:
                stack.append(neighbor)

    return labels


###### LCA

In [None]:
# 1. Función principal: labeling_lca
# Encargada de recorrer el tronco común (segmento 5) y derivar hacia las funciones de la LAD y la LCX.

def labeling_lca(MST, df):
    # Paso 1: Identificar nodo inicial con mayor Pz (más basal)
    leaf_nodes = [node for node, degree in MST.degree() if degree == 1]
    candidate_starts = [node for node in leaf_nodes if len(list(MST.neighbors(node))) == 1]
    start_node = max(candidate_starts, key=lambda n: df.loc[n, 'Pz'])

    # ------------------ Segment 5 ------------------
    visited = set()
    path5 = []
    queue = [(start_node, [start_node])]
    node_bif_5 = None

    while queue:
        current_node, current_path = queue.pop(0)
        if current_node in visited:
            continue
        visited.add(current_node)
        if MST.degree(current_node) > 2:
            node_bif_5 = current_node
            path5 = current_path
            break
        for neighbor in MST.neighbors(current_node):
            if neighbor not in visited:
                queue.append((neighbor, current_path + [neighbor]))

    if not path5 or not node_bif_5:
        raise ValueError("No bifurcation found from highest Pz leaf.")

    labels = np.zeros(len(df), dtype=int)
    for node in path5:
        # todos los nodes des de la hoja más basal (z más alto) hasta esa bifurcación se etiquetan como 5
        labels[node] = 5

    parent_5 = path5[-2] if len(path5) >= 2 else None
    children_5 = [n for n in MST.neighbors(node_bif_5) if n != parent_5]

    if len(children_5) != 2:
        raise ValueError("Expected exactly two branches after bifurcation 5.")

    # Tras la bifurcación del tronco común (5) se comparan las distancias anatómicas:
    # Menor distancia al ventrículo derecho → Segmento 6 (LAD)
    # Menor distancia a la aurícula izquierda → Segmento 11 (LCx)

    # Construimos el diccionario con las distancias anatómicas para cada hijo
    child_metrics = {
        c: {
            'ventricle_right': df.loc[c, 'dist_to_heart_ventricle_right'],
            'atrium_left': df.loc[c, 'dist_to_heart_atrium_left']
        }
        for c in children_5
    }

    # Elegir nodo más cercano al ventrículo derecho → LCx (segmento 11)
    node_11_start = min(child_metrics, key=lambda c: child_metrics[c]['ventricle_right'])

    # El otro nodo, el más cercano a la aurícula izquierda → LAD (segmento 6)
    remaining_nodes = [c for c in children_5 if c != node_11_start]
    node_6_start = min(remaining_nodes, key=lambda c: child_metrics[c]['atrium_left'])

    labels = labeling_lca_lad(MST, df, node_bif_5, node_6_start, labels)
    labels = labeling_lca_lcx(MST, df, node_bif_5, node_11_start, labels)

    df['Branch ID'] = labels

    return df

#### **New visualization**

In [None]:
SEGMENT_COLORS = {
    1: '#006400',   # pRCA → dark green
    2: '#FFFF00',   # mRCA → yellow
    3: '#8B4513',   # dRCA → brown
    4: '#0000FF',   # R-PDA → blue
    5: '#FF0000',   # LM → red
    6: '#FFA500',   # pLAD → orange
    7: '#6B8E23',   # mLAD → olive drab (apagado, terroso)
    8: '#A52A2A',   # dLAD → reddish brown
    9: '#7CFC00',   # D1 → lime green / neon green
    10: '#2F1B0C',  # D2 → very dark brown
    11: '#F5F5DC',  # pLCx → beige
    12: '#DFFF00',  # OM1 → lemon/lime yellow
    13: '#C8A2C8',  # LCx → lilac
    14: '#FFFFE0',  # OM2 → very light yellow
    15: '#DAA520',  # L-PDA → goldenrod
    16: '#CCFF00',  # R-PLB → phosphorescent yellow
    17: '#00008B',  # RI → dark blue
    18: '#B22222',  # L-PLB → firebrick red
    19: '#FF4500'   # tertiary → reddish orange
}

CAVITY_COLORS = {
    "heart_ventricle_right": "#FF9999",   # pastel red
    "heart_ventricle_left": "#99CCFF",    # pastel blue
    "heart_atrium_right": "#90EE90",      # light green
    "heart_atrium_left": "#D8BFD8"        # thistle / light purple
}


CAVITY_NAMES = {
    "heart_ventricle_right": "Right Ventricle",
    "heart_ventricle_left": "Left Ventricle",
    "heart_atrium_right": "Right Atrium",
    "heart_atrium_left": "Left Atrium",
}

In [None]:
def visualize_segments_names(df):
    """
    Visualizes the segmented branches of the artery using a 3D scatter plot.
    Each segment is color-coded based on its assigned Branch ID and labeled with its proper name.
    """
    branch_name_map = {
        1: "1 - pRCA", 2: "2 - mRCA", 3: "3 - dRCA", 4: "4 - R-PDA", 5: "5 - LM",
        6: "6 - pLAD", 7: "7 - mLAD", 8: "8 - dLAD", 9: "9 - D1", 10: "10 - D2",
        11: "11 - pLCx", 12: "12 - OM1", 13: "13 - LCx", 14: "14 - OM2", 15: "15 - L-PDA",
        16: "16 - R-PLB", 17: "17 - RI", 18: "18 - L-PLB"
    }

    fig = go.Figure()

    for branch_id in df['Branch ID'].unique():
        branch_data = df[df['Branch ID'] == branch_id]
        branch_name = branch_name_map.get(branch_id, f"Branch {branch_id}")
        color = SEGMENT_COLORS.get(branch_id, 'grey')

        fig.add_trace(go.Scatter3d(
            x=branch_data['Px'],
            y=branch_data['Py'],
            z=branch_data['Pz'],
            mode='markers',
            marker=dict(size=14, color=color), # posar a 4 peque, 12 gran
            name=branch_name
        ))

    fig.update_layout(
        title='Segmented Coronary Artery Visualization',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800
    )

    fig.show()

In [None]:
def visualize_segments_cavities_apex(df, cavity_points_dict, apex_point_right=None, apex_point_left=None, sample_size=10000):
    """
    Visualizes segmented coronary arteries, heart cavity point clouds, and apex points (left and right) in a single 3D plot.
    """
    fig = go.Figure()

    # --- Plot arteries ---
    for branch_id in df['Branch ID'].unique():
        branch_data = df[df['Branch ID'] == branch_id]
        branch_name = f"{branch_id} - {branch_data['Branch ID'].iloc[0]}"
        color = SEGMENT_COLORS.get(branch_id, 'grey')

        fig.add_trace(go.Scatter3d(
            x=branch_data['Px'],
            y=branch_data['Py'],
            z=branch_data['Pz'],
            mode='markers',
            marker=dict(size=4, color=color),
            name=branch_name
        ))

    # --- Plot cavities ---
    for cavity_name, points in cavity_points_dict.items():
        if cavity_name not in CAVITY_COLORS:
            print(f"Skipping unknown cavity: {cavity_name}")
            continue
        if points.shape[0] > sample_size:
            indices = np.random.choice(points.shape[0], sample_size, replace=False)
            points = points[indices]
        display_name = CAVITY_NAMES.get(cavity_name, cavity_name)
        fig.add_trace(go.Scatter3d(
            x=points[:, 0],
            y=points[:, 1],
            z=points[:, 2],
            mode='markers',
            marker=dict(size=2, color=CAVITY_COLORS[cavity_name], opacity=0.6),
            name=display_name
        ))

    # --- Apex points ---
    if apex_point_right is not None:
        fig.add_trace(go.Scatter3d(
            x=[apex_point_right[0]],
            y=[apex_point_right[1]],
            z=[apex_point_right[2]],
            mode='markers+text',
            marker=dict(size=20, color='black', symbol='cross'),
            name='Apex Right Ventricle',
            text=["Apex Right"],
            textposition='top center'
        ))

    if apex_point_left is not None:
        fig.add_trace(go.Scatter3d(
            x=[apex_point_left[0]],
            y=[apex_point_left[1]],
            z=[apex_point_left[2]],
            mode='markers+text',
            marker=dict(size=20, color='red', symbol='cross'),
            name='Apex Left Ventricle',
            text=["Apex Left"],
            textposition='top center'
        ))

    fig.update_layout(
        title='Coronary Artery Segments with Heart Cavities and Apex Points',
        scene=dict(
            xaxis=dict(title='Px', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(title='Py', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(title='Pz', backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            aspectmode='data'
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800,
        showlegend=True
    )

    fig.show()

## **NORMAL 1**

### **Execution - RCA Normal 1**

##### **Load and clean the dataset**

In [None]:
df_rca_n1 = parse_and_export_centerlines(file_path_rca_n1)
df_rca_n1.head()

Unnamed: 0,Branch ID,Px,Py,Pz,Tx,Ty,Tz,Nx,Ny,Nz,...,BNz,Dfit,Dmin,Dmax,C,Dh,Xh,Scf,Area,E
0,16,150.4411,124.7483,28.0525,-0.841,0.5269,-0.1234,0.1021,0.3784,0.92,...,-0.372,0.4566,0.1851,0.7793,,0.3867,0.4962,1.5989,0.1546,0.6182
1,16,149.9095,125.084,27.9798,-0.8389,0.5338,-0.1063,0.1176,0.3686,0.9221,...,-0.372,0.9324,0.2362,1.5503,0.0292,0.8461,0.5458,3.1222,0.6604,0.8026
2,16,149.3632,125.4346,27.9166,-0.8358,0.5421,-0.0866,0.164,0.3971,0.903,...,-0.4208,1.2609,0.5736,2.0474,0.0373,1.1644,0.5687,4.1918,1.2202,0.6242
3,16,148.7591,125.832,27.8633,-0.8298,0.5549,-0.059,0.22,0.4225,0.8792,...,-0.4727,1.1607,0.8701,1.4998,0.0482,1.0757,0.7172,3.8404,1.0327,0.7187
4,16,148.1571,126.2426,27.8321,-0.8196,0.5723,-0.0249,0.2893,0.451,0.8443,...,-0.5352,1.161,0.6509,1.5376,0.0604,1.0726,0.6976,3.8314,1.0274,0.7226


In [None]:
df_rca_n1 = clean_dataframe(df_rca_n1)
df_rca_n1

Unnamed: 0,Px,Py,Pz
0,150.4411,124.7483,28.0525
1,149.9095,125.0840,27.9798
2,149.3632,125.4346,27.9166
3,148.7591,125.8320,27.8633
4,148.1571,126.2426,27.8321
...,...,...,...
428,88.6466,103.6300,25.5747
429,88.3045,103.5971,25.8131
430,87.9935,103.5645,26.0413
431,87.7353,103.5360,26.2380


##### **Minimmum Spannig Tree Graph**

In [None]:
MST_rca_n1 = calculate_mst(df_rca_n1)

In [None]:
visualize_mst(df_rca_n1, MST_rca_n1)

##### **Heart Chambers**

In [None]:
# load_and_visualize_cavity(path_chamb_n1, cavities[0])

In [None]:
# 1. Cargar KDTrees
cavity_kdtrees_rca_n1 = load_all_cavities_as_kdtrees(path_chamb_n1, cavities)

# 2. Calcular distancias a cavidades
df_rca_n1_chamb_info = df_rca_n1.copy()
df_rca_n1_chamb_info = assign_cavity_distances(df_rca_n1_chamb_info, cavity_kdtrees_rca_n1)

# 3. Estimar ápex
# right vntricle
ventricle_right_path = os.path.join(path_chamb_n1, "heart_ventricle_right.nii.gz")
apex_point_right_n1 = get_apex_from_cavity(ventricle_right_path)

# left vntricle
ventricle_left_path = os.path.join(path_chamb_n1, "heart_ventricle_left.nii.gz")
apex_point_left_n1 = get_apex_from_cavity(ventricle_left_path)

# 4. Calcular distancia al ápex
df_rca_n1_chamb_info = assign_distance_to_apex(df_rca_n1_chamb_info, apex_point_right_n1, "ventricle_right")
df_rca_n1_chamb_info = assign_distance_to_apex(df_rca_n1_chamb_info, apex_point_left_n1, "ventricle_left")


df_rca_n1_chamb_info

Unnamed: 0,Px,Py,Pz,dist_to_heart_ventricle_right,dist_to_heart_ventricle_left,dist_to_heart_atrium_right,dist_to_heart_atrium_left,dist_to_apex_ventricle_right,dist_to_apex_ventricle_left
0,150.4411,124.7483,28.0525,36.571322,8.558464,56.738653,46.315361,60.512694,41.531175
1,149.9095,125.0840,27.9798,36.407430,8.601129,56.295989,45.893261,60.671469,41.890470
2,149.3632,125.4346,27.9166,36.254903,8.660735,55.842215,45.455006,60.848865,42.273650
3,148.7591,125.8320,27.8633,36.107880,8.706975,55.341833,44.953634,61.065308,42.716953
4,148.1571,126.2426,27.8321,35.988210,8.765094,54.841469,44.437114,61.307315,43.183802
...,...,...,...,...,...,...,...,...,...
428,88.6466,103.6300,25.5747,4.697415,31.072905,6.973928,38.836652,58.810214,67.738877
429,88.3045,103.5971,25.8131,4.761104,31.190374,6.597201,38.768606,59.099982,68.070387
430,87.9935,103.5645,26.0413,4.744826,31.296010,6.247699,38.703464,59.365319,68.372559
431,87.7353,103.5360,26.2380,4.750161,31.383481,5.954022,38.647639,59.587169,68.624159


In [None]:
df_rca_n1_labeled_cham = labeling_rca_chambers(MST_rca_n1, df_rca_n1_chamb_info)
df_rca_n1_labeled_cham

start node: 328


Unnamed: 0,Px,Py,Pz,dist_to_heart_ventricle_right,dist_to_heart_ventricle_left,dist_to_heart_atrium_right,dist_to_heart_atrium_left,dist_to_apex_ventricle_right,dist_to_apex_ventricle_left,Branch ID
0,150.4411,124.7483,28.0525,36.571322,8.558464,56.738653,46.315361,60.512694,41.531175,16
1,149.9095,125.0840,27.9798,36.407430,8.601129,56.295989,45.893261,60.671469,41.890470,16
2,149.3632,125.4346,27.9166,36.254903,8.660735,55.842215,45.455006,60.848865,42.273650,16
3,148.7591,125.8320,27.8633,36.107880,8.706975,55.341833,44.953634,61.065308,42.716953,16
4,148.1571,126.2426,27.8321,35.988210,8.765094,54.841469,44.437114,61.307315,43.183802,16
...,...,...,...,...,...,...,...,...,...,...
427,88.6466,103.6300,25.5747,4.697415,31.072905,6.973928,38.836652,58.810214,67.738877,4
428,88.3045,103.5971,25.8131,4.761104,31.190374,6.597201,38.768606,59.099982,68.070387,4
429,87.9935,103.5645,26.0413,4.744826,31.296010,6.247699,38.703464,59.365319,68.372559,4
430,87.7353,103.5360,26.2380,4.750161,31.383481,5.954022,38.647639,59.587169,68.624159,4


In [None]:
# Print unique values of the column
unique_values = df_rca_n1_labeled_cham['Branch ID'].unique()
print(unique_values)

[16  3  2  1  4]


In [None]:
visualize_segments_names(df_rca_n1_labeled_cham)

In [None]:
ventricle_right_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_ventricle_right.nii.gz"))
ventricle_left_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_ventricle_left.nii.gz"))
atrium_right_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_atrium_right.nii.gz"))
atrium_left_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_atrium_left.nii.gz"))

cavity_points_dict = {
    "heart_ventricle_right": ventricle_right_points_n1,
    "heart_ventricle_left": ventricle_left_points_n1,
    "heart_atrium_right": atrium_right_points_n1,
    "heart_atrium_left": atrium_left_points_n1,
}

In [None]:
visualize_segments_cavities_apex(df_rca_n1_labeled_cham, cavity_points_dict, apex_point_right_n1, apex_point_left_n1)

### **Execution - LCA Normal 1**

##### **Load and clean the dataset**

In [None]:
df_lca_n1 = parse_and_export_centerlines(file_path_lca_n1)
df_lca_n1.head()

Unnamed: 0,Branch ID,Px,Py,Pz,Tx,Ty,Tz,Nx,Ny,Nz,...,BNz,Dfit,Dmin,Dmax,C,Dh,Xh,Scf,Area,E
0,10,165.6963,105.4312,77.5512,-0.6097,0.2803,0.7414,-0.3427,-0.9367,0.0723,...,0.6671,0.958,0.5374,1.4355,,0.8015,0.5584,3.3391,0.6691,0.8144
1,10,165.5281,105.501,77.7533,-0.6262,0.2326,0.7441,-0.3115,-0.9496,0.0347,...,0.6671,1.3974,0.7427,1.9252,0.1859,1.1244,0.5841,5.2718,1.482,0.5607
2,10,165.3429,105.561,77.9707,-0.6432,0.1762,0.7451,-0.266,-0.964,-0.0016,...,0.6669,1.5941,0.7662,2.1823,0.218,1.342,0.615,5.791,1.9429,0.6037
3,10,165.1206,105.6082,78.2242,-0.6623,0.0971,0.7429,-0.2025,-0.9779,-0.0528,...,0.6673,1.4368,0.9247,1.8548,0.2609,1.2641,0.6815,5.156,1.6294,0.2421
4,10,164.8774,105.6258,78.4921,-0.6787,-0.0041,0.7344,-0.1182,-0.9863,-0.1147,...,0.669,1.4504,0.903,1.8502,0.3076,1.3192,0.713,4.8825,1.6102,0.7215


In [None]:
df_lca_n1 = clean_dataframe(df_lca_n1)
df_lca_n1

Unnamed: 0,Px,Py,Pz
0,165.6963,105.4312,77.5512
1,165.5281,105.5010,77.7533
2,165.3429,105.5610,77.9707
3,165.1206,105.6082,78.2242
4,164.8774,105.6258,78.4921
...,...,...,...
901,123.9599,130.4967,91.7940
902,123.5791,130.2197,91.8576
903,123.1971,129.9483,91.9276
904,122.8180,129.6789,92.0056


##### **Minimmum Spannig Tree Graph**

In [None]:
MST_lca_n1 = calculate_mst(df_lca_n1)

In [None]:
visualize_mst(df_lca_n1, MST_lca_n1)

##### **Heart Chambers**

In [None]:
# 1. Cargar KDTrees
cavity_kdtrees_lca_n1 = load_all_cavities_as_kdtrees(path_chamb_n1, cavities)

# 2. Calcular distancias a cavidades
df_lca_n1_chamb_info = df_lca_n1.copy()
df_lca_n1_chamb_info = assign_cavity_distances(df_lca_n1_chamb_info, cavity_kdtrees_lca_n1)

# 3. Estimar ápex
# right ventricle
ventricle_right_path = os.path.join(path_chamb_n1, "heart_ventricle_right.nii.gz")
apex_point_right_n1 = get_apex_from_cavity(ventricle_right_path)

# left ventricle
ventricle_left_path = os.path.join(path_chamb_n1, "heart_ventricle_left.nii.gz")
apex_point_left_n1 = get_apex_from_cavity(ventricle_left_path)


# 4. Calcular distancia al ápex
df_lca_n1_chamb_info = assign_distance_to_apex(df_lca_n1_chamb_info, apex_point_right_n1, "ventricle_right")
df_lca_n1_chamb_info = assign_distance_to_apex(df_lca_n1_chamb_info, apex_point_left_n1, "ventricle_left")

df_lca_n1_chamb_info

Unnamed: 0,Px,Py,Pz,dist_to_heart_ventricle_right,dist_to_heart_ventricle_left,dist_to_heart_atrium_right,dist_to_heart_atrium_left,dist_to_apex_ventricle_right,dist_to_apex_ventricle_left
0,165.6963,105.4312,77.5512,31.043167,6.528175,69.071896,48.246136,79.660270,61.278767
1,165.5281,105.5010,77.7533,30.846617,6.564447,68.948750,48.042981,79.786326,61.454360
2,165.3429,105.5610,77.9707,30.623778,6.612826,68.809523,47.829282,79.913985,61.637875
3,165.1206,105.6082,78.2242,30.344024,6.606769,68.642676,47.587406,80.050783,61.843680
4,164.8774,105.6258,78.4921,30.031514,6.554956,68.460365,47.342683,80.178355,62.049572
...,...,...,...,...,...,...,...,...,...
901,123.9599,130.4967,91.7940,20.061955,6.578319,40.405095,2.903991,99.765076,89.283013
902,123.5791,130.2197,91.8576,19.777715,6.538960,40.068030,2.813371,99.680809,89.314655
903,123.1971,129.9483,91.9276,19.497434,6.542137,39.732730,2.767705,99.607314,89.357147
904,122.8180,129.6789,92.0056,19.217516,6.586947,39.407866,2.684867,99.543142,89.408470


In [None]:
df_lca_n1_labeled_cham = labeling_lca(MST_lca_n1, df_lca_n1_chamb_info)
df_lca_n1_labeled_cham

Unnamed: 0,Px,Py,Pz,dist_to_heart_ventricle_right,dist_to_heart_ventricle_left,dist_to_heart_atrium_right,dist_to_heart_atrium_left,dist_to_apex_ventricle_right,dist_to_apex_ventricle_left,Branch ID
0,165.6963,105.4312,77.5512,31.043167,6.528175,69.071896,48.246136,79.660270,61.278767,10
1,165.5281,105.5010,77.7533,30.846617,6.564447,68.948750,48.042981,79.786326,61.454360,10
2,165.3429,105.5610,77.9707,30.623778,6.612826,68.809523,47.829282,79.913985,61.637875,10
3,165.1206,105.6082,78.2242,30.344024,6.606769,68.642676,47.587406,80.050783,61.843680,10
4,164.8774,105.6258,78.4921,30.031514,6.554956,68.460365,47.342683,80.178355,62.049572,10
...,...,...,...,...,...,...,...,...,...,...
901,123.9599,130.4967,91.7940,20.061955,6.578319,40.405095,2.903991,99.765076,89.283013,12
902,123.5791,130.2197,91.8576,19.777715,6.538960,40.068030,2.813371,99.680809,89.314655,12
903,123.1971,129.9483,91.9276,19.497434,6.542137,39.732730,2.767705,99.607314,89.357147,12
904,122.8180,129.6789,92.0056,19.217516,6.586947,39.407866,2.684867,99.543142,89.408470,12


In [None]:
# Print unique values of the column
unique_values = df_lca_n1_labeled_cham['Branch ID'].unique()
print(unique_values)

[10  8 11 13  7  6  5  9 17 12]


In [None]:
visualize_segments(df_lca_n1_chamb_info)

In [None]:
visualize_segments_names(df_lca_n1_labeled_cham)

In [None]:
ventricle_right_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_ventricle_right.nii.gz"))
ventricle_left_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_ventricle_left.nii.gz"))
atrium_right_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_atrium_right.nii.gz"))
atrium_left_points_n1 = load_cavity_points(os.path.join(path_chamb_n1, "heart_atrium_left.nii.gz"))

cavity_points_dict = {
    "heart_ventricle_right": ventricle_right_points_n1,
    "heart_ventricle_left": ventricle_left_points_n1,
    "heart_atrium_right": atrium_right_points_n1,
    "heart_atrium_left": atrium_left_points_n1,
}

In [None]:
visualize_segments_cavities_apex(df_lca_n1_labeled_cham, cavity_points_dict, apex_point_right_n1, apex_point_left_n1)