In [28]:
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np

import pymeshlab
import polyscope as ps
import open3d as o3d
import matplotlib.pyplot as plt

### LLDDM Resnets using L2 Loss

This is code for using the LLDDM with an L2 loss.

When using this, ensure that the order of points in the point cloud of the start and target image are correctly matched. 

Ie, if you want to register a point cloud hand, you need to know where exactly a point of that hand ends up in the target image, and retain the order of counting points when describing input/output as vectors.

If this isn't doable, use Chamfer's distance instead, implemented lower down in this code.

#### Layers

In [4]:
class DenseEulerFBlock(tf.keras.Model):
    """A Dense neural network block (n+relu x n x 2).

    The LLDDM paper considers a resnet where each residual passthrough occurs after 3 dense blocks.
    The size parameters effect how the net learns in an intuitive manner. 

    The first layer learns a partitioning of the input space into distinct polyhedra. 
    Relu is what allows for the hard boundary between cells.
    
    The second layer learns the contribution of each polyhedra.

    The third layer learns vectors with contributions from the above.
    For use with 2D registrations set this to 2. With 3D, 3, and so on.

    With this, it can be seen that each block learns tangent vectors within each
    polyhedra, with respect to contributions from other polyhedra.
    """
    def __init__(self, widths):
        super(DenseEulerFBlock, self).__init__()
        self.initialiser = tf.keras.initializers.HeNormal()
        
        self.d1 = tf.keras.layers.Dense(widths[0], activation='relu')
        self.d2 = tf.keras.layers.Dense(widths[1], activation=None)
        self.d3 = tf.keras.layers.Dense(widths[2], activation=None, use_bias=False)
        
    def call(self, input_tensor, training=False):
        return self.d3(self.d2(self.d1(input_tensor)))


class DenseEulerMergeBlock(tf.keras.Model):
    """This is just a RELU block.
    """
    def __init__(self):
        super(DenseEulerMergeBlock, self).__init__()
        
    def call(self, input_tensor, training=False):
        return tf.nn.relu(input_tensor)


def DenseCombinedLoss(d1, d2, d3, d4, d5, d6, m6, truth, sigma=0.1):
    """ The loss function: Kinetic energy minisation subject to correct registration.

    Args:
        d_{i = 1, ..., 6} : The outputs of each DenseEulerFBlock in the net.
                            I end up using 6 blocks, hence 6 of these exist.
        m6 : The final net output.
        truth : The expected output.
        sigma (float, optional): A regularisation parameter determining the ratio of significant
                                 of correct registration and kinetic energy minimisation. 
                                 Defaults to 0.1.
    """
    regularisation_loss = 0.5*(tf.norm(d1) + tf.norm(d2) + tf.norm(d3) + tf.norm(d4) + tf.norm(d5) + tf.norm(d6))/6
    data_term           = 0.5*tf.norm(m6-truth)/sigma
    return regularisation_loss + data_term

#### Example Network

In [7]:
tf.keras.backend.clear_session()
initializer = tf.keras.initializers.HeNormal()

input0 = tf.keras.Input(shape=(1, 2))

d1 = DenseEulerFBlock((500,500,2))(input0)
m1 = DenseEulerMergeBlock()(input0 + d1)

d2 = DenseEulerFBlock((500,500,2))(m1)
m2 = DenseEulerMergeBlock()(m1 + d2)

d3 = DenseEulerFBlock((500,500,2))(m2)
m3 = DenseEulerMergeBlock()(m2 + d3)

d4 = DenseEulerFBlock((500,500,2))(m3)
m4 = DenseEulerMergeBlock()(m3 + d4)

d5 = DenseEulerFBlock((500,500,2))(m4)
m5 = DenseEulerMergeBlock()(m4 + d5)

d6 = DenseEulerFBlock((500,500,2))(m5)
m6 = DenseEulerMergeBlock()(m5 + d6)


true0 = tf.keras.Input(shape=(1, 2))
model = tf.keras.Model([input0, true0], [input0, m1, m2, m3, m4, m5, m6, true0])
model.add_loss(DenseCombinedLoss(d1, d2, d3, d4, d5, d6, m6, true0, sigma=0.1))

opt = tf.keras.optimizers.Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=opt, loss=None)

### LLDDM Resnets using Chamfer's Distance

#### Loss definition

In [17]:
def chamfers_distance(y_pred, y_true):
       """Computes Chamfer's distance between two 3D point clouds.
       """
       cd1 =  tf.math.reduce_sum(tf.math.sqrt(
              tf.math.reduce_min(tf.math.reduce_sum(
              tf.math.square(tf.reshape(y_pred, (batch_size, 1, 1, 3)) - y_true), axis=-1), axis=1)))
       cd2 =  tf.math.reduce_sum(tf.math.sqrt(
              tf.math.reduce_min(tf.math.reduce_sum(
              tf.math.square(tf.reshape(y_true, (batch_size, 1, 1, 3)) - y_pred), axis=-1), axis=1)))
       return cd1 + cd2


def DenseCombinedCDLoss(d1, d2, d3, d4, d5, d6, m6, truth, sigma=0.1):
    regularisation_loss = 0.5*(tf.norm(d1) + tf.norm(d2) + tf.norm(d3) + tf.norm(d4) + tf.norm(d5) + tf.norm(d6))/6
    data_term = 0.5*chamfers_distance(m6, truth)/(sigma**2)
    return data_term + regularisation_loss

#### Data import

In [18]:
def convert_obj_to_numpy(obj):
    """Convert .obj to .npy object. For use with SCHREC19 objects.
    """
    ms = pymeshlab.MeshSet()
    ms.load_new_mesh(obj)
    m = ms.current_mesh()
    v_matrix = m.vertex_matrix()
    return v_matrix


def collect_source_and_target(sourcefile, targetfile):
    """Loads and returns normalised sourcefile.obj and 
       targetfile.obj objects as numpy arrays.
    """
    source_shape  = convert_obj_to_numpy(sourcefile)
    target_shape  = convert_obj_to_numpy(targetfile)

    source_shape = source_shape.reshape((source_shape.shape[0], 1, 3))
    target_shape = target_shape.reshape((target_shape.shape[0], 1, 3))

    source_scaling = np.max(source_shape)
    source_bias    = np.min(source_shape)
    source_shape   = (source_shape - np.min(source_shape))/np.max(source_shape)
    source_shape   = (source_shape - np.min(source_shape))/np.max(source_shape)

    target_scaling = np.max(target_shape)
    target_bias    = np.min(target_shape)
    target_shape   = (target_shape - np.min(target_shape))/np.max(target_shape)
    target_shape   = (target_shape - np.min(target_shape))/np.max(target_shape)

    source_shape = tf.convert_to_tensor(source_shape)
    target_shape = tf.convert_to_tensor(target_shape)
    
    return source_shape, target_shape


source_shape, target_shape = collect_source_and_target('scan_015.obj', 'scan_016.obj')
batch_size=source_shape.shape[0]

#### Example net

In [19]:
tf.keras.backend.clear_session()
input0 = tf.keras.Input(shape=(1, 3))

d1 = DenseEulerFBlock((800,800,3))(input0)
m1 = DenseEulerMergeBlock()(input0 + d1)

d2 = DenseEulerFBlock((800,800,3))(m1)
m2 = DenseEulerMergeBlock()(m1 + d2)

d3 = DenseEulerFBlock((800,800,3))(m2)
m3 = DenseEulerMergeBlock()(m2 + d3)

d4 = DenseEulerFBlock((800,800,3))(m3)
m4 = DenseEulerMergeBlock()(m3 + d4)

d5 = DenseEulerFBlock((800,800,3))(m4)
m5 = DenseEulerMergeBlock()(m4 + d5)

d6 = DenseEulerFBlock((800,800,3))(m5)
m6 = DenseEulerMergeBlock()(m5 + d6)


true0 = tf.keras.Input(shape=(1, 3))
model = tf.keras.Model([input0, true0], [input0, m1, m2, m3, m4, m5, m6, true0])
model.add_loss(DenseCombinedCDLoss(d1, d2, d3, d4, d5, d6, m6, true0, sigma=0.1))

opt = tf.keras.optimizers.Adam(learning_rate=5e-6,beta_1=0.9,beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=opt, loss=None)

#### Running the net

In [None]:
# 2 - running the net
model.fit(x=[source_shape, target_shape], y=None, epochs=1000, verbose=1, batch_size=batch_size);
plt.plot(model.history.history['loss']);
print(model.history.history['loss'][-1])

#### Flow reconstruction

In [29]:
def reconstruct_numpy_to_points_and_vertices(np_arr):
    """Constructs vertices and faces from a numpy array of points.
    """
    ms = pymeshlab.Mesh(np_arr)
    ms.generate_surface_reconstruction_ball_pivoting()
    m = ms.current_mesh()
    v_matrix = m.vertex_matrix()
    f_matrix = m.face_matrix()
    return [v_matrix, f_matrix]


def make_pointed_flow(model, source_shape, target_shape, separator=50):
    """Constructs a numpy array characteristing the time-series 
       flow of source to target shape according to a trained LDDMM resnet.
    """
    prediction = model.predict([source_shape, target_shape])
    prediction[0] = prediction[0].reshape((prediction[0].shape[0], 3))
    full_flow = np.copy(prediction[0]*100)
    for i in range(1, len(prediction)):
        prediction[i] = prediction[i].reshape((prediction[i].shape[0], 3))
        full_flow = np.append(full_flow, prediction[i]*100 + i*separator*np.array([1, 0, 0]), axis=0)
    return full_flow


def visualise_flow(point_cloud, pts_per_obj=None, view_style='turntable', mesh_type='ball', mesh=True):
    """Visualise the network's generated point cloud flow using polyscope and open3d libraries.
    """
    ps.init()
    ps.set_navigation_style(view_style)
    if mesh is not True:
        # visualise non-meshed flow (just a point-cloud)  
        ps.register_point_cloud("pointed_flow", point_cloud)

    elif mesh is True:
        if pts_per_obj is None:
            raise ValueError("As you are meshing, you need to specify `pts_per_obj`. Generally this is point_cloud.shape[0]/(resnet_lddmm timesteps + 2)")
            
        for i in range(point_cloud.shape[0]//pts_per_obj):
            # visualising it meshed
            obj = point_cloud[pts_per_obj*i:pts_per_obj*(i+1)]

            # import points into the open3d 03d object
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(obj)
            pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=5, max_nn=100)) # estimates normals for the point cloud
            
            if mesh_type == 'ball':
                radius = np.mean(pcd.compute_nearest_neighbor_distance())
                meshes = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd,o3d.utility.DoubleVector([radius, radius]))
            
            elif mesh_type == 'poisson':
                # computes the smooth poisson mesh of the point cloud
                meshes = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=11, width=0, scale=1.5, linear_fit=True)[0]
            
            elif mesh_type == 'alpha':
                alpha_val = 0.1 # adjust as necessary
                meshes = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha_val)

            # write mesh to .obj file, which will then be viewed by pymeshlab and then polyscope
            o3d.io.write_triangle_mesh("full_flow_meshed_pivoted" + "{}".format(i) + ".obj", meshes)

            # create pymeshlab object that will then have states and properties stored
            ms = pymeshlab.MeshSet()
            ms.load_new_mesh('full_flow_meshed_pivoted' + '{}'.format(i) + '.obj')
            m = ms.current_mesh()

            # get numpy arrays of vertices and faces of the current mesh
            v_matrix = m.vertex_matrix()
            f_matrix = m.face_matrix()

            # visualise with polyscope
            # a=ps.register_point_cloud("full_flow_rasterised {}".format(i), v_matrix)
            b=ps.register_surface_mesh("full_flow_meshed {}".format(i), v_matrix, f_matrix, smooth_shade=True)
            b.set_back_face_policy('identical')
    ps.show()

#### Learning the model and visualising its flow:

In [25]:
# np.save('full_flow', visualise_flow(model, source_shape, target_shape))
# point_cloud = np.load('full_flow.npy')
# visualise_flow(point_cloud, 1001)