In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Symbols

---


| Variable                                 	| Unit of Measurement 	|
|------------------------------------------	|----------------------	|
| $k$, permeability of the liver           	| $m^2$                	|
| $\mu_s$, viscosity of embolizing solution 	| $Pa \cdot s$         	|
| $\mu_o$, viscosity of blood               	| $Pa \cdot s$         	|
| distances                                	| $m$                  	|

| **Symbol**                  | **Description**                                | **Unit**           |
|-----------------------------|------------------------------------------------|--------------------|
| $ k $                     | Permeability of the liver                      | $ \mathrm{m}^2 $ |
| $ \mu_s $                 | Viscosity of embolizing solution               | $ \mathrm{Pa} \cdot \mathrm{s} $ |
| $ \mu_o $                 | Viscosity of blood                             | $ \mathrm{Pa} \cdot \mathrm{s} $ |
| $ \Omega $                | 3D Bulk Domain                                 | Dimensionless      |
| $ \Sigma $                | Cylindrical Vessel                             | Dimensionless      |
| $ \Gamma $                | Lateral Boundary of $ \Sigma $               | Dimensionless      |
| $ \Gamma_0, \Gamma_S $    | Top and Bottom Faces of $ \Sigma $           | Dimensionless      |
| $ \lambda(s) $            | $\mathcal{C}^2$-regular Curve (Centerline)    | $ \mathrm{m} $ (Parameter $ s $) |
| $ s $                     | Arc-length Parameter along $ \lambda $        | $ \mathrm{m} $   |
| $ \Lambda $               | Centerline of $ \Sigma $                      | Dimensionless      |
| $ \mathcal{D}(s) $        | Cross-section of $ \Sigma $                   | $ \mathrm{m}^2 $ |
| $ \partial \mathcal{D}(s) $| Boundary of Cross-section $ \mathcal{D}(s) $  | $ \mathrm{m} $   |
| $ u_{\oplus} $            | Fluid Potential in $ \Omega $ (Exterior)       | $ \mathrm{Pa} $  |
| $ u_{\ominus} $           | Fluid Potential in $ \Lambda $ (Interior)      | $ \mathrm{Pa} $  |
| $ \beta $                 | Coupling Coefficient                            | $ \frac{\mathrm{m}^3}{\mathrm{s} \cdot \mathrm{Pa}} $ |
| $ \kappa $                | Permeability Coefficient                        | $ \frac{\mathrm{m}^2}{\mathrm{s} \cdot \mathrm{Pa}} $ |
| $ g $                     | Source Term                                    | $ \frac{\mathrm{m}}{\mathrm{s}} $ |
| $ \alpha $                | Diffusion Coefficient                           | $ \frac{\mathrm{m}^2}{\mathrm{s} \cdot \mathrm{Pa}} $ |
| $ \Delta $                | Laplacian Operator                              | $ \frac{1}{\mathrm{m}^2} $ |
| $ \delta_{\Lambda} $      | Dirac Measure on $ \Lambda $                  | $ \frac{1}{\mathrm{m}^2} $ |
| $ f $                     | Forcing Term                                    | $ \frac{1}{\mathrm{s}} $ |
| $ \bar{u}_{\oplus} $      | Averaged Fluid Potential in $ \Omega $         | $ \mathrm{Pa} $  |
| $ \bar{u}_{\ominus} $     | Averaged Fluid Potential in $ \Lambda $        | $ \mathrm{Pa} $  |

In [2]:
# @title Install nonstandard libraries
%%capture
!pip install ipywidgets
!pip install vtk
!pip install meshio
!pip install pyvista

import os, re

def replace_in_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()

    # Replace 'ufl' with 'ufl_legacy'
    content = re.sub(r'\bufl\b', 'ufl_legacy', content)

    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(content)

def process_directory(directory):
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith('.py'):
                file_path = os.path.join(root, file)
                replace_in_file(file_path)

# dolfin
try:
    import dolfin
except ImportError:
    !wget "https://fem-on-colab.github.io/releases/fenics-install-real.sh" -O "/tmp/fenics-install.sh" && bash "/tmp/fenics-install.sh"

# block
try:
    import block
except ImportError:
    !git clone "https://bitbucket.org/fenics-apps/cbc.block/src/master/"
    !pip install master/

# fenics_ii
try:
    import xii
except ImportError:
    !git clone "https://github.com/MiroK/fenics_ii"
    process_directory("fenics_ii/")
    !pip install fenics_ii/

# graphnics
try:
    import graphnics
except ImportError:
    !git clone "https://github.com/IngeborgGjerde/graphnics"
    !pip install graphnics/

In [10]:
WD_PATH = "/content/drive/MyDrive/Research/3D-1D"

import sys, os
sys.path.append(os.path.join(WD_PATH, 'modules'))

import visualizer
import daedalus
import FEMSensitivity
import FEMSink
import FEMEdge

import meshio
import scipy
import copy
import vtk
import json
import numpy as np
import matplotlib.pyplot as plt
import datetime
import networkx as nx
from matplotlib.gridspec import GridSpec
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import cKDTree
import importlib
from dolfin import *
from vtk.util.numpy_support import vtk_to_numpy
from xii import *
from graphnics import *

In [None]:
# @title Define G = serena data graph
G = FenicsGraph()
ind = 0
branch_points = {}

for n in range(29):
    file_path = os.path.join(WD_PATH, 'data', 'pv_json1', f'Centerline_{str(n)}.mrk.json')
    f = open(file_path)
    data = json.load(f)
    f.close()

    # get coords + radius at each point
    points = data['markups'][0]['controlPoints']
    radius = data['markups'][0]['measurements'][3]['controlPointValues']
    G.add_nodes_from(range(ind - n, ind + len(points) - n))

    # check if first coord is branch point from previous centerlines
    v1 = 0
    for key, val in branch_points.items():
        if points[0]['position'] == val:
            v1 = key
            break

    # add coords and radius to nodes fenics graph
    v2 = ind - n + 1
    pos_v1 = points[0]['position']
    pos_v2 = points[1]['position']
    G.nodes[v1]["pos"] = pos_v1
    G.nodes[v2]["pos"] = pos_v2
    G.nodes[v1]["radius"] = radius[0]
    G.nodes[v2]["radius"] = radius[1]
    # add edge to fenics graph
    G.add_edge(v1, v2)

    for i in range(len(points)-2):
        v1 = ind - n + 1 + i
        v2 = v1 + 1
        # convert coordinates from mm to meters (divide by 1000)
        pos_v1 = [coord / 1000 for coord in points[0]['position']]
        pos_v2 = [coord / 1000 for coord in points[1]['position']]
        pos_v1 = points[i + 1]['position']
        pos_v2 = points[i + 2]['position']
        G.nodes[v1]["pos"] = pos_v1
        G.nodes[v2]["pos"] = pos_v2
        G.nodes[v1]["radius"] = radius[i + 1]
        G.nodes[v2]["radius"] = radius[i + 2]
        G.add_edge(v1, v2, radius=radius[i+1])

    # store last point as a branch point
    ind += len(points)
    branch_points.update({ind-n-1: pos_v2})

In [12]:
# @title Define G = .vtk domain read
import vtk

def read_vtk(file_path):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(file_path)
    reader.Update()
    output = reader.GetOutput()

    G = FenicsGraph()

    damage_array = output.GetPointData().GetArray("Damage")
    for i in range(output.GetNumberOfPoints()):
        point = output.GetPoint(i)
        damage_value = damage_array.GetValue(i)
        G.add_node(i, pos=tuple(point), damage=damage_value)

    radius_array = output.GetCellData().GetArray("Radius")
    for i in range(output.GetNumberOfCells()):
        cell = output.GetCell(i)
        point_ids = [cell.GetPointId(j) for j in range(cell.GetNumberOfPoints())]
        for j in range(len(point_ids) - 1):
            u = point_ids[j]
            v = point_ids[j + 1]
            radius_value = radius_array.GetValue(i) if radius_array else None
            G.add_edge(u, v, radius=radius_value)

    return G

# Usage
file_path = WD_PATH + '/data/vtk/sortedDomain.vtk'
# file_path = WD_PATH + '/oncopigReferenceData/ZPAF23S018/20230531/vesselNetwork_upDated.vtk'
G = read_vtk(file_path)

In [None]:
def filter_small_components(G, threshold=20):
  """
  Removes connected components from the graph G that have a total length
  (sum of edge lengths) less than the specified threshold.

  Args:
    G: The networkx graph to process.
    threshold: The minimum total length for a component to be kept.
  """

  # Convert to undirected graph
  G_undirected = G.to_undirected()

  # Find connected components
  components = list(nx.connected_components(G_undirected))

  # Calculate total length for each component
  for component in components:
    total_length = 0
    for u, v in G_undirected.subgraph(component).edges():
      pos_u = G_undirected.nodes[u]['pos']
      pos_v = G_undirected.nodes[v]['pos']
      length = ((pos_u[0] - pos_v[0])**2 +
                (pos_u[1] - pos_v[1])**2 +
                (pos_u[2] - pos_v[2])**2) ** 0.5
      total_length += length

    # Remove component if total length is below threshold
    if total_length < threshold:
      G.remove_nodes_from(component)  # Remove from the original directed graph

# Apply the filtering
filter_small_components(G)

# @title Clean graph (only for raw .vtk domain reads)
def cleanup_graph(G):
    G.remove_nodes_from(list(nx.isolates(G)))
    G = nx.convert_node_labels_to_integers(G)

    # Merge close nodes
    positions = nx.get_node_attributes(G, 'pos')
    positions_array = np.array(list(positions.values()))
    tree = cKDTree(positions_array)
    merged_nodes = set()

    for node, pos in positions.items():
        # Find nearby nodes within the threshold distance
        nearby_nodes = tree.query_ball_point(pos, 1.0e-4)

        # Merge with the first nearby node (if any) that hasn't been merged yet
        for other_node in nearby_nodes:
            if other_node != node and other_node not in merged_nodes:
                G = nx.contracted_nodes(G, node, other_node)
                merged_nodes.add(other_node)
                break  # Stop after merging with one node

    # Set zero radii to 0.1
    for node in G.nodes():
        if G.nodes[node]['radius'] == 0:
            G.nodes[node]['radius'] = 0.1

    return G

G = cleanup_graph(G)

In [None]:
# @title Define test graph
# Create the FenicsGraph
G = FenicsGraph()

# Add nodes and their attributes
node_data = {
    0: {'pos': [8, 20, 15], 'radius': 1},
    1: {'pos': [10, 20, 15], 'radius': 1},
    2: {'pos': [11, 20, 15], 'radius': 1},
    3: {'pos': [13, 20, 15], 'radius': 1},
    4: {'pos': [15, 20, 15], 'radius': 1},
    5: {'pos': [16, 20, 15], 'radius': 1},
    6: {'pos': [18, 20, 15], 'radius': 1},
    7: {'pos': [20, 20, 15], 'radius': 1},
    8: {'pos': [21, 20, 15], 'radius': 1},
    9: {'pos': [23, 20, 15], 'radius': 1},
    10: {'pos': [33, 37, 15], 'radius': 1},
    11: {'pos': [32, 35, 15], 'radius': 1},
    12: {'pos': [31, 33, 15], 'radius': 1},
    13: {'pos': [30, 31, 15], 'radius': 1},
    14: {'pos': [29, 29, 15], 'radius': 1},
    15: {'pos': [27, 28, 15], 'radius': 1},
    16: {'pos': [26, 26, 15], 'radius': 1},
    17: {'pos': [25, 24, 15], 'radius': 1},
    18: {'pos': [24, 22, 15], 'radius': 1},
    19: {'pos': [23, 20, 15], 'radius': 1},
    20: {'pos': [33, 3, 15], 'radius': 1},
    21: {'pos': [32, 5, 15], 'radius': 1},
    22: {'pos': [31, 7, 15], 'radius': 1},
    23: {'pos': [30, 9, 15], 'radius': 1},
    24: {'pos': [29, 11, 15], 'radius': 1},
    25: {'pos': [27, 12, 15], 'radius': 1},
    26: {'pos': [26, 14, 15], 'radius': 1},
    27: {'pos': [25, 16, 15], 'radius': 1},
    28: {'pos': [24, 18, 15], 'radius': 1},
    29: {'pos': [23, 20, 15], 'radius': 1}
}

for i in range(9):  # Connect nodes between V1 and C
    G.add_edge(i, i + 1)

for i in range(10, 19):  # Connect nodes between V2 and C
    G.add_edge(i+1, i)

for i in range(20, 29):  # Connect nodes between V3 and C
    G.add_edge(i+1, i)

# Set node attributes
for node_id, attributes in node_data.items():
    for attr_name, attr_value in attributes.items():
        G.nodes[node_id][attr_name] = attr_value

for i in range(9):  # Connect nodes between V1 and C
    G.add_edge(i, i + 1)

# pos_t = {idx: data['pos'][:2] for idx, data in node_data.items()}  # 2D position (x, y)
# nx.draw(G, pos_t, with_labels=True, node_color='lightblue', node_size=500, font_size=10, font_color='black')

In [None]:
#@title Define test graph 2
# Import the necessary library
G = FenicsGraph()

# Define the node labels and their corresponding integer IDs
node_mapping = {
    0: 'A',
    1: 'B',
    2: 'C',
    3: 'D',
    4: 'E',
    5: 'F',
    6: 'G',
    7: 'H'
}

# Define the coordinates for each node (x, y, z)
node_coords = {
    0: [0, 20, 15],    # A
    1: [10, 20, 15],   # B
    2: [22, 13, 15],   # C
    3: [22, 28, 15],   # D
    4: [15, 5, 15],    # E
    5: [15, 35, 15],   # F
    6: [38, 5, 15],    # G
    7: [38, 35, 15]    # H
}

# Define the edges along with their radii
edges_with_radii = [
    (0, 1, 4),  # AB
    (1, 2, 3),  # BC
    (1, 3, 3),  # BD
    (2, 4, 2),  # CE
    (2, 6, 3),  # CG
    (3, 5, 2),  # DF
    (3, 7, 3)   # DH
]

# Create the FenicsGraph object
G = FenicsGraph()

# Add nodes to the graph with their positions
for node_id, coord in node_coords.items():
    G.add_node(node_id, pos=coord)

# Add edges to the graph with their radii
for u, v, radius in edges_with_radii:
    G.add_edge(u, v, radius=radius)

# Initialize a dictionary to store node radii
node_radii = {node: 0 for node in G.nodes()}

# Iterate over all edges to determine the maximum radius for each node
for u, v, data in G.edges(data=True):
    radius = data['radius']
    if radius > node_radii[u]:
        node_radii[u] = radius
    if radius > node_radii[v]:
        node_radii[v] = radius

# Assign the computed radii to the respective nodes
for node, radius in node_radii.items():
    G.nodes[node]['radius'] = radius

# (Optional) If you want to visualize the graph, you can use the following code:
# import matplotlib.pyplot as plt
# pos_2d = {node: (data['pos'][0], data['pos'][1]) for node, data in G.nodes(data=True)}
# radii = [data['radius']*100 for node, data in G.nodes(data=True)]  # Scale radii for visualization
# nx.draw(G, pos_2d, with_labels=True, node_color='lightblue', node_size=radii, font_size=10, font_color='black')
# plt.show()

# Print the nodes with their attributes to verify
for node in G.nodes(data=True):
    node_id = node[0]
    attributes = node[1]
    print(f"Node {node_mapping[node_id]} (ID: {node_id}): Position = {attributes['pos']}, Radius = {attributes['radius']}")


Node A (ID: 0): Position = [0, 20, 15], Radius = 4
Node B (ID: 1): Position = [10, 20, 15], Radius = 4
Node C (ID: 2): Position = [22, 13, 15], Radius = 3
Node D (ID: 3): Position = [22, 28, 15], Radius = 3
Node E (ID: 4): Position = [15, 5, 15], Radius = 2
Node F (ID: 5): Position = [15, 35, 15], Radius = 2
Node G (ID: 6): Position = [38, 5, 15], Radius = 3
Node H (ID: 7): Position = [38, 35, 15], Radius = 3


In [None]:
importlib.reload(FEMSensitivity)

<module 'FEMSensitivity' from '/content/drive/MyDrive/Research/3D-1D/modules/FEMSensitivity.py'>

In [14]:
fem_primal = FEMSink.FEMSink(
    G=G,
    kappa=1.0,
    alpha=9.6e-2,
    beta=1.45e4,
    gamma=1.0,
    P_infty=1.0e3,
    theta=1.0,
    P_sink=1.0e3
)

fem_primal.save_vtk("output_directory")

IndexError: list index out of range

In [15]:
# k = 5.0e-12
# mu_o = 8.9e-4
k = 9.44e-3
mu_o = 4.5e-3
P_cvp = 5 #mmHg
P_in = 100 #mmHg

test = FEMEdge.FEMEdge(
  G,
  kappa = 1.0,
  alpha = k / mu_o,
  beta = 6.0e+4 * (k / mu_o),
  gamma = 0.2,
  del_Omega = P_cvp * 133.322, # Pa
  P_infty = P_in * 133.322, # Pa
)

IndexError: list index out of range

In [None]:
importlib.reload(visualizer)

<module 'visualizer' from '/content/drive/MyDrive/Research/3D-1D/modules/visualizer.py'>

In [None]:
visualizer.visualize(fem_primal.Lambda, fem_primal.uh1d, mesh3d=fem_primal.Omega, sol3d=fem_primal.uh3d, elev=20, azim=-80, z_level=0)

LinAlgError: Singular matrix

<Figure size 2000x1000 with 6 Axes>

In [None]:
test.save_vtk(os.path.join(WD_PATH, 'perfusion_results', 'test_rh_ex3'))

Poster figure generators down below

In [None]:
# @title test schematic plotter
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np

# Define vertices of the rectangular prism
vertices = np.array([[0, 0, 0],
                     [34.5, 0, 0],
                     [34.5, 40, 0],
                     [0, 40, 0],
                     [0, 0, 30],
                     [34.5, 0, 30],
                     [34.5, 40, 30],
                     [0, 40, 30]])

# Define faces of the rectangular prism
faces = [[vertices[j] for j in [0, 1, 5, 4]],  # Bottom face
         [vertices[j] for j in [2, 3, 7, 6]],  # Top face
         [vertices[j] for j in [0, 3, 7, 4]],  # Left face
         [vertices[j] for j in [1, 2, 6, 5]],  # Right face
         [vertices[j] for j in [0, 1, 2, 3]],  # Front face
         [vertices[j] for j in [4, 5, 6, 7]]]  # Back face

# Create a figure and a 3D axis
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=20, azim=-80)

# Plot the rectangular prism
ax.add_collection3d(Poly3DCollection(faces, facecolors='lightgray', linewidths=1, edgecolors='black', alpha=0.3))

# Define points for the Y-shaped bifurcation
center = np.array([23, 20, 15])
point1 = np.array([3, 20, 15])
point2 = np.array([33, 37, 15])
point3 = np.array([33, 3, 15])
dim0 = np.array([0,0,0])
dim1 = np.array([34.5, 40, 30])

# Plot the bifurcation edges
ax.plot([center[0], point1[0]], [center[1], point1[1]], [center[2], point1[2]], color='black', linewidth=2)
ax.plot([center[0], point2[0]], [center[1], point2[1]], [center[2], point2[2]], color='black', linewidth=2)
ax.plot([center[0], point3[0]], [center[1], point3[1]], [center[2], point3[2]], color='black', linewidth=2)

# Add red dots and labels at the vertices and the center of the bifurcation
# Change marker style to black diamonds with hollow interiors
vertices_with_center = np.vstack([point1, point2, point3, center, dim0, dim1])
elevation = 3
for v in vertices_with_center:
    ax.scatter(v[0], v[1], v[2], color='black', s=50, marker='D', facecolors='none') # Black diamond with hollow interior
    ax.text(v[0], v[1], v[2] + elevation, f'({(v[0]):.1f}, {int(v[1])}, {int(v[2])})', color='black', zorder=10)

# Adjust plot limits for padding
padding = 1
ax.set_xlim([0 - padding, 35 + padding])
ax.set_ylim([0 - padding, 40 + padding])
ax.set_zlim([0 - padding, 30 + padding])

# Set plot labels
ax.set_xlabel('X (mm)')
ax.set_ylabel('Y (mm)')
ax.set_zlabel('Z (mm)')

plt.tight_layout()
plt.show()

Bounding box: $[0, 34.5] \times [0,40] \times [0,30] \subset \mathbb{R}^3$

Bifurcation vertices: $V_1 (8,20,15), V_2(33,37,15), V_3(33,3,15), C(23,20,15)$

Here, $C$ is connected to $V_1, V_2, V_3$ via separate edges

In [None]:
import FEMSensitivity
# k = 5.0e-12
# mu_o = 8.9e-4
k = 9.44e-3
mu_o = 4.5e-3
P_cvp = 5 #mmHg
P_in = 100 #mmHg

test = FEMSensitivity.FEMSensitivity(
  G,
  kappa = 1.0,
  alpha = k / mu_o,
  beta = 1e-10,
  gamma = 0.2,
  del_Omega = P_cvp * 133.322, # Pa
  P_infty = P_in * 133.322 # Pa
)

Averaging over 29352 cells: 100%|██████████| 29352/29352 [00:51<00:00, 568.60it/s]


In [None]:
# @title generate kdtree spec

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

# Assuming G is already defined and contains node positions
# Extract positions into a NumPy array for easier manipulation
positions = np.array([G.nodes[node]['pos'] for node in G.nodes])

# Find the minimum position in each dimension
min_x, min_y, min_z = np.min(positions, axis=0)

# Shift all positions to align with the origin
for node in G.nodes:
    G.nodes[node]['pos'] = (
        G.nodes[node]['pos'][0] - min_x,
        G.nodes[node]['pos'][1] - min_y,
        G.nodes[node]['pos'][2] - min_z,
    )

# Select a smaller, controllable proportion of points to plot planes (e.g., 6%)
proportion = 0.05  # Modify this value to change the proportion
subset_size = int(proportion * len(positions))
subset_indices = np.random.choice(range(len(positions)), subset_size, replace=False)
subset_positions = positions[subset_indices]

# Step 3: Plot scatter plot and planes
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=20, azim=190)

# Scatter plot only the subset of points used for k-d tree plane plotting
ax.scatter(subset_positions[:, 0], subset_positions[:, 1], subset_positions[:, 2], c='blue', s=5)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# Plot a subset of the planes used in the k-d tree computation
for depth, (x, y, z) in enumerate(subset_positions):
    if depth % 3 == 0:  # Plane parallel to YZ (fix x)
        xx, zz = np.meshgrid([x, x], np.linspace(0, 100, 10))
        yy = np.linspace(0, 100, 10)
        ax.plot_surface(xx, yy[np.newaxis].T, zz, color='r', alpha=0.2)
    elif depth % 3 == 1:  # Plane parallel to XZ (fix y)
        xx, zz = np.meshgrid(np.linspace(0, 100, 10), np.linspace(0, 100, 10))
        yy = np.ones_like(xx) * y
        ax.plot_surface(xx, yy, zz, color='g', alpha=0.2)
    else:  # Plane parallel to XY (fix z)
        xx, yy = np.meshgrid(np.linspace(0, 100, 10), np.linspace(0, 100, 10))
        zz = np.ones_like(xx) * z
        ax.plot_surface(xx, yy, zz, color='b', alpha=0.3)

plt.show()


In [None]:
# @title generate tree spec

used_points = subset_positions

# Now, let's create a simple KD-Tree-like structure manually and visualize the graph

class KDNode:
    def __init__(self, point=None, left=None, right=None, axis=0):
        self.point = point
        self.left = left
        self.right = right
        self.axis = axis

def build_kd_tree(points, depth=0):
    if len(points) == 0:
        return None

    k = points.shape[1]  # Dimensions (3D in this case)
    axis = depth % k

    # Sort points along the current axis and choose the median
    sorted_points = points[points[:, axis].argsort()]
    median_idx = len(sorted_points) // 2

    # Create the node and construct the subtrees
    return KDNode(
        point=sorted_points[median_idx],
        left=build_kd_tree(sorted_points[:median_idx], depth + 1),
        right=build_kd_tree(sorted_points[median_idx + 1:], depth + 1),
        axis=axis
    )

# Build the KD-Tree with the subset of points used to generate the planes
kd_tree_manual = build_kd_tree(used_points)

# Now let's visualize the tree structure using matplotlib
def plot_kd_tree_graph_formatted(ax, node, depth=0, x_pos=0, y_pos=0, dx=4, parent_pos=None):
    if node is None:
        return

    # Format the point to 3 decimal places
    formatted_point = np.round(node.point, 3)

    # Plot the current node
    ax.text(x_pos, y_pos, f'{formatted_point}', ha='center', va='center',
            bbox=dict(facecolor='white', edgecolor='black'))

    # Draw a line from the current node to its parent, if any
    if parent_pos:
        ax.plot([x_pos, parent_pos[0]], [y_pos, parent_pos[1]], 'k-')

    # Recurse for left and right children
    plot_kd_tree_graph_formatted(ax, node.left, depth + 1, x_pos - dx / (depth + 1), y_pos - 2, dx, (x_pos, y_pos))
    plot_kd_tree_graph_formatted(ax, node.right, depth + 1, x_pos + dx / (depth + 1), y_pos - 2, dx, (x_pos, y_pos))

# Plot the manually built KD-tree structure
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title('k-d Tree Structure for $\Lambda$')
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 2)
ax.axis('off')

plot_kd_tree_graph_formatted(ax, kd_tree_manual)

plt.gca().invert_yaxis()
plt.show()