## Overview
This notebook will load in the saved Statistical Shape Model (SSM) components (mean shape, principal components, principal component variances) and provides examples on how to use the components to deform and visualize the SSM. 

- Examples are provided for manipulating the model using numpy arrays & visualizing using matplotlib 3D. 

- Examples on creating new vtk surface meshes (polydata) from the manipulated/deformed meshes are provided. 

- Optional visualizatons of the 3D surface meshes are provided (if the user has itkwidgets installed). 
    https://github.com/InsightSoftwareConsortium/itkwidgets#installation

### Load dependencies

In [1]:
import numpy as np
import vtk
import matplotlib.pyplot as plt
# the next line enables 3D plots to be interactive. Comment out to make static.
%matplotlib notebook 

### Define functions for loading data & manipulating SSM

In [2]:
def read_polydata(path):
    '''
    Function to read vtk polydata from filepath
    Input: 
        Path
    Returns vtk polydata object. 
    '''
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(path)
    reader.Update()

    return reader.GetOutput()

def get_mesh_physical_point_coords(mesh):
    '''
    Function to get ndarray for X/Y/Z positions of each
    point on the inputed vtk surface mesh (polydata)
    
    Inputs: 
    mesh: 
        A vtk polydata object for which to extract the xyz positions of its points. 
    
    Return:
    point_coordinates:
        3xN ndarray; N = number of points on the mesh surface. 
    '''
    point_coordinates = np.zeros((3, mesh.GetNumberOfPoints()))
    for pt_idx in range(mesh.GetNumberOfPoints()):
        point_coordinates[:,pt_idx] = mesh.GetPoint(pt_idx)
    return point_coordinates

def get_ssm_deformation(PCs, Vs, mean_coords, pc=0, n_sds=3):
    '''
    Function to Statistical Shape Model (SSM) deformed along given Principal Component.
    
    Inputs: 
    PCs: 
        NxM ndarray; N = number of points on surface, M = number of principal components in model
        Each column is a principal component.
    Vs: 
        M ndarray; M = number of principal components in model
        Each entry is the variance for the coinciding principal component in PCs
    mean_femur_coords: 
        3xN ndarray; N = number of points on surface.         
    pc: 
        The principal component of the SSM to deform
    n_sds: 
        The number of standard deviations (sd) to deform the SSM. 
        This can be positive or negative to scale the model in either direction. 
    
    Return:
    deformed_coords:
        3xN ndarray; N=number of points on mesh surface. 
        This includes the x/y/z position of each surface node after deformation using the SSM and
        the specified characteristics (pc, n_sds)
    '''
    pc_vector = PCs[:, pc]
    pc_vector_scale = np.sqrt(Vs[pc]) * n_sds # convert Variances to SDs & multiply by n_sds (negative/positive important)
    coords_deformation = pc_vector * pc_vector_scale
    deformed_coords = (mean_coords.flatten() + coords_deformation).reshape(mean_coords.shape)
    return deformed_coords

def create_vtk_mesh_from_deformed_points(mean_mesh, new_points):
    '''
    Create new vtk mesh (polydata) from a set of points (ndarray) deformed using the SSM. 
    
    Inputs: 
    mean_mesh: 
        vtk polydata of the mean mesh
    new_points:
        3xN ndarray; N=number of points on mesh surface (same as number of points on mean_mesh).
        This includes the x/y/z position of each surface node should be deformed to. 
    '''
    new_mesh = vtk.vtkPolyData()
    new_mesh.DeepCopy(mean_mesh)
    points = new_mesh.GetPoints()
    for pt_idx in range(new_points.shape[1]):
        points.SetPoint(pt_idx, new_points[:, pt_idx])
    
    return new_mesh

def write_vtk(mesh, path_to_save):
    '''
    Save a VTK mesh to file
    Inputs: 
    mesh: 
        a vtk mesh (polydata) that should be written to disk
    path_to_save:
        a string of the path (including filename) to save the mesh.
    '''
    writer = vtk.vtkPolyDataWriter()
    writer.SetFileName(path_to_save)
    writer.SetInputData(mesh)
    writer.Write()

### Load Data
- PCs = principal components 
- Vs = variances (for the principal components)
- mean_mesh = vtk polydata of the mean mesh of the dataset

The variable `bone` specifies the bone being analyzed in this notebook. 

In [3]:
bone = 'femur' # can be changed to 'tibia' or 'femur' to load and manipulate each SSM. 
PCs = np.loadtxt(f'./{bone}_PCs.txt')
Vs = np.loadtxt(f'./{bone}_Variances.txt')
mean_mesh = read_polydata(f'./mean_{bone}_mesh.vtk')

### Get mesh coordinates as numpy array & plot in 3D

In [4]:
mean_coords = get_mesh_physical_point_coords(mean_mesh)

In [13]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(mean_coords[0,:], mean_coords[1,:], mean_coords[2,:], s=0.1)

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x17d542eb0>

### Use defined functions to deform the mean mesh using the SSM. 
- Deform along principal component 0 (the first one)
- Deform 3 standard deviations in both the positive and negative directions. 

In [6]:
# minus 3 standard deviations version (PC 0)
ssm_pc_0_minus_3_sd = get_ssm_deformation(PCs, Vs, mean_coords, pc=0, n_sds=-3)
# positive 3 standard deviations version (PC 0)
ssm_pc_0_plus_3_sd = get_ssm_deformation(PCs, Vs, mean_coords, pc=0, n_sds=3)

### Plot all three bones at the same time. 
- Can you comment out the appropriate lines to show just one, or just two of the bones at the same time. 

In [12]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
# mean bone
ax.scatter(mean_coords[0,:], 
           mean_coords[1,:], 
           mean_coords[2,:],
           c='C0', # color of the mesh - specifying to be consistent (blue)
           s=0.1,  # size of the scatter points
           label='mean femur')
# minus 3 SD
ax.scatter(ssm_pc_0_minus_3_sd[0,:], 
           ssm_pc_0_minus_3_sd[1,:], 
           ssm_pc_0_minus_3_sd[2,:],
           c='C1', # color of the mesh - specifying to be consistent (orange)
           s=0.1,  # size of the scatter points 
           label='pc 0 -3SD')
# plus 3 SD
ax.scatter(ssm_pc_0_plus_3_sd[0,:], 
           ssm_pc_0_plus_3_sd[1,:], 
           ssm_pc_0_plus_3_sd[2,:],
           c='C2', # color of the mesh - specifying to be consistent (green)
           s=0.1,  # size of the scatter points
           label='pc 0 +3SD')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x17cb046d0>

### Create a new VTK mesh from the deformed points (from SSM). 

In [8]:
# create mesh for PC 0 & -3SD
ssm_pc_0_minus_3_sd_vtk_mesh = create_vtk_mesh_from_deformed_points(mean_mesh, ssm_pc_0_minus_3_sd)
# create mesh for PC 0 & +3SD
ssm_pc_0_plus_3_sd_vtk_mesh = create_vtk_mesh_from_deformed_points(mean_mesh, ssm_pc_0_plus_3_sd)


### Below will save the mesh(es) to disk

In [9]:
filename = 'test_mesh.vtk'
write_vtk(ssm_pc_0_minus_3_sd_vtk_mesh, 
          filename)

### If you have installed itkwidgets you can run the below to view the meshes in your browser
- https://github.com/InsightSoftwareConsortium/itkwidgets#installation
- `pip install itkwidgets`

Properties of the various meshes can be changed so you can try to visualize some differences between them. For example: 
- The colour can be changed
- The opacity can be changed to make one (or more than one) see through
- The surface can include edgesor made to have gride lines

In [10]:
from itkwidgets import view

In [11]:
view(geometries=[mean_mesh, 
                 ssm_pc_0_minus_3_sd_vtk_mesh,
                 ssm_pc_0_plus_3_sd_vtk_mesh])

Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…