In [3]:
import locale
locale.getpreferredencoding()
locale.getpreferredencoding = lambda: "UTF-8"

In [4]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-23.3.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-23.3.1


In [5]:
!apt update
!apt install -y libsm6 libxext6 libxrender-dev

[33m0% [Working][0m            Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
[33m0% [Waiting for headers] [1 InRelease 5,484 B/110 kB 5%] [Waiting for headers] [Connecting to ppa.la[0m                                                                                                    Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
[33m0% [Waiting for headers] [1 InRelease 5,484 B/110 kB 5%] [2 InRelease 0 B/3,626 B 0%] [Connecting to[0m[33m0% [Waiting for headers] [1 InRelease 14.2 kB/110 kB 13%] [Connecting to ppa.launchpadcontent.net (1[0m                                                                                                    Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
[33m0% [Waiting for headers] [1 InRelease 14.2 kB/110 kB 13%] [Connected to ppa.launchpadcontent.net (18[0m                                                                    

In [62]:
!pip install trimesh
!pip install open3d
!pip install scikit-learn
!pip install scipy

[0m

In [7]:
!git clone https://github.com/grapergrape/AMS_izziv.git

Cloning into 'AMS_izziv'...
remote: Enumerating objects: 135, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (122/122), done.[K
remote: Total 135 (delta 24), reused 119 (delta 11), pack-reused 0[K
Receiving objects: 100% (135/135), 15.27 MiB | 10.62 MiB/s, done.
Resolving deltas: 100% (24/24), done.


In [8]:
import trimesh
import numpy as np
import cupy as cp
import open3d
import sklearn
from sklearn.neighbors import KNeighborsClassifier

# Load the mesh
mesh = trimesh.load_mesh('/content/AMS_izziv/aneurysms/ru_22_CTA_PT00014_20200906.obj')
mesh.show()

In [9]:
import numpy as np

# Create a copy of the original mesh
deformed_mesh = mesh.copy()

# Get the vertices of the mesh
vertices = deformed_mesh.vertices

# Define a bounding box for one of the bunny's ears
ear_min = np.array([-0.05, 0.15, 0])
ear_max = np.array([0.05, 0.3, 0.1])

# Select vertices that fall within the bounding box
ear_vertices = ((vertices > ear_min) & (vertices < ear_max)).all(axis=1)

# Apply a non-linear transformation to the selected ear vertices
vertices[ear_vertices, 1] = vertices[ear_vertices, 1] + 0.05 * np.sin(10 * vertices[ear_vertices, 0])

# Replace the vertices in the deformed mesh
deformed_mesh.vertices = vertices

deformed_mesh.show()


# Coherent Point Drift implementation

In [10]:
def simplify_mesh(mesh, steps):
    """
    Simplify a 3D mesh by reducing the number of faces.

    Args:
        mesh (trimesh.Trimesh): The input 3D mesh.
        steps (int): The number of simplification steps.

    Returns:
        trimesh.Trimesh: The simplified mesh.
    """
    for _ in range(steps):
        # Apply quadratic decimation to reduce the face count by half in each step.
        mesh = mesh.simplify_quadric_decimation(mesh.faces.shape[0] // 2)
    return mesh

def downsample_with_knn(mesh, steps):
    """
    Downsample a 3D mesh while preserving color information using a K-nearest neighbors (KNN) approach.

    Args:
        mesh (trimesh.Trimesh): The input 3D mesh with color information.
        steps (int): The number of downsampling steps.

    Returns:
        trimesh.Trimesh: The downsampled mesh with predicted colors.
    """
    # Step 1: Train a KNN model on how the input mesh is colored.

    # Get the vertices of the original mesh
    original_mesh_vertices = mesh.vertices

    # Assuming color information is stored in vertex_colors
    original_mesh_colors = mesh.visual.vertex_colors

    # Create a KNN classifier with 3 neighbors
    neigh = KNeighborsClassifier(n_neighbors=3)

    # Train the KNN model on the original mesh's vertices and their colors
    neigh.fit(original_mesh_vertices, original_mesh_colors)

    # Step 2: Clone the original mesh
    mesh2 = mesh.copy()

    # Step 3: Perform mesh simplification to reduce the face count
    mesh2 = simplify_mesh(mesh2, steps)

    # Step 4: Get the vertices of the downsampled mesh
    downsampled_mesh_vertices = mesh2.vertices

    # Step 5: Predict color labels for the downsampled mesh using the KNN model
    predicted_labels = neigh.predict(downsampled_mesh_vertices)

    # Step 6: Add the predicted labels to the downsampled mesh
    mesh2.visual.vertex_colors = predicted_labels

    return mesh2

def initialize_nonrigid(X, Y, beta, lambda_, w):
    """
    This function initializes the parameters to be used in the CPD algorithm.

    X : ndarray
        Source point cloud.
    Y : ndarray
        Target point cloud.
    beta : float
        Gaussian filter standard deviation.
    lambda_ : float
        Weight of the regularization term.
    w : float
        Weight for the uniform distribution in the EM algorithm.

    Returns
    -------
    W : ndarray
        Weight matrix for transforming each point.
    G : ndarray
        Gaussian Affinity matrix.
    X_hat : ndarray
        Initial transformed source point cloud.
    sigma2 : float
        Initial estimate of the variance.
    """
    N, D = X.shape
    M = Y.shape[0]

    # Compute initial sigma2
    sigma2 = (cp.sum(cp.sum(cp.square(X - cp.mean(X, axis=0)))) + cp.sum(cp.sum(cp.square(Y - cp.mean(Y, axis=0))))) / (D * (M + N))
    # The initial estimate of the variance is the sum of variances of X and Y.

    # Initialize W
    W = cp.zeros((M, D))
    # W is initialized to zero, meaning that the initial transformation is the identity.

    # Create Gaussian affinity matrix G
    diff = X[:, None] - X
    G = cp.exp(-cp.sum(diff**2, axis=2) / (2 * beta ** 2))
    # G is computed as the Gaussian of the pairwise distances in X.

    # X_hat is initially X
    X_hat = X
    # The transformed source point cloud starts as a copy of the original.

    return W, G, X_hat, sigma2

def calculate_P_nonrigid(X, Y, sigma2, beta, w):
    """
    This function calculates the matrix P for the CPD algorithm.

    X : ndarray
        Transformed source point cloud.
    Y : ndarray
        Target point cloud.
    sigma2 : float
        Estimate of the variance.
    beta : float
        Gaussian filter standard deviation.
    w : float
        Weight for the uniform distribution in the EM algorithm.

    Returns
    -------
    P : ndarray
        The matrix P.
    """
    N, D = X.shape
    M = Y.shape[0]
    P = cp.empty((N, M))

    diff = X[:, None] - Y
    P = cp.exp(-cp.sum(diff**2, axis=2) / (2 * sigma2 ** 2)) / (sigma2 ** D * (2 * cp.pi) ** (D / 2))
    # P is computed as a Gaussian of the pairwise distances, scaled by the estimated variance.

    cP = w / ((1 - w) * (2 * cp.pi * beta) ** (D / 2))
    den = cp.sum(P, axis=0)
    den = den + cP
    P = cp.divide(P, den[:, cp.newaxis])
    # P is normalized to sum to 1, to form a valid probability distribution.

    return P

def maximize_nonrigid(P, X, Y, G, beta, lambda_):
    """
    This function maximizes the parameters for the CPD algorithm.

    P : ndarray
        The matrix P.
    X : ndarray
        Transformed source point cloud.
    Y : ndarray
        Target point cloud.
    G : ndarray
        Gaussian Affinity matrix.
    beta : float
        Gaussian filter standard deviation.
    lambda_ : float
        Weight of the regularization term.

    Returns
    -------
    W : ndarray
        The updated weight matrix.
    sigma2 : float
        The updated estimate of the variance.
    """
    N, D = X.shape
    M = Y.shape[0]

    # Compute mu_x and mu_y
    ones_N = cp.ones(N)
    Pt1 = cp.dot(P.T, ones_N)
    Px = cp.dot(P.T, X)
    mu_x = cp.sum(Px, axis=0) / cp.sum(Pt1, axis=0)
    mu_y = cp.mean(Y, axis=0)
    # mu_x and mu_y are the centroids of the source and target point clouds, respectively, weighted by the probabilities in P.

    # Compute X_hat and Y_hat
    X_hat = X - mu_x
    Y_hat = Y - mu_y
    # X_hat and Y_hat are the source and target point clouds, translated to have zero centroid.

    P_diag = cp.zeros((M, M))
    cp.fill_diagonal(P_diag, Pt1)
    # P_diag is the diagonal matrix of column sums of P.

    # Compute C
    C = cp.dot(G.T, cp.dot(P_diag, G)) + lambda_ * beta * cp.eye(G.shape[0])
    # C is a regularized version of a weighted covariance matrix of the transformed point cloud.

    # Compute W
    W = cp.dot(cp.linalg.inv(C), Px - cp.dot(G.T, cp.dot(P_diag, X_hat)))
    # W is the solution of a linear system that minimizes the objective function.

    # Compute sigma2
    sigma2 = (cp.trace(X.T.dot(P_diag).dot(X)) + cp.trace(W.T.dot(G).dot(W)) - 2 * cp.trace(Px.T.dot(W))) / (N * D)
    # sigma2 is the updated estimate of the variance, based on the current transformation.

    return W, sigma2

def fit_nonrigid(X, Y, beta=2, lambda_=2, w=0.1, max_iterations=100, tol=1e-5):
    """
    This function implements the non-rigid variant of the Coherent Point Drift (CPD) algorithm.

    X : ndarray
        Source point cloud.
    Y : ndarray
        Target point cloud.
    beta : float
        Gaussian filter standard deviation.
    lambda_ : float
        Weight of the regularization term.
    w : float
        Weight for the uniform distribution in the EM algorithm.
    max_iterations : int
        Maximum number of iterations.
    tol : float
        Tolerance for stopping criterion.

    Returns
    -------
    X_hat : ndarray
        The final transformed source point cloud.
    """
    N, D = X.shape
    M = Y.shape[0]
    W, G, X_hat, sigma2 = initialize_nonrigid(X, Y, beta, lambda_, w)

    for i in range(max_iterations):
        P = calculate_P_nonrigid(X_hat + cp.dot(G, W), Y, sigma2, beta, w)
        W_new, sigma2_new = maximize_nonrigid(P, X_hat, Y, G, beta, lambda_)
        T = X + cp.dot(G, W_new)
        # T is the transformed point cloud after applying the current estimated transformation.

        if cp.linalg.norm(W_new - W) < tol:
            break
        # If the change in W is small, we have converged and can stop early.

        W, sigma2 = W_new, sigma2_new
        # Update the transformation parameters for the next iteration.

        X_hat = T
        # X_hat is updated to be the currently estimated transformed point cloud.

    return X_hat

def get_rmse(deformed_mesh, aligned_mesh):
    return cp.sqrt(cp.mean((cp.asarray(deformed_mesh.vertices) - cp.asarray(aligned_mesh.vertices))**2))

def align_mesh(original_mesh, original_deformed_mesh, threshold=0.01, max_downsample=5):
    rmse = float('inf')
    for downsample_level in range(max_downsample, -1, -1):
        print(f"Downsample level: {downsample_level}")
        # Downsample the original mesh
        intra_mesh = original_mesh.copy()
        intra_deformed_mesh = original_deformed_mesh.copy()

        intra_mesh = simplify_mesh(intra_mesh, downsample_level) # Adjust voxel_size as needed
        intra_deformed_mesh = simplify_mesh(intra_deformed_mesh, downsample_level) # Downsample deformed_mesh too

        # Transfer your data to the GPU
        mesh_gpu = cp.asarray(intra_mesh.vertices)
        deformed_mesh_gpu = cp.asarray(intra_deformed_mesh.vertices)
        try:
          # Call the fit function
            aligned_vertices_gpu = fit_nonrigid(mesh_gpu, deformed_mesh_gpu)
        except ValueError:
          #Skip axis mismatch error stemming from simplify
            print("Non conforming downsampling level")
            continue
        # Transfer the aligned vertices back to the CPU
        aligned_vertices = cp.asnumpy(aligned_vertices_gpu)

        # Apply the transformation to deformed_mesh
        aligned_mesh = deformed_mesh.copy()
        aligned_mesh.vertices = aligned_vertices

        # Calculate the RMSE
        rmse = get_rmse(intra_deformed_mesh, aligned_mesh)
        print(f"RMSE at downsampling level {downsample_level}: {rmse}")

        if rmse <= threshold:
            print(f"Final RMSE: {rmse}")
            return aligned_mesh

    print("Stopped as downsampling level 0 reached.")
    return None

In [11]:
# Apply the modified downsampling function with a specified number of steps
downsampled_mesh = downsample_with_knn(mesh, 5)
downsampled_mesh.show()

## Validation

Threshold should be changed in based on the quality of input image, example: for bunny res 4 it shouldnt be that low

In [12]:
aligned_mesh = align_mesh(mesh, deformed_mesh)
if aligned_mesh is not None:
    aligned_mesh.show()

Downsample level: 5
RMSE at downsampling level 5: 1.9055442770617985e-07
Final RMSE: 1.9055442770617985e-07


###Training models with curvature

In [43]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import os
import numpy as np
import trimesh
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
# Define the directory path
dir_path = '/content/AMS_izziv/Registration Cases'



# Get a list of all .obj files in the directory
file_names = [os.path.join(dir_path, file_name) for file_name in os.listdir(dir_path) if file_name.endswith('.obj')]

# Define curvature calculation functions
def calc_angle(mesh, v1, v2):
    dot_product = np.dot(mesh.vertex_normals[v1], mesh.vertex_normals[v2])
    # Fix potential issue of arccos input value out of range due to precision
    dot_product = np.clip(dot_product, -1.0, 1.0)
    return np.arccos(dot_product)

def calc_curvature(mesh):
    curvatures = []
    for i in range(len(mesh.vertices)):
        curvatures.append(np.mean([calc_angle(mesh, i, neighbor) for neighbor in mesh.vertex_neighbors[i]]))
    return np.array(curvatures)

# Function to threshold and binarize colors
def binarize_colors(colors):
    r = colors[:, 0]
    r_binarized = np.where(r < 50, 0, 1)
    return r_binarized


def calc_gaussian_curvature(mesh):
    # Initialize an array to store the Gaussian curvature of each vertex
    gaussian_curvatures = np.zeros(len(mesh.vertices))

    # Iterate over all the vertices
    for i in range(len(mesh.vertices)):
        # Get the adjacent vertices
        adjacent_vertices = mesh.vertex_neighbors[i]

        # Compute the total angle around the vertex
        total_angle = sum([calc_angle(mesh, i, j) for j in adjacent_vertices])

        # Calculate the Gaussian curvature
        gaussian_curvature = (2*np.pi - total_angle) / len(adjacent_vertices)

        gaussian_curvatures[i] = gaussian_curvature

    return gaussian_curvatures

def calc_curvature_derivative(mesh, curvatures):
    curvature_derivative = []
    for i in range(len(mesh.vertices)):
        neighbors = mesh.vertex_neighbors[i]
        curvature_diffs = [curvatures[i] - curvatures[n] for n in neighbors]
        curvature_derivative.append(np.mean(curvature_diffs))
    return np.array(curvature_derivative)

def calc_curvature_derivative2(mesh, curvature_derivatives):
    curvature_derivative2 = []
    for i in range(len(mesh.vertices)):
        neighbors = mesh.vertex_neighbors[i]
        curvature_diffs = [curvature_derivatives[i] - curvature_derivatives[n] for n in neighbors]
        curvature_derivative2.append(np.mean(curvature_diffs))
    return np.array(curvature_derivative2)

curvature_features = []
gaussian_curvatures = []  # New list to store Gaussian curvatures
segmentation_masks = []

# Iterate over all file names
for file_name in file_names:
    # Load the original mesh
    original_mesh = trimesh.load_mesh(file_name)

    # Binarize R channel colors
    r_binarized = binarize_colors(original_mesh.visual.vertex_colors)

    # Calculate curvatures
    curvatures = calc_curvature(original_mesh)

    # Calculate Gaussian Curvatures
    g_curvatures = calc_gaussian_curvature(original_mesh)  # New line
    gaussian_curvatures.append(g_curvatures)  # New line

    # Calculate the mean and median of curvatures and the difference of each curvature from them
    mean_curv = np.mean(curvatures)
    median_curv = np.median(curvatures)

     # Calculate first and second curvature derivatives
    curvature_derivatives = calc_curvature_derivative(original_mesh, curvatures)
    curvature_derivatives2 = calc_curvature_derivative2(original_mesh, curvature_derivatives)

    # Stack curvature features
    curvature_features.append(np.column_stack((curvatures,
                                               curvatures - mean_curv,
                                               curvatures - median_curv,
                                               g_curvatures ,
                                               g_curvatures - mean_curv,
                                               g_curvatures - median_curv,
                                               curvature_derivatives,
                                               curvature_derivatives2,
                                               )))


    # Add binarized colors to list
    segmentation_masks.append(r_binarized)

# Modify the neural network architecture
model = models.Sequential([
    layers.Dense(1024, activation='tanh', input_shape=(8,)),  # Double the number of neurons in the input layer
    layers.Dropout(0.5),  # Add dropout for regularization
    layers.Dense(512, activation='tanh'),  # Double the number of neurons in the second layer
    layers.Dropout(0.5),  # Add dropout for regularization
    layers.Dense(256, activation='tanh'),
    layers.Dense(128, activation='tanh'),
    layers.Dense(64, activation='tanh'),   # Add an extra layer
    layers.Dense(2, activation='softmax')
])

# Specify learning rate when creating the optimizer
optimizer = Adam(learning_rate=0.000001)

# Compile the model with the new optimizer
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
for features, mask in zip(curvature_features, segmentation_masks):
    # Check for NaN values in features and mask
    if np.isnan(features).any() or np.isnan(mask).any():
        print("NaN values detected in features or mask")
        continue

    # Calculate class weights for each set of features and masks
    class_weights = {i: len(mask) / (2 * np.sum(mask == i)) for i in [0, 1]}

    # Convert segmentation masks to categorical (one-hot encoding)
    Y = to_categorical(mask)

    # Train the model on this set of features and masks, and provide the class weights
    model.fit(features, Y, epochs=10, batch_size=16, class_weight=class_weights)

Found GPU at: /device:GPU:0
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8

In [70]:
import numpy as np
from scipy.spatial import ConvexHull, distance
import trimesh
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import os

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
# Define the directory path
dir_path = '/content/AMS_izziv/Registration Cases'
# Get a list of all .obj files in the directory
file_names = [os.path.join(dir_path, file_name) for file_name in os.listdir(dir_path) if file_name.endswith('.obj')]

diameter_features = []
segmentation_masks = []

# Function to threshold and binarize colors
def binarize_colors(colors):
    r = colors[:, 0]
    r_binarized = np.where(r > 50, 1, 0)
    return r_binarized

def convex_hull_diameter(pts):
    hull = ConvexHull(pts)
    coords = hull.points[hull.vertices]
    max_dist = 0
    max_pair = None

    for i, pt1 in enumerate(coords[:-1]):
        for pt2 in coords[i+1:]:
            dist = distance.euclidean(pt1, pt2)
            if dist > max_dist:
                max_dist = dist
                max_pair = (pt1, pt2)

    return max_dist

def calculate_diameters(mesh):
    diameters = []
    for vertex in mesh.vertices:
        neighbors = mesh.vertex_neighbors[vertex]
        neighbor_coordinates = mesh.vertices[neighbors]
        diameter = convex_hull_diameter(neighbor_coordinates)
        diameters.append(diameter)

    return np.array(diameters)

def calc_diameter_derivative(diameters):
    return np.gradient(diameters)

def calc_diameter_derivative2(diameters):
    return np.gradient(calc_diameter_derivative(diameters))
# Calculate the median diameter
def calc_med_diameter(diameters):
    return np.median(diameters)

diameter_features = []
segmentation_masks = []

# Iterate over all file names
for file_name in file_names:
    # Load the original mesh
    original_mesh = trimesh.load_mesh(file_name)
    # Binarize R channel colors
    r_binarized = binarize_colors(original_mesh.visual.vertex_colors)
    segmentation_masks.append(r_binarized)
    # Calculate diameters
    dia = calculate_diameters(original_mesh)
    median_dia = calc_med_diameter(dia)
    # Calculate first and second diameter derivatives
    dia_derivatives = calc_diameter_derivative(dia)
    dia_derivatives2 = calc_diameter_derivative2(dia_derivatives)
    diameter_features.append(np.column_stack((dia,
                                               dia - median_dia,
                                               dia_derivatives,
                                               dia_derivatives2,
                                               )))
# Define the model - a simple feed forward network
model = models.Sequential([
    layers.Dense(1024, activation='tanh', input_shape=(4,)),
    layers.Dropout(0.5),
    layers.Dense(512, activation='tanh'),
    layers.Dropout(0.5),
    layers.Dense(256, activation='tanh'),
    layers.Dense(128, activation='tanh'),
    layers.Dense(64, activation='tanh'),
    layers.Dense(2, activation='softmax')
])

# Compile the model
model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
for features, mask in zip(diameter_features, segmentation_masks):
    Y = to_categorical(mask)
    if np.isnan(features).any() or np.isnan(mask).any():
        print("NaN values detected in features or mask")
        continue
    class_weights = {i: len(mask) / (2 * np.sum(mask == i)) for i in [0, 1]}
    model.fit(features, Y, epochs=10, batch_size=16, class_weight=class_weights)


Found GPU at: /device:GPU:0


  neighbors = np.array(mesh.vertex_neighbors)[vertex.astype(int)]


IndexError: ignored

In [58]:
import os
import glob
def save_model(model):
    # Get list of all "version_*.h5" files in the directory
    version_files = glob.glob('/content/AMS_izziv/version_*.h5')

    # Parse version numbers from file names and find the highest version number
    versions = [int(os.path.basename(file).split('_')[1].split('.')[0]) for file in version_files]
    latest_version = max(versions) if versions else 0

    # Define the file name for the next model
    next_model_file = f'/content/AMS_izziv/version_{latest_version + 1}.h5'

    # Save the model
    model.save(next_model_file)

    return latest_version + 1  # Return the latest version number

def delete_previous_model(previous_version):
    # If there is a previous model, delete it
    if previous_version != 0:
        prev_model_file = f'/content/AMS_izziv/version_{previous_version}.h5'
        if os.path.exists(prev_model_file):
            os.remove(prev_model_file)
output_test_mesh = trimesh.load_mesh('/content/AMS_izziv/aneurysms/ru_24_CTA_PT00016_20200925.obj')
comparison = output_test_mesh.copy()
output_test_mesh = downsample_with_knn(output_test_mesh, 5)
comparison = downsample_with_knn(comparison, 5)
# Save the model
latest_version = save_model(model)

# Delete the previous version
delete_previous_model(latest_version - 1)
# Extract the vertices
test_vertices = output_test_mesh.vertices

# Calculate the curvatures for the test mesh
test_curvatures = calc_curvature(output_test_mesh)

# Calculate Gaussian Curvatures for the test mesh
test_g_curvatures = calc_gaussian_curvature(output_test_mesh)

# Calculate first and second curvature derivatives for the test mesh
test_curvature_derivatives = calc_curvature_derivative(output_test_mesh, test_curvatures)
test_curvature_derivatives2 = calc_curvature_derivative2(output_test_mesh, test_curvature_derivatives)

# Calculate the mean and median of curvatures and the difference of each curvature from them
mean_curv_test = np.mean(test_curvatures)
median_curv_test = np.median(test_curvatures)

# Calculate the mean and median of Gaussian curvatures and the difference of each curvature from them
mean_g_curv_test = np.mean(test_g_curvatures)
median_g_curv_test = np.median(test_g_curvatures)

test_features = np.column_stack((test_curvatures, test_curvatures - mean_curv_test, test_curvatures - median_curv_test,
                                  test_g_curvatures, test_g_curvatures - mean_g_curv_test, test_g_curvatures - median_g_curv_test,
                                  test_curvature_derivatives, test_curvature_derivatives2))  # Include all eight features

# Perform prediction using the model
Y_test_pred = model.predict(test_features)
unique_values_class_0 = np.unique(Y_test_pred[:, 0])
unique_values_class_1 = np.unique(Y_test_pred[:, 1])

print('Unique Values for Class 0:', unique_values_class_0)
print('Unique Values for Class 1:', unique_values_class_1)
factor = 1 # Adjust this value to suit your needs

# Multiply the predicted probabilities for class 1 by the factor
Y_test_pred_adjusted = Y_test_pred.copy()  # Create a copy to avoid modifying the original predictions
Y_test_pred_adjusted[:, 1] *= factor
# Convert the adjusted probabilities to class labels
Y_test_pred_binary_adjusted = np.argmax(Y_test_pred_adjusted, axis=1)

# Map binary labels to RGB colors
colors = np.zeros((len(Y_test_pred_binary_adjusted), 3))
colors[Y_test_pred_binary_adjusted == 0] = [0, 0, 0]  # Set label 0 to black
colors[Y_test_pred_binary_adjusted == 1] = [255, 0, 0]  # Set label 1 to red

# Assign colors to the vertices of the new mesh
output_test_mesh.visual.vertex_colors = colors

# Visualize the mesh
output_test_mesh.show()

Unique Values for Class 0: [0.39528367 0.43842128 0.4719551  0.48871493 0.49528384 0.5040179
 0.50553703 0.5100355  0.5114658  0.517495   0.5188723  0.522811
 0.5324662  0.5338122  0.53853697 0.54644865 0.54844594 0.5486884
 0.55126196 0.551361   0.55174124 0.5551646  0.5565747  0.55724484
 0.558773   0.5591165  0.56158096 0.5655573  0.5668582  0.5673674
 0.56739306 0.5712843  0.57165724 0.574127   0.57442385 0.57581
 0.576493   0.5767467  0.58278877 0.5837934  0.5838817  0.58606744
 0.58868426 0.59306073 0.5935192  0.5936945  0.594555   0.59588116
 0.5975678  0.59966946 0.6007101  0.6026192  0.60319954 0.60389364
 0.607753   0.6077575  0.60983765 0.61230767 0.61276877 0.613925
 0.6144076  0.616098   0.6173951  0.61855704 0.6212932  0.6246781
 0.62622917 0.626854   0.6285677  0.6285755  0.62893385 0.6296506
 0.6299313  0.6301262  0.6314713  0.6316937  0.6319107  0.63280565
 0.6364101  0.63648695 0.63903683 0.6396951  0.64046466 0.64225256
 0.64550096 0.64762795 0.647732   0.65147156 0.

  saving_api.save_model(


In [59]:
comparison.show()