In [9]:
%reload_ext autoreload
%autoreload 2

import vtk
from vtk.util import numpy_support

import numpy as np
import bloscpack as bp
import sys
import os

sys.path.append(os.path.join(os.environ['REPO_DIR'], 'utilities'))
from utilities2015 import *
from registration_utilities import *

from skimage.measure import marching_cubes, correct_mesh_orientation

from vis3d_utilities import *
from itertools import izip

%matplotlib inline

In [10]:
mesh_rootdir = create_if_not_exists('/home/yuncong/CSHL_meshes')
volume_dir = '/home/yuncong/CSHL_volumes/'
atlasAlignParams_dir = '/home/yuncong/CSHL_atlasAlignParams/'

In [11]:
volume_landmark_names_unsided = ['12N', '5N', '6N', '7N', '7n', 'AP', 'Amb', 'LC',
                                 'LRt', 'Pn', 'R', 'RtTg', 'Tz', 'VLL', 'sp5']
linear_landmark_names_unsided = ['outerContour']

labels_unsided = volume_landmark_names_unsided + linear_landmark_names_unsided
labels_unsided_indices = dict((j, i+1) for i, j in enumerate(labels_unsided))  # BackG always 0

labelMap_unsidedToSided = {'12N': ['12N'],
                            '5N': ['5N_L', '5N_R'],
                            '6N': ['6N_L', '6N_R'],
                            '7N': ['7N_L', '7N_R'],
                            '7n': ['7n_L', '7n_R'],
                            'AP': ['AP'],
                            'Amb': ['Amb_L', 'Amb_R'],
                            'LC': ['LC_L', 'LC_R'],
                            'LRt': ['LRt_L', 'LRt_R'],
                            'Pn': ['Pn_L', 'Pn_R'],
                            'R': ['R_L', 'R_R'],
                            'RtTg': ['RtTg'],
                            'Tz': ['Tz_L', 'Tz_R'],
                            'VLL': ['VLL_L', 'VLL_R'],
                            'sp5': ['sp5'],
                           'outerContour': ['outerContour']}

labelMap_sidedToUnsided = {n: nu for nu, ns in labelMap_unsidedToSided.iteritems() for n in ns}

from itertools import chain
labels_sided = list(chain(*(labelMap_unsidedToSided[name_u] for name_u in labels_unsided)))
labels_sided_indices = dict((j, i+1) for i, j in enumerate(labels_sided)) # BackG always 0

In [12]:
stack = 'MD594'
with open(atlasAlignParams_dir + '/%(stack)s/%(stack)s_3dAlignParams.txt' % {'stack': stack}, 'r') as f:
    lines = f.readlines()
    global_params = np.array(map(float, lines[0].strip().split()))
#     atlas_xdim, atlas_ydim, atlas_zdim  = np.array(map(float, lines[1].strip().split()))
    atlas_centroid = np.array(map(float, lines[2].strip().split()))
    test_xdim, test_ydim, test_zdim = np.array(map(int, lines[3].strip().split()))
    test_centroid = np.array(map(float, lines[4].strip().split()))

In [13]:
ann_xmin, ann_xmax, ann_ymin, ann_ymax, ann_zmin, ann_zmax = \
np.loadtxt(volume_dir + '%(stack)s/volume_%(stack)s_annotation_withOuterContour_limits.txt' % {'stack': stack}, dtype=np.int)

sco_xmin, sco_xmax, sco_ymin, sco_ymax, sco_zmin, sco_zmax = \
np.loadtxt(volume_dir + '%(stack)s/%(stack)s_scoreVolume_limits.txt' % {'stack': stack}, dtype=np.int)

In [14]:
def align_principle_axes(vertices_normalized0, vertices_normalized):
    
    u1, _, _ = np.linalg.svd(np.dot(vertices_normalized0.T, vertices_normalized0)/vertices_normalized0.shape[0])
    
    u, s, v = np.linalg.svd(np.dot(vertices_normalized.T, vertices_normalized)/vertices_normalized.shape[0])

    if np.dot(u[:,0], u1[:,0]) < 0:
        u[:,0] = -u[:,0]
    if np.dot(u[:,1], u1[:,1]) < 0:
        u[:,1] = -u[:,1]
    if np.dot(u[:,2], u1[:,2]) < 0:
        u[:,2] = -u[:,2]

    U, _, VT = np.linalg.svd(np.dot(u1, u.T))
    R = np.dot(U, VT)
    
    return R

In [15]:
from scipy.spatial import KDTree

def icp(template2, moving, num_iter=10, rotation_only=True):
    # https://www.wikiwand.com/en/Orthogonal_Procrustes_problem
    # https://www.wikiwand.com/en/Kabsch_algorithm
    
    moving2 = moving.copy()
    template = template2.copy()
    
    tree = KDTree(template)

    for i in range(num_iter):
        
        t = time.time()
        
        _, nns = tree.query(moving2)
        data = template[nns]
        M = np.dot(moving2.T, data)
        U, s, VT = np.linalg.svd(M)
        if rotation_only:
            s2 = np.ones_like(s)
            s2[-1] = np.sign(np.linalg.det(np.dot(U, VT).T))
            R = np.dot(np.dot(U, np.diag(s2)), VT).T
        else:
            R = np.dot(U, VT).T
            
        moving2 = np.dot(moving2, R.T)
        d = np.sum(np.sqrt(np.sum((moving2 - data)**2, axis=1)))
        if i > 1 and d_prev == d:
            break
        d_prev = d
        
        sys.stderr.write('icp @ %d err %.2f: %.2f\n' % (i, d, time.time() - t))
    
    M = np.dot(moving.T, template[nns])
    U, _, VT = np.linalg.svd(M)
    R = np.dot(U, VT).T
    
    return R

In [16]:
def average_shape(polydata_list, concensus_percentage=.5, num_simplify_iter=0, smooth=False):
    
    volume_list = []
    origin_list = []

    for p in polydata_list:
        t = time.time()
        v, orig, _ = polydata_to_volume(p)
        sys.stderr.write('polydata_to_volume: %.2f\n' % (time.time() - t))

        volume_list.append(v)
        origin_list.append(np.array(orig, np.int))

    t = time.time()
        
    common_mins = np.min(origin_list, axis=0).astype(np.int)
    relative_origins = origin_list - common_mins

    common_xdim, common_ydim, common_zdim = np.max([(v.shape[1]+o[0], v.shape[0]+o[1], v.shape[2]+o[2])
                                                    for v,o in zip(volume_list, relative_origins)], axis=0)

    common_volume_list = []

    for i, v in enumerate(volume_list):
        common_volume = np.zeros( (common_ydim, common_xdim, common_zdim), np.uint8)
        x0, y0, z0 = relative_origins[i]
        ydim, xdim, zdim = v.shape
        common_volume[y0:y0+ydim, x0:x0+xdim, z0:z0+zdim] = v

        common_volume_list.append((common_volume > 0).astype(np.int))

    average_volume = np.sum(common_volume_list, axis=0) >= min(2, len(common_volume_list)*concensus_percentage)
        
    sys.stderr.write('find common: %.2f\n' % (time.time() - t))

    print average_volume.shape
    
    t = time.time()
    average_polydata = volume_to_polydata(average_volume, common_mins, num_simplify_iter=num_simplify_iter, 
                                          smooth=smooth)
    sys.stderr.write('volume_to_polydata: %.2f\n' % (time.time() - t))
    
    return average_volume, average_polydata

In [17]:
from collections import defaultdict
centroid_allLandmarks = defaultdict(list)
average_polydata_allLandmarks = {}
polydata_list_allLandmarks = {}

# for name_unsided in labels_unsided:
for name_unsided in ['outerContour']:
    
    print name_unsided    
    
    vertices_list = []
    faces_list = []
    
    for stack in ['MD589', 'MD594']:
#     for stack in ['MD594']:

        names = labelMap_unsidedToSided[name_unsided]
        
        if len(names) == 2:
        
            for name_sided in names:

                fn = "/home/yuncong/CSHL_meshes/%(stack)s/%(stack)s_%(label)s_gaussianSmoothed.stl" % {'stack':stack, 'label':name_sided}
               
                if os.path.exists(fn):
                    vertices, faces = load_mesh_stl(fn)
                else:
                    continue
            
                if '_L' in name_sided:
                    zmean = vertices[:,2].mean(axis=0)
                    vertices[:, 2] = - (vertices[:, 2] - zmean) + zmean # mirror L to align with R, keep zmin        

                vertices_list.append(vertices)

                if stack == 'MD594':
                    vertices_alignedToScoreVolume = vertices + (ann_xmin,ann_ymin,ann_zmin) - (sco_xmin,sco_ymin,sco_zmin)
                    vertices_alignedToAtlas = transform_points_inverse(global_params, pts_prime=vertices_alignedToScoreVolume, 
                                                                       c_prime=test_centroid, c=atlas_centroid)

                faces_list.append(faces)

                centroid_prime = vertices.mean(axis=0)

                if stack == 'MD594':

                    centroid_prime_alignedToScoreVolume = centroid_prime + (ann_xmin,ann_ymin,ann_zmin) - (sco_xmin,sco_ymin,sco_zmin)
                    centroid_prime_alignedToAtlas = transform_points_inverse(global_params, pts_prime=[centroid_prime_alignedToScoreVolume], 
                                                                       c_prime=test_centroid, c=atlas_centroid)[0]
                    centroid_allLandmarks[name_sided].append(centroid_prime_alignedToAtlas)
                else:
                    centroid_allLandmarks[name_sided].append(centroid_prime)
        
        elif len(names) == 1:
            
            name_sided = names[0]
            
            fn = "/home/yuncong/CSHL_meshes/%(stack)s/%(stack)s_%(label)s_gaussianSmoothed.stl" % {'stack':stack, 'label':name_sided}
            
            if os.path.exists(fn):
                vertices, faces = load_mesh_stl(fn)
            else:
                continue

#             vertices_mirrored = vertices.copy()    
#             zmean = vertices[:,2].mean(axis=0)
#             vertices_mirrored[:, 2] = - (vertices[:, 2] - zmean) + zmean # mirror L to align with R, keep zmin        
            
#             t = time.time()
            
#             if name_unsided == 'sp5' or name_unsided == 'outerContour':
#                 _, symmetric_poly = average_shape([mesh_to_polydata(vertices, faces), mesh_to_polydata(vertices_mirrored, faces)],
#                                               num_simplify_iter=5, smooth=True)
#             else:
#                 _, symmetric_poly = average_shape([mesh_to_polydata(vertices, faces), mesh_to_polydata(vertices_mirrored, faces)],
#                                               num_simplify_iter=3, smooth=True)
                
#             # must simplify so the ICP later can be fast
#             # must also smooth, because otherwise the number of points gets stuck
            
#             sys.stderr.write('average shape: %.2f\n' % (time.time() - t))
#             # cost mostly comes from :
#             # - marching cube
#             # - first two rounds of decimation and smoothing

#             symmetric_vertices, symmetric_faces = polydata_to_mesh(symmetric_poly)
         
#             vertices_list.append(symmetric_vertices)
#             faces_list.append(symmetric_faces)
            
            symmetric_vertices, symmetric_faces = (vertices, faces)
            vertices_list.append(symmetric_vertices)
            faces_list.append(symmetric_faces)
    
            centroid_prime = symmetric_vertices.mean(axis=0)

            if stack == 'MD594':

                centroid_prime_alignedToScoreVolume = centroid_prime + (ann_xmin,ann_ymin,ann_zmin) - (sco_xmin,sco_ymin,sco_zmin)
                centroid_prime_alignedToAtlas = transform_points_inverse(global_params, pts_prime=[centroid_prime_alignedToScoreVolume], 
                                                                   c_prime=test_centroid, c=atlas_centroid)[0]
                centroid_allLandmarks[name_sided].append(centroid_prime_alignedToAtlas)
            else:
                centroid_allLandmarks[name_sided].append(centroid_prime)

            
#             for i in range(2):
#                 if i == 1:
#                     zmean = vertices[:,2].mean(axis=0)
#                     vertices[:, 2] = - (vertices[:, 2] - zmean) + zmean # mirror L to align with R, keep zmin        

#                 vertices_list.append(vertices.copy())

#                 if stack == 'MD594':
#                     vertices_alignedToScoreVolume = vertices + (ann_xmin,ann_ymin,ann_zmin) - (sco_xmin,sco_ymin,sco_zmin)
#                     vertices_alignedToAtlas = transform_points_inverse(global_params, pts_prime=vertices_alignedToScoreVolume, 
#                                                                        c_prime=test_centroid, c=atlas_centroid)

#                 faces_list.append(faces)

#                 centroid_prime = vertices.mean(axis=0)

#                 if stack == 'MD594':

#                     centroid_prime_alignedToScoreVolume = centroid_prime + (ann_xmin,ann_ymin,ann_zmin) - (sco_xmin,sco_ymin,sco_zmin)
#                     centroid_prime_alignedToAtlas = transform_points_inverse(global_params, pts_prime=[centroid_prime_alignedToScoreVolume], 
#                                                                        c_prime=test_centroid, c=atlas_centroid)[0]
#                     centroid_allLandmarks[name_sided].append(centroid_prime_alignedToAtlas)
#                 else:
#                     centroid_allLandmarks[name_sided].append(centroid_prime)
                            
    
    ###### Align meshes ######
    
    u1 = None
    vertices_normalized_aligned_list = []
    vertices_aligned_list = []
    centroid_list = []

    for i, vertices in enumerate(vertices_list):

        centroid = vertices.mean(axis=0)
        centroid_list.append(centroid)

        scale = np.sqrt(((vertices - centroid)**2).mean())
        vertices_normalized = (vertices - centroid) / scale

        if i == 0:
            vertices_normalized1 = vertices_normalized.copy()
            R = np.eye(3)
        else:            
            t = time.time()
            R = icp(vertices_normalized1, vertices_normalized, num_iter=100)
            sys.stderr.write('icp: %.2f\n' % (time.time() - t))

        print R

        vertices_normalized_alignedTo1 = np.dot(vertices_normalized, R.T)
        vertices_normalized_aligned_list.append(vertices_normalized_alignedTo1)

        vertices_alignedTo1 = vertices_normalized_alignedTo1 * scale
        vertices_aligned_list.append(vertices_alignedTo1)
    
    polydata_list = [mesh_to_polydata(vs, fs) for vs, fs in zip(vertices_aligned_list, faces_list)]
            
    # if landmark has only one instance, add its mirrored version
#     if len(labelMap_unsidedToSided[name_unsided]) == 1:
#         for vs, fs in zip(vertices_aligned_list, faces_list):
#             zmean = vs[:,2].mean(axis=0)
#             vs_mirrored = vs.copy()
#             vs_mirrored[:, 2] = - (vs[:, 2] - zmean) + zmean # mirror L to align with R, keep zmin        
#             polydata_list.append(mesh_to_polydata(vs_mirrored, fs))

    polydata_list_allLandmarks[name_unsided] = polydata_list
    
    ######### Compute Average #########
    t = time.time()

    if name_unsided == 'outerContour' or name_unsided == 'sp5':
        _, average_polydata = average_shape(polydata_list, num_simplify_iter=5, smooth=True)
    else:
        _, average_polydata = average_shape(polydata_list, num_simplify_iter=3, smooth=True)
    
    sys.stderr.write('average shape: %.2f\n' % (time.time() - t))

    average_polydata_allLandmarks[name_unsided] = average_polydata

#     bp.pack_ndarray_file(average_volume, mesh_rootdir + "/%(name)s_average.bp" % {'stack': stack, 'name': name_unsided})
    
    save_mesh_stl(average_polydata, mesh_rootdir + "/%(name)s_average.stl" % {'stack': stack, 'name': name_unsided})

centroid_allLandmarks.default_factory = None

outerContour
[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]]

icp @ 0 err 3254.18: 7.15
icp @ 1 err 3132.45: 6.89
icp @ 2 err 3056.80: 6.79
icp @ 3 err 3011.54: 6.70
icp @ 4 err 2982.09: 6.63
icp @ 5 err 2953.09: 6.60
icp @ 6 err 2928.58: 6.83
icp @ 7 err 2909.55: 6.56
icp @ 8 err 2896.79: 6.53
icp @ 9 err 2887.98: 6.69
icp @ 10 err 2881.01: 6.54
icp @ 11 err 2876.38: 6.78
icp @ 12 err 2873.00: 6.85
icp @ 13 err 2870.95: 6.65
icp @ 14 err 2869.71: 6.42
icp @ 15 err 2868.80: 7.69
icp @ 16 err 2868.42: 6.38
icp @ 17 err 2867.84: 6.35
icp @ 18 err 2867.38: 6.35
icp @ 19 err 2866.72: 6.34
icp @ 20 err 2866.45: 6.35
icp @ 21 err 2866.22: 6.34
icp @ 22 err 2865.90: 6.33
icp @ 23 err 2865.66: 6.35
icp @ 24 err 2865.53: 6.52
icp @ 25 err 2865.42: 6.54
icp @ 26 err 2865.31: 6.38
icp @ 27 err 2865.21: 6.33
icp @ 28 err 2865.12: 6.31
icp @ 29 err 2865.05: 6.39
icp @ 30 err 2865.02: 6.33
icp @ 31 err 2864.99: 6.33
icp @ 32 err 2864.95: 6.33
icp @ 33 err 2864.90: 6.32
icp @ 34 err 2864.79: 6.33
icp @ 35 err 2864.67: 6.43
icp @ 36 err 2864.52: 6.58
icp @ 37 er


[[ 0.97954637 -0.18393658  0.08158516]
 [ 0.19129926  0.97700858 -0.09412119]
 [-0.06239707  0.10780326  0.99221218]]

fill point array: 0.02 seconds
fill cell array: 0.00 seconds
fill point array: 0.02 seconds
fill cell array: 0.00 seconds
polydata_to_volume: 0.78
polydata_to_volume: 0.98
find common: 11.24



(543, 951, 545)
area: 2490494.54

marching cube: 29.27 seconds
compute surface area: 0.89 seconds
fill point array: 3.39 seconds
fill cell array: 0.18 seconds
mesh_to_polydata: 3.59 seconds
simplify 0 @ 2610385: 26.89 seconds
simplify 1 @ 522738: 36.61 seconds
simplify 2 @ 105387: 8.38 seconds
simplify 3 @ 21928: 1.45 seconds
simplify 4 @ 5230: 0.21 seconds





volume_to_polydata: 108.16
average shape: 121.25


In [18]:
name_to_show = 'outerContour'

######## show overlay list of meshes ########
ren1 = vtk.vtkRenderer()
renWin1 = vtk.vtkRenderWindow()
renWin1.AddRenderer(ren1)
iren1 = vtk.vtkRenderWindowInteractor()
iren1.SetRenderWindow(renWin1)

colors = [(0,0,1), (0,1,0), (1,0,0), (1,1,0)]

for i, polydata in enumerate(polydata_list_allLandmarks[name_to_show]):
    
    m = vtk.vtkPolyDataMapper()
    m.SetInputData(polydata)

    a = vtk.vtkActor()
    a.SetMapper(m)
    a.GetProperty().SetRepresentationToWireframe()
    a.GetProperty().SetColor(colors[i % len(colors)])
    
    ren1.AddActor(a)

axes_widget1 = add_axes(iren1)

renWin1.Render()
renWin1.SetWindowName('overlay')

camera = vtk.vtkCamera()
ren1.SetActiveCamera(camera)
ren1.ResetCamera()

iren1.Start()

In [None]:
name_to_show = 'sp5'

######## show overlay list of meshes ########
ren1 = vtk.vtkRenderer()
renWin1 = vtk.vtkRenderWindow()
renWin1.AddRenderer(ren1)
iren1 = vtk.vtkRenderWindowInteractor()
iren1.SetRenderWindow(renWin1)

colors = [(1,0,0), (0,1,0), (0,0,1), (1,1,0)]

for i, polydata in enumerate(polydata_list_allLandmarks[name_to_show]):
    
    m = vtk.vtkPolyDataMapper()
    m.SetInputData(polydata)

    a = vtk.vtkActor()
    a.SetMapper(m)
    a.GetProperty().SetRepresentationToWireframe()
    a.GetProperty().SetColor(colors[i % len(colors)])
    
    ren1.AddActor(a)

axes_widget1 = add_axes(iren1)

renWin1.Render()
renWin1.SetWindowName('overlay')

######### show average mesh #########
ren2 = vtk.vtkRenderer()

renWin2 = vtk.vtkRenderWindow()
renWin2.AddRenderer(ren2)

iren2 = vtk.vtkRenderWindowInteractor()
iren2.SetRenderWindow(renWin2)

m2 = vtk.vtkPolyDataMapper()
m2.SetInputData(average_polydata_allLandmarks[name_to_show])

a2 = vtk.vtkActor()
a2.SetMapper(m2)
a2.GetProperty().SetRepresentationToWireframe()
# a.GetProperty().SetColor(colors[2])

ren2.AddActor(a2)
axes_widget2 = add_axes(iren2)

renWin2.Render()
renWin2.SetWindowName('average')

#####################################

camera = vtk.vtkCamera()
ren1.SetActiveCamera(camera)
ren2.SetActiveCamera(camera)
ren1.ResetCamera()
ren2.ResetCamera()

iren1.Start()
iren2.Start()