# Element Parameter Detection


## Setup


In [None]:
%load_ext autoreload

import numpy as np
import math
import random
import os

import torch
import pickle

from tqdm.notebook import tqdm
from ipywidgets import interact 
import gc
import open3d as o3d

from src.elements import *
from src.ifc import *
from src.preparation import *
from src.visualisation import *
from src.chamfer import *
from src.utils import *
from src.morph import *

random.seed = 42


In [None]:
from pytorch3d.ops import utils as pytorch3d_utils
from pytorch3d.ops.points_normals import _disambiguate_vector_directions
from pytorch3d.ops.knn import knn_points
from typing import Tuple, TYPE_CHECKING, Union
from pytorch3d.common.workaround import symeig3x3


def get_point_covariances_relative(
    points_padded: torch.Tensor,
    targets_padded: torch.Tensor,
    num_points_per_cloud: torch.Tensor,
    neighborhood_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the per-point covariance matrices by of the 3D locations of
    K-nearest neighbors of each point.

    Args:
        **points_padded**: Input point clouds as a padded tensor
            of shape `(minibatch, num_points, dim)`.
        **num_points_per_cloud**: Number of points per cloud
            of shape `(minibatch,)`.
        **neighborhood_size**: Number of nearest neighbors for each point
            used to estimate the covariance matrices.

    Returns:
        **covariances**: A batch of per-point covariance matrices
            of shape `(minibatch, dim, dim)`.
        **k_nearest_neighbors**: A batch of `neighborhood_size` nearest
            neighbors for each of the point cloud points
            of shape `(minibatch, num_points, neighborhood_size, dim)`.
    """
    # get K nearest neighbor idx for each point in the point cloud
    k_nearest_neighbors = knn_points(
        targets_padded,
        points_padded,
        lengths1=num_points_per_cloud,
        lengths2=num_points_per_cloud,
        K=neighborhood_size,
        return_nn=True,
    ).knn
    # obtain the mean of the neighborhood
    pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
    # compute the diff of the neighborhood and the mean of the neighborhood
    central_diff = k_nearest_neighbors - pt_mean
    # per-nn-point covariances
    per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
    # per-point covariances
    covariances = per_pt_cov.mean(2)

    return covariances, k_nearest_neighbors


# calculate eigen values and eigen vectors relative to points from a second point cloud
def estimate_pointcloud_local_coord_frames_relative(
    pointclouds: Union[torch.Tensor, "Pointclouds"],
    targets: Union[torch.Tensor, "Pointclouds"],
    neighborhood_size: int = 50,
    disambiguate_directions: bool = True,
    *,
    use_symeig_workaround: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Estimates the principal directions of curvature (which includes normals)
    of a batch of `pointclouds`.

    The algorithm first finds `neighborhood_size` nearest neighbors for each
    point of the point clouds, followed by obtaining principal vectors of
    covariance matrices of each of the point neighborhoods.
    The main principal vector corresponds to the normals, while the
    other 2 are the direction of the highest curvature and the 2nd highest
    curvature.

    Note that each principal direction is given up to a sign. Hence,
    the function implements `disambiguate_directions` switch that allows
    to ensure consistency of the sign of neighboring normals. The implementation
    follows the sign disabiguation from SHOT descriptors [1].

    The algorithm also returns the curvature values themselves.
    These are the eigenvalues of the estimated covariance matrices
    of each point neighborhood.

    Args:
      **pointclouds**: Batch of 3-dimensional points of shape
        `(minibatch, num_point, 3)` or a `Pointclouds` object.
      **neighborhood_size**: The size of the neighborhood used to estimate the
        geometry around each point.
      **disambiguate_directions**: If `True`, uses the algorithm from [1] to
        ensure sign consistency of the normals of neighboring points.
      **use_symeig_workaround**: If `True`, uses a custom eigenvalue
        calculation.

    Returns:
      **curvatures**: The three principal curvatures of each point
        of shape `(minibatch, num_point, 3)`.
        If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
      **local_coord_frames**: The three principal directions of the curvature
        around each point of shape `(minibatch, num_point, 3, 3)`.
        The principal directions are stored in columns of the output.
        E.g. `local_coord_frames[i, j, :, 0]` is the normal of
        `j`-th point in the `i`-th pointcloud.
        If `pointclouds` are of `Pointclouds` class, returns a padded tensor.

    References:
      [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
      Local Surface Description, ECCV 2010.
    """

    points_padded, num_points = pytorch3d_utils.convert_pointclouds_to_tensor(pointclouds)
    targets_padded, target_num_points = pytorch3d_utils.convert_pointclouds_to_tensor(targets)

    ba, N, dim = points_padded.shape
    if dim != 3:
        raise ValueError(
            "The pointclouds argument has to be of shape (minibatch, N, 3)"
        )

    if (num_points <= neighborhood_size).any():
        raise ValueError(
            "The neighborhood_size argument has to be"
            + " >= size of each of the point clouds."
        )

    # undo global mean for stability
    # TODO: replace with tutil.wmean once landed
    pcl_mean = points_padded.sum(1) / num_points[:, None]
    points_centered = points_padded - pcl_mean[:, None, :]
    targets_centered = targets_padded - pcl_mean[:, None, :]

    # get the per-point covariance and nearest neighbors used to compute it
    cov, knns = get_point_covariances_relative(points_centered, targets_centered, num_points, neighborhood_size)

    # get the local coord frames as principal directions of
    # the per-point covariance
    # this is done with torch.symeig / torch.linalg.eigh, which returns the
    # eigenvectors (=principal directions) in an ascending order of their
    # corresponding eigenvalues, and the smallest eigenvalue's eigenvector
    # corresponds to the normal direction; or with a custom equivalent.
    if use_symeig_workaround:
        curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
    else:
        curvatures, local_coord_frames = torch.linalg.eigh(cov)

    # disambiguate the directions of individual principal vectors
    if disambiguate_directions:
        # disambiguate normal
        n = _disambiguate_vector_directions(
            points_centered, knns, local_coord_frames[:, :, :, 0]
        )
        # disambiguate the main curvature
        z = _disambiguate_vector_directions(
            points_centered, knns, local_coord_frames[:, :, :, 2]
        )
        # the secondary curvature is just a cross between n and z
        y = torch.cross(n, z, dim=2)
        # cat to form the set of principal directions
        local_coord_frames = torch.stack((n, y, z), dim=3)

    return curvatures, local_coord_frames

### sphere morphing


In [None]:
# visualise a list of point clouds as an animation using open3d
# use ctrl+c to copy and ctrl+v to set camera and zoom inside visualiser
def create_point_cloud_animation(cloud_list, loss_func, save_image=False, colours=None):
    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug)
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    cloud = cloud_list[0]
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(cloud)
    if colours is not None:
        point_cloud.colors = o3d.utility.Vector3dVector(colours[0])
    vis.add_geometry(point_cloud)
    stops = [9,39,99,299,999]

    for i in range(len(cloud_list)):
        time.sleep(0.01 + 0.05/(i/10+1))
        cloud = cloud_list[i]
        point_cloud.points = o3d.utility.Vector3dVector(cloud)
        if colours is not None:
            point_cloud.colors = o3d.utility.Vector3dVector(colours[i])
        vis.update_geometry(point_cloud)
        vis.poll_events()
        vis.update_renderer()
        if save_image and i in stops:
            vis.capture_screen_image("sphere/" + loss_func + str(i) + ".jpg", do_render=True)
    vis.destroy_window()

    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Info)


In [None]:
# cld1_name = "sphere/chair.pcd"
# loss_func = "chamfer"
# run_morph(cld1_name, loss_func)

In [None]:
from pytorch3d.ops import points_normals
from eindex import eindex
import einops

# calculate chamfer loss based on local curvature
def calc_curvature_loss_tensor(x, y, k=32, return_assignment=False):
    chamferDist = ChamferDistance()
    eps = 0.00001

    # add a loss term for mismatched pairs
    nn = chamferDist(
        x, y, bidirectional=True, return_nn=True
    )

    eig_vals_x, eig_vects_x = estimate_pointcloud_local_coord_frames_relative(
        x, y, neighborhood_size=k)

    eig_vals_y, eig_vects_y = estimate_pointcloud_local_coord_frames_relative(
        y, x, neighborhood_size=k)

    corresponding_y_vals = eindex(eig_vals_y, torch.squeeze(nn[0].idx, dim=-1), "batch [batch points] eigenvalues")
    corresponding_x_vals = eindex(eig_vals_x, torch.squeeze(nn[1].idx, dim=-1), "batch [batch points] eigenvalues")
    
    corresponding_y_vects = eindex(eig_vects_y, torch.squeeze(nn[0].idx, dim=-1), "batch [batch points] eigenvalues eigenvects")
    corresponding_x_vects = eindex(eig_vects_x, torch.squeeze(nn[1].idx, dim=-1), "batch [batch points] eigenvalues eigenvects")
    #print("x", x.shape, nn[0].dists.shape, nn[0].dists.shape, torch.squeeze(nn[1].idx, dim=-1).shape)
    
    #print("sample eigen", eig_vects_x[0,:4,:,:], eig_vals_x[0,:4])

    eigen_val_dist_y = torch.sum(torch.square(eig_vals_x - corresponding_y_vals))
    eigen_val_dist_x = torch.sum(torch.square(eig_vals_y - corresponding_x_vals))
    
    # calculate dot product between eigenvectors
    dot_product_y = einops.einsum(eig_vects_x, corresponding_y_vects, "b n p q, b n p q -> b n p")
    # ignore direction and take absolute value
    dot_product_y = torch.sum(1. - torch.abs(dot_product_y))
    dot_product_x = einops.einsum(eig_vects_y, corresponding_x_vects, "b n p q, b n p q -> b n p")
    print("dot", (1. - torch.abs(dot_product_x))[0,:3])
    # ignore direction and take absolute value
    dot_product_x = torch.sum(1. - torch.abs(dot_product_x))
    

    #corresponding_y_points = eindex(x, torch.squeeze(nn[1].idx, dim=-1), "batch [batch points] xyz")
    #print("corresponding shape", corresponding_y_vals.shape, corresponding_y_vects.shape)
    #print("mean", eigen_val_dist_x)
    
    #curvature_loss = (dot_product_x + dot_product_y)/1000 + (eigen_val_dist_x + eigen_val_dist_y)*1000
    curvature_loss = (dot_product_x + dot_product_y)/1000
    #print("curvature loss", curvature_loss)

    #print("eig_vals_x", eig_vals_x.shape, eig_vects_x.shape, "eig_vals_y", eig_vals_y.shape, eig_vects_y.shape)

    bidirectional_dist = torch.sum(nn[0].dists[:,:,0]) + torch.sum(nn[1].dists[:, :, 0])
    print("d", torch.sum(nn[1].dists[:,:,0]).item(), torch.sum(nn[0].dists[:,:,0]).item())

    return curvature_loss + bidirectional_dist, None

In [None]:
cld1_name = "sphere/plane1.pcd"
cld2_name = "sphere/chair.pcd"

cuda = torch.device("cuda")
cld1 = np.array(o3d.io.read_point_cloud(cld1_name).points)
cld2 = np.array(o3d.io.read_point_cloud(cld2_name).points)
pcd1_tensor = torch.tensor([cld1], device=cuda)
pcd2_tensor = torch.tensor([cld2], device=cuda)

l = calc_curvature_loss_tensor(pcd1_tensor, pcd2_tensor)
print(l)

In [None]:
%autoreload 2

# visualise animation
cld1_name = "sphere/plane1.pcd"
visualise = True
#loss_funcs = [ "chamfer", "emd", "balanced", "reverse", "single"]
#loss_funcs = [ "balanced", "single"]
loss_funcs = [ "curvature"]
for loss_func in loss_funcs:
    print(loss_func)
    run_morph(cld1_name, loss_func, lr=.4)
    if visualise:
        with open("sphere/" + loss_func + ".pkl", "rb") as f:
            morphed = pickle.load(f)
        colours = visualise_density(morphed, 'plasma_r')
        with open("sphere/" + loss_func + "_dens.pkl", "wb") as f:
            pickle.dump(colours, f)
        #create_point_cloud_animation(cloud_list, loss_func)
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def view_density(loss_func): 
    with open("sphere/" + loss_func + ".pkl", "rb") as f:
        morphed = pickle.load(f)
    with open("sphere/" + loss_func + "_dens.pkl", "rb") as f:
        colours = pickle.load(f)

    create_point_cloud_animation(morphed, loss_func, True, colours[:,:,:3])

interact(view_density, loss_func=["curvature", "density", "balanced", "infocd", "single", "chamfer", "reverse", "emd", "direct"]);


In [None]:
torch.cuda.empty_cache()
gc.collect()

# visualise animation
loss_func = "emd"
with open("sphere/" + loss_func + ".pkl", "rb") as f:
    morphed = pickle.load(f)
print(morphed.shape, loss_func)
cloud_list = [m for m in morphed]
#create_point_cloud_animation(cloud_list, loss_func)


In [None]:
colours = visualise_density(morphed, 'plasma_r')
with open("sphere/" + loss_func + "_dens.pkl", "wb") as f:
    pickle.dump(colours, f)

### Batch optimisation


In [None]:
# downsample
shapenet_path = "../experiments/ICCV2023-HyperCD/ShapeNetCompletion/test/complete/"
downsampled_path = "../experiments/ICCV2023-HyperCD/ShapeNetCompletion/downsample/"
# folders = os.listdir(shapenet_path)
# choices = np.random.choice(len(points), 4096)

# for fl in tqdm(folders):
#     if not os.path.exists(downsampled_path+fl):
#         os.mkdir(downsampled_path+fl)
        
#     files = os.listdir(shapenet_path + fl)
#     cloud =  o3d.geometry.PointCloud()
#     for cl in files:
#         points = np.array(o3d.io.read_point_cloud(shapenet_path + fl + "/" + cl).points)
#         points = points[choices]
#         cloud.points = o3d.utility.Vector3dVector(points)
#         o3d.io.write_point_cloud(downsampled_path+fl + "/" + cl, cloud)

In [None]:
loss_func = "balanced"

cd, emd = sphere_morph_metrics(loss_func, downsampled_path, save=False)
print(cd[-1], emd[-1])

In [None]:
loss_func = "balanced"
#loss_func = "emd"

sphere_morph_metrics(loss_func, downsampled_path)

In [None]:
# plot losses on same axis
def plot_losses(losses, labels, title):
    x = np.arange(0, len(losses[0]))
    plt.figure(figsize=(30, 6))
    for i, loss in enumerate(losses):
        plt.plot(x, loss, label=labels[i])

    plt.xlabel("point cloud index")
    plt.ylabel("distance")
    plt.title(title)
    plt.legend()
    plt.show()

In [None]:
# create plots
loss_funcs = ["chamfer",]
chamfer_list, emd_list = [], []
for loss_func in loss_funcs:
    with open("sphere/" + loss_func + "_metrics.pkl", "rb") as f:
        chamfer, emd, assignments = pickle.load(f)
        chamfer_list.append(chamfer)
        emd_list.append(emd)
        
plot_losses(chamfer_list, loss_funcs, "chamfer")

ts(emd_list, loss_funcs, "EMD")
        
print(emd_list[-1], chamfer_list[-1])

In [None]:
# load losses
loss_types = ["reverse", "chamfer", "emd", "pair"]
losses = []

for loss_func in loss_types:
    with open("sphere/loss_" + loss_func + ".pkl", "rb") as f:
        losses.append(pickle.load(f))
        
plot_losses(losses, loss_types, "loss function comparison")


#### consistency


In [None]:
# measure conssistency between forward and backward correspondences for chamfer distance
# optionally compare against the ideal assignment, as measured by EMD
def measure_assignment_consistency(assignment, emd=None):
    reverse_assignment = torch.gather(assignment[0], 0, assignment[1])
    expected = torch.arange(assignment[0].shape[0], device=torch.device("cuda"))
    consistency = torch.sum(torch.eq(expected, reverse_assignment).long())
    print("consistency", consistency.item(), len(torch.unique(assignment[0])), len(torch.unique(assignment[1])))
    
    if emd is not None:
        #print(emd[:5], assignment[0][:5], assignment[1][:5])
        emd_consistency = torch.sum(torch.eq(emd, assignment[0]).long())
        #print("emd_consistency", emd_consistency.item(), len(torch.unique(emd)))


In [None]:
# compare correspondences
#TODO: include consistency in top5 matches?
loss_func = "emd"
with open("sphere/assignments_" + "emd" + ".pkl", "rb") as f:
    emd_assignment = pickle.load(f)

with open("sphere/assignments_" + loss_func + ".pkl", "rb") as f:
    assignment = pickle.load(f)
    
for i, ass in enumerate(assignment):
    #measure_assignment_consistency(ass, emd_assignment[i][0])
    measure_assignment_consistency(ass)


In [None]:
# find the morphed middle ground between two different clouds
cld1_name = "sphere/plane1.pcd"
cld2_name = "sphere/chair.pcd"

cuda = torch.device("cuda")
cld1 = np.array(o3d.io.read_point_cloud(cld1_name).points)
cld2 = np.array(o3d.io.read_point_cloud(cld2_name).points)
pcd1_tensor = torch.tensor([cld1], device=cuda)
pcd2_tensor = torch.tensor([cld2], device=cuda)

print(pcd1_tensor.shape, pcd2_tensor.shape)
loss, assignment = calc_emd(pcd1_tensor, pcd2_tensor, 0.05, 50)
assignment = assignment.detach().cpu().numpy()
print(assignment.shape)

matched_points = cld2[assignment[0]]
print(matched_points.shape)

morphed_points = (cld1 + matched_points)/2
morphed_cloud =  o3d.geometry.PointCloud()
morphed_cloud.points = o3d.utility.Vector3dVector(morphed_points)
o3d.io.write_point_cloud("sphere/chair_plane_morph.pcd", morphed_cloud)