## Rigid Alignment with ICP

### Import the Data

In [None]:
#import utils

import numpy as np
from pyFM.mesh import TriMesh
import os
import matplotlib.pyplot as plt

import itertools
import copy
import networkx as nx
import pyFM.spectral as spectral
import seaborn as sns
from pyFM.FMN import FMN
from stl import mesh
import pyFM
import torch
from stl import mesh
import sys

In [None]:
from sklearn.neighbors import KDTree, NearestNeighbors
import scipy
import scipy.spatial
import pickle
from tqdm.auto import tqdm
from os.path import join


def load_ints(path, from_matlab=False):
    vals = np.loadtxt(path,dtype=int)
    if from_matlab:
        vals -= 1
    return vals


def save_ints(path, vals, to_matlab=False):
    if to_matlab:
        vals += 1

    np.savetxt(path, vals, fmt='%d')


def save_pickle(path, vals):
    pickle.dump(vals, open(path,"wb"))


def load_pickle(path):
    return pickle.load(open(path,"rb"))


def knn_query(X, Y, k=1, return_distance=False, use_scipy=False, dual_tree=False, n_jobs=1):

    if use_scipy:
        tree = scipy.spatial.KDTree(X)
        dists, matches = tree.query(Y, k=k, workers=n_jobs)
        if k == 1:
            dists = dists.squeeze()
            matches = matches.squeeze()
    else:
        # if n_jobs == 1:
        #     tree = KDTree(X, leaf_size=40)
        #     dists, matches = tree.query(Y, k=k, return_distance=True)
        # else:
        tree = NearestNeighbors(n_neighbors=k, leaf_size=40, algorithm="kd_tree", n_jobs=n_jobs)
        tree.fit(X)
        dists, matches = tree.kneighbors(Y)
        if k == 1:
            dists = dists.squeeze()
            matches = matches.squeeze()

    if return_distance:
        return dists, matches
    return matches


def knn_query_normals(X, Y, normals1, normals2, k_base=30, return_distance=False, n_jobs=1, verbose=False):
    """
    Compute a NN query ensuring normal consistency.
    k_base determines the number of neighbors first computed for faster computation.
    """
    final_matches = np.zeros(Y.shape[0], dtype=int)
    final_dists = np.zeros(Y.shape[0])

    # FIRST DO KNN SEARCH HOPING TO OBTAIN FAST
    # tree = KDTree(X)  # Tree on (n1,)
    # dists, matches = tree.query(Y, k=k_base, return_distance=True)  # (n2,k), (n2,k)

    dists, matches = knn_query(X, Y, k=k_base, return_distance=True, n_jobs=n_jobs)

    # Check for which vertices the solution is already computed
    isvalid = np.einsum('nkp,np->nk', normals1[matches], normals2) > 0  # (n2, k)
    valid_row = isvalid.sum(1) > 0

    valid_inds = valid_row.nonzero()[0]  # (n',)
    invalid_inds = (~valid_row).nonzero()[0]  # (n2-n')

    if verbose:
        print(f'{valid_inds.size} direct matches and {invalid_inds.size} specific indices')

    # Fill the known values
    final_matches[valid_inds] = matches[(valid_inds, isvalid[valid_inds].argmax(axis=1))]
    if return_distance:
        final_dists[valid_inds] = dists[(valid_inds, isvalid[valid_inds].argmax(axis=1))]

    # Individually check other indices
    n_other = invalid_inds.size
    myit = range(n_other)
    for inv_ind in myit:
        vert_ind = invalid_inds[inv_ind]
        possible_inds = np.nonzero(normals1 @ normals2[vert_ind] > 0)[0]

        if len(possible_inds) == 0:
            final_matches[vert_ind] = matches[vert_ind,0]
            final_dists[vert_ind] = dists[vert_ind,0]
            continue

        tree = KDTree(X[possible_inds])
        temp_dist, temp_match_red = tree.query(Y[None, vert_ind], k=1, return_distance=True)

        final_matches[vert_ind] = possible_inds[temp_match_red.item()]
        final_dists[vert_ind] = temp_dist.item()

    if return_distance:
        return final_dists, final_matches
    return final_matches


def rotation(theta, axis):
    rot = np.zeros((3,3))

    rot[axis, axis] = 1

    inds = [i for i in range(3) if i != axis]
    rot[np.ix_(inds, inds)] = np.array([[np.cos(theta), -np.sin(theta)],
                                        [np.sin(theta), np.cos(theta)]])

    return rot


def rotx(theta):
    return rotation(theta, 0)


def roty(theta):
    return rotation(theta, 1)


def rotz(theta):
    return rotation(theta, 2)


def rigid_alignment(X1, X2, p2p_12=None, weights=None, return_params=False, return_deformed=True):
    """
    Solve optimal R and t so that
    || X1@R.T + t - X2 || is minimized

    X1 : (n1,3)
    X2 : (n2,3)
    p2p_12 : (n1,) point to point from X1 to X2
    weights : (n1,)

    Returns deformed X1
    """
    if not (return_params or return_deformed):
        raise ValueError("Choose something to return")

    X = X1  # (n1,3)
    Y = X2[p2p_12] if p2p_12 is not None else X2  # (n1,3)

    if weights is None:
        X_cent = X.mean(axis=0)
        Y_cent = Y.mean(axis=0)
    else:
        weights /= weights.sum()
        X_cent = (weights[:, None]*X).sum(axis=0)
        Y_cent = (weights[:, None]*Y).sum(axis=0)

    X_bar = X - X_cent
    Y_bar = Y - Y_cent

    if weights is None:
        H = X_bar.T @ Y_bar  # (3,3)
    else:
        H = X_bar.T @ (weights[:, None]*Y_bar)  # (3,3)

    U, _, VT = scipy.linalg.svd(H)
    theta = VT.T @ U.T

    if np.isclose(scipy.linalg.det(theta), -1):
        U[:, -1] *= -1
        theta = VT.T @ U.T

    t = Y_cent - X_cent@theta.T

    if not return_deformed:
        return theta, t

    X_new = X@theta.T + t

    if not return_params:
        return X_new

    else:
        return X_new, theta, t


def icp_align(X1, X2, p2p_12=None, weights=None, return_params=False, n_iter=50, epsilon=1e-8, n_jobs=1, verbose=False):
    """
    Solve optimal R and t so that
    || X1@R.T + t - X2 || is minimized
    using ICP

    X1 : (n1,3)
    X2 : (n2,3)
    p2p : (n1,) point to point from X1 to X2

    Returns deformed X1
    """
    tree = NearestNeighbors(n_neighbors=1, leaf_size=40, algorithm="kd_tree", n_jobs=n_jobs)
    tree.fit(X2)

    if p2p_12 is None:
        _, p2p_12 = tree.kneighbors(X1)
        p2p_12 = p2p_12.squeeze()

    X_curr = X1.copy()
    theta_curr = np.eye(3)
    t_curr = np.zeros(3)
    criteria = np.inf
    iteration = 0

    iterable = tqdm(range(n_iter)) if verbose else range(n_iter)
    for iteration in iterable:
        res_icp = rigid_alignment(X_curr, X2, p2p_12=p2p_12, weights=weights, return_params=return_params)

        if return_params:
            X_new, theta, t = res_icp
            theta_curr = theta @ theta_curr
            t_curr = theta @ t_curr + t
        else:
            X_new = res_icp

        _, p2p_12 = tree.kneighbors(X_new)
        p2p_12 = p2p_12.squeeze()

        criteria = np.linalg.norm(X_new - X_curr)
        X_curr = X_new.copy()

        if criteria < epsilon:
            break

    if verbose:
        print(f'Aligned using ICP in {iteration} iterations')

    if return_params:
        return X_new, theta_curr, t_curr

    return X_new

Functions to plot the meshes

In [None]:
import plotly.graph_objects as go

def plot_mesh(mesh,color=None, colormap='Sunset',reverse=True):
    '''
    Input:  v: vertices of the mesh
            f: faces
            colormap: type of coloration
    '''
    shape=mesh.vertices
    face=mesh.faces
    x, y, z = shape[:,0],shape[:,1],shape[:,2]
    f1,f2,f3= face[:,0], face[:,1], face[:,2]
    #project the error on the lbo basis
    fig = go.Figure(data=[go.Mesh3d(x=x,y=y,z=z, i=f1, j=f2, k=f3,
                                    intensity = color,  # set color to an array/list of desired values
                                    colorscale=colormap,
                                    reversescale = reverse,
                                    opacity=1#
    )])
    fig.show()

def plot_pointcloud(shape,color=None, colormap='Sunset',reverse=True):
    
    """
    Visualize a 3D point cloud with optional coloring using Plotly.

    This function takes a 3D point cloud and optionally a color array, then 
    generates an interactive 3D scatter plot using Plotly. The points can be 
    colored based on the provided color array, and various colormap options 
    are available.

    Parameters:
    ----------
    shape : numpy.ndarray
        A 2D array with shape (n, 3) representing the 3D coordinates of the 
        point cloud, where n is the number of points.
    color : array-like, optional
        An array of values used to color the points. The length of this array 
        should match the number of points in the point cloud. If None, the 
        points will not be colored based on any values.
    colormap : str, optional
        The colormap to be used for coloring the points. Default is 'Sunset'.
    reverse : bool, optional
        Whether to reverse the colormap. Default is True.

    Returns:
    -------
    None
        This function does not return any value. It creates and shows an 
        interactive Plotly plot in the default web browser.

    Example:
    -------
    >>> import numpy as np
    >>> shape = np.random.rand(100, 3)
    >>> color = np.random.rand(100)
    >>> plot_pointcloud(shape, color)
    """

    x, y, z = shape[:,0],shape[:,1],shape[:,2]

    #project the error on the lbo basis
    fig = go.Figure(data=[go.Scatter3d(x=x,y=y,z=z,
                                    mode='markers',
                                    marker=dict(
                                    color=color,
                                    size=2,
                                    colorscale=colormap,
                                    reversescale = reverse,
                                    opacity=1#
                                    ),
    )])
    fig.show()


def pick_points(pointcloud):

    """
    Visualize a 3D point cloud using Plotly and display point indices on hover.

    This function takes a point cloud (a list of 3D coordinates) as input and 
    generates an interactive 3D scatter plot using Plotly. The plot displays 
    each point in the cloud with markers, colored according to their z-coordinate.
    When hovering over a point, its index within the input list is shown.

    Parameters:
    ----------
    pointcloud : list of tuple
        A list of 3D coordinates, where each coordinate is a tuple (x, y, z).

    Returns:
    -------
    None
        This function does not return any value. It creates and shows an 
        interactive Plotly plot in the default web browser.

    Example:
    -------
    >>> pointcloud = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
    >>> pick_points(pointcloud)
    """

    # Extract x, y, z coordinates from the point cloud
    x, y, z = zip(*pointcloud)

    # Create a scatter plot
    scatter = go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(
            size=5,
            color=z,  # Use z-coordinate for color gradient
            colorscale='Viridis',  # Choose a colorscale
            opacity=0.8
        )
    )

    # Create layout
    layout = go.Layout(scene=dict(aspectmode="data"))

    # Create figure
    fig = go.Figure(data=[scatter], layout=layout)

    # Add hover info with point indices
    hover_text = [f'Index: {index}' for index in range(len(pointcloud))]
    fig.data[0]['text'] = hover_text

    # Show the interactive plot
    fig.show()

Import data

In [None]:
#1 min e 11 sec.
import trimesh
#Number of meshes to import:
N_meshes= 2
meshlist=[]
for i in range(N_meshes):
    meshlist.append(TriMesh(fr'D:\Make_Galileo_great_again\data\giorgio_\0000{i:02d}_tumoredbrain\0000{i:02d}_tumoredbrain.off')) 
    mesh= trimesh.Trimesh(meshlist[i].vertices, meshlist[i].faces)
    components = mesh.split(only_watertight=False)
    largest_component = max(components, key=lambda comp: comp.volume) 
    print(f"Number of vertices before LC for mesh{i:02d}:", meshlist[i].vertices.shape)
    meshlist[i]= TriMesh(largest_component.vertices, largest_component.faces).process(k=200, intrinsic=True)
    print(f"Number of vertices after LC for mesh{i:02d}:", meshlist[i].vertices.shape)



ICP

In [None]:
subsample_list = np.zeros((2, 10), dtype=int)
for i in tqdm(range(2)):
    subsample_list[i] = meshlist[i].extract_fps(10, geodesic=False, verbose=False)
fps1 = subsample_list[0]
fps2 = subsample_list[1]

In [None]:
# Get initial correspondences with KNN
p2p_21_init_sub = knn_query_normals(meshlist[0].vertlist[fps1], meshlist[1].vertlist[fps2],
                                            meshlist[0].vertex_normals[fps1], meshlist[1].vertex_normals[fps2],
                                            k_base=5, n_jobs=20, verbose=False)
# ICP Align the shape
_, R, t = icp_align(meshlist[1].vertlist[fps2], meshlist[0].vertlist[fps1],
                            p2p_12=p2p_21_init_sub,
                            return_params=True, n_jobs=20, epsilon=1e-4, verbose=False)

#mesh2.rotate(R);
#mesh2.translate(t);


Plot before the aligment

In [None]:
def double_plot(myMesh1,myMesh2,cmap1=None,cmap2=None):
    
    first_line = plot_mesh(myMesh1, cmap1)
    second_line = plot_mesh(myMesh2, cmap2)

    return 
double_plot(meshlist[0], meshlist[1])

ICP on the second mesh 

In [None]:
meshlist[1].rotate(R)
meshlist[1].translate(t)

Plot after ICP

In [None]:
double_plot(meshlist[0],meshlist[1])

## ICP + PYFMAP

Subsampling

### Select 10 landmarks with KNN

The subsample

In [None]:
subsample_list = np.zeros((2, 10), dtype=int)
for i in tqdm(range(2)):
    subsample_list[i] = meshlist[i].extract_fps(10, geodesic=False, verbose=False)
fps1 = subsample_list[0]
fps2 = subsample_list[1]

KNN

In [None]:
fps2 = knn_query(meshlist[1].vertlist, meshlist[0].vertlist[fps1])

Landmarks are expected to be provided as a two-dimensional array, where the first element contains the N landmarks of the first mesh, and the second element contains the N landmarks of the second mesh. These landmarks should be reshaped to establish a point-to-point correspondence with the landmarks on the first shape.

In [None]:
landmarks=np.array([fps1,fps2]).T

In [None]:
landmarks

In [None]:
from pyFM.functional import FunctionalMapping

process_params = {
        'n_ev': (20,20),  # Number of eigenvalues on source and Target
        'landmarks': landmarks,
        'subsample_step': 1,  # In order not to use too many descriptors
        'n_descr': 40, #number of descriptors
        'descr_type': 'HKS',  # WKS or HKS
    }
model = FunctionalMapping(meshlist[0],meshlist[1])
model.preprocess(**process_params,verbose=False)

model.fit(w_descr= 1e-1, w_lap= 1e-3, w_dcomm= 1,w_orient= 0, verbose=False)
fmap12=model.FM       #C{XY}, or A_{YX}^T

#p2p=model.get_p2p(fmap12, model.mesh1.eigenvectors[:,:k],model.mesh2.eigenvectors[:,:k],adj, bijective)
p2p=model.get_p2p()


### Apply Functional Map

In [None]:
plt.imshow(fmap12)

In [None]:
def visu(vertices):
    
    "The function normalizes the values over the vertices"
    
    min_coord,max_coord = np.min(vertices,axis=0,keepdims=True),np.max(vertices,axis=0,keepdims=True)
    cmap = (vertices-min_coord)/(max_coord-min_coord)
    return cmap



In [None]:
cmap1 = np.mean(visu(meshlist[0].vertlist),axis=1)
cmap2_wks = cmap1[p2p]
double_plot(meshlist[0],meshlist[1],(cmap1),(cmap2_wks))

In [None]:
#zoomout
from pyFM.refine.zoomout import mesh_zoomout_refine_p2p
FM_12_wks_zo, p2p_21_wks_zo = mesh_zoomout_refine_p2p(p2p_21=p2p, mesh1=meshlist[0], mesh2=meshlist[1], k_init=20, nit=16, step=5, return_p2p=True, n_jobs=10, verbose=True)

cmap1 = np.mean(visu(meshlist[0].vertlist),axis=1)
cmap2_wks = cmap1[p2p_21_wks_zo]

double_plot(meshlist[0],meshlist[1],(cmap1),(cmap2_wks))

In [None]:
plt.imshow(FM_12_wks_zo)