In [None]:
# Aligns a score volume with an annotation volume

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

import sys
import os

sys.path.append(os.environ['REPO_DIR'] + '/utilities')
from utilities2015 import *
from registration_utilities import *

import matplotlib.pyplot as plt
%matplotlib inline

from joblib import Parallel, delayed
import time

import logging

In [3]:
volume_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_volumes/'

In [5]:
labels_twoSides = []
labels_twoSides_indices = {}
with open(volume_dir + '/MD589/volume_MD589_annotation_withOuterContour_labelIndices.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        name, index = line.split()
        labels_twoSides.append(name)
        labels_twoSides_indices[name] = int(index)
        
labelMap_sidedToUnsided = {name: name if '_' not in name else name[:-2] for name in labels_twoSides_indices.keys()}
labels_unsided = ['BackG'] + sorted(set(labelMap_sidedToUnsided.values()) - {'BackG', 'outerContour'}) + ['outerContour']
labels_unsided_indices = dict((j, i) for i, j in enumerate(labels_unsided))

from collections import defaultdict

labelMap_unsidedToSided = defaultdict(list)
for name_sided, name_unsided in labelMap_sidedToUnsided.iteritems():
    labelMap_unsidedToSided[name_unsided].append(name_sided)
labelMap_unsidedToSided.default_factory = None

In [6]:
atlas_volume = bp.unpack_ndarray_file(os.path.join(volume_dir, 'MD589/volume_MD589_annotation_withOuterContour.bp'))

atlas_ydim, atlas_xdim, atlas_zdim = atlas_volume.shape
atlas_centroid = np.array([.5*atlas_xdim, .5*atlas_ydim, .5*atlas_zdim])
print atlas_centroid

[ 419.   229.5  267. ]


In [7]:
def parallel_where(name, num_samples=None):
    
    w = np.where(atlas_volume == labels_twoSides_indices[name])
    
    if num_samples is not None:
        n = len(w[0])
        sample_indices = np.random.choice(range(n), min(num_samples, n), replace=False)
        return np.c_[w[1][sample_indices].astype(np.int16), 
                     w[0][sample_indices].astype(np.int16), 
                     w[2][sample_indices].astype(np.int16)]
    else:
        return np.c_[w[1].astype(np.int16), w[0].astype(np.int16), w[2].astype(np.int16)]

t = time.time()

atlas_nzs = Parallel(n_jobs=16)(delayed(parallel_where)(name, num_samples=int(1e5)) for name in labels_twoSides[1:])

atlas_nzs = {name: nzs for name, nzs in zip(labels_twoSides[1:], atlas_nzs)}

sys.stderr.write('load atlas: %f seconds\n' % (time.time() - t)) #~ 7s

load atlas: 7.055487 seconds


In [9]:
# atlas_nzs_full = Parallel(n_jobs=16)(delayed(parallel_where)(name) for name in labels_twoSides[1:])
# atlas_nzs_full = {name: nzs for name, nzs in zip(labels_twoSides[1:], atlas_nzs_full)}

In [8]:
downsample_factor = 16

section_thickness = 20 # in um
xy_pixel_distance_lossless = 0.46
xy_pixel_distance_tb = xy_pixel_distance_lossless * 32 # in um, thumbnail

xy_pixel_distance_downsampled = xy_pixel_distance_lossless * downsample_factor
z_xy_ratio_downsampled = section_thickness / xy_pixel_distance_downsampled

In [9]:
atlasAlignOptLogs_dir = create_if_not_exists('/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignOptLogs')
atlasAlignParams_dir = create_if_not_exists('/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignParams')
annotationsViz_rootdir = '/oasis/projects/nsf/csd395/yuncong/CSHL_annotaionsPojectedViz'

In [10]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [11]:
pts_centered = {name: (np.concatenate([atlas_nzs[n] for n in labelMap_unsidedToSided[name]]) - atlas_centroid).astype(np.int16) 
                         for name in labels_unsided[1:]}

In [12]:
label_weights = {name: .1 if name == 'outerContour' else 1. for name in labels_unsided[1:]}

In [13]:
def compute_score_and_gradient(T):
    global pts_centered
    
    score = 0
    dMdA = np.zeros((12,))
    
    for name in labels_unsided[1:]:
#         t1 = time.time()
    
        pts_prime = transform_points(T, pts_centered=pts_centered[name], c_prime=test_centroid)
        
        xs_prime, ys_prime, zs_prime = pts_prime.T.astype(np.int16)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
                (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
                   
        if np.count_nonzero(valid) > 0:

            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            voxel_probs_valid = volume2_allLabels[name][ys_prime_valid, xs_prime_valid, zs_prime_valid] / 1e4

            score += label_weights[name] * voxel_probs_valid.sum()
            
            Sx = dSdxyz[name][0, ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sy = dSdxyz[name][1, ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sz = dSdxyz[name][2, ys_prime_valid, xs_prime_valid, zs_prime_valid]
            
            dxs, dys, dzs = pts_centered[name][valid].T

            q = np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
                          Sy*dxs, Sy*dys, Sy*dzs, Sy,
                          Sz*dxs, Sz*dys, Sz*dzs, Sz]        
            
            dMdA += label_weights[name] * q.sum(axis=0)
            
            del voxel_probs_valid, q, Sx, Sy, Sz, dxs, dys, dzs, xs_prime_valid, ys_prime_valid, zs_prime_valid
        
#         sys.stderr.write('########### %s: %f seconds\n' % (labels[l], time.time() - t1))
        
        del valid, xs_prime, ys_prime, zs_prime, pts_prime
        
    return score, dMdA

In [14]:
def compute_score(T):
    
    score = 0
    for name in labels_unsided[1:]:
        
        pts_prime = transform_points(T, pts_centered=pts_centered[name], c_prime=test_centroid)
    
        xs_prime, ys_prime, zs_prime = pts_prime.T.astype(np.int16)
        
        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
            (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
            
        voxel_probs_valid = volume2_allLabels[name][ys_prime[valid], xs_prime[valid], zs_prime[valid]] / 1e4

        score += label_weights[name] * voxel_probs_valid.sum()
                
        del voxel_probs_valid, valid, xs_prime, ys_prime, zs_prime, pts_prime
                
    return score

def compute_score_gradient(T):

    dMdA = np.zeros((12,))

    for name in labels_unsided[1:]:
#       
        pts_prime = transform_points(T, pts_centered=pts_centered[name], c_prime=test_centroid)

        xs_prime, ys_prime, zs_prime = pts_prime.T.astype(np.int16)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
            (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
            
        if np.count_nonzero(valid) > 0:
            
            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            Sx = dSdxyz[name][0, ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sy = dSdxyz[name][1, ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sz = dSdxyz[name][2, ys_prime_valid, xs_prime_valid, zs_prime_valid]
               
            dxs, dys, dzs = pts_centered[name][valid].T
                        
            dMdA += label_weights[name] * np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
                          Sy*dxs, Sy*dys, Sy*dzs, Sy,
                          Sz*dxs, Sz*dys, Sz*dzs, Sz].sum(axis=0)
            
    return dMdA

In [16]:
stack = 'MD594'

In [17]:
################# LOAD TEST VOLUME ######################

section_bs_begin, section_bs_end = section_range_lookup[stack]
print section_bs_begin, section_bs_end

(volume_xmin, volume_xmax, volume_ymin, volume_ymax, volume_zmin, volume_zmax) = \
np.loadtxt(os.path.join(volume_dir, '%(stack)s/%(stack)s_scoreVolume_limits.txt' % {'stack': stack}), dtype=np.int)

map_z_to_section = {}
for s in range(section_bs_begin, section_bs_end+1):
    for z in range(int(z_xy_ratio_downsampled*s) - volume_zmin, int(z_xy_ratio_downsampled*(s+1)) - volume_zmin + 1):
        map_z_to_section[z] = s


volume2_allLabels = {}

for name in labels_unsided:

    if name == 'BackG':
        continue

    volume2_roi = bp.unpack_ndarray_file(os.path.join(volume_dir, '%(stack)s/%(stack)s_scoreVolume_%(label)s.bp' % \
                                                      {'stack': stack, 'label': name})).astype(np.float16)
    volume2_allLabels[name] = volume2_roi
    del volume2_roi

test_ydim, test_xdim, test_zdim = volume2_allLabels.values()[0].shape
test_centroid = np.array([.5*test_xdim, .5*test_ydim, .5*test_zdim])

print test_xdim, test_ydim, test_zdim
print test_centroid

# test_xdim = volume_xmax - volume_xmin + 1
# test_ydim = volume_ymax - volume_ymin + 1
# test_zdim = volume_zmax - volume_zmin + 1

###################### Load Gradient #####################

dSdxyz = {name: np.empty((3, test_ydim, test_xdim, test_zdim), dtype=np.float16) for name in labels_unsided[1:]}

t1 = time.time()

for name in labels_unsided:

    if name == 'BackG':
        continue

    t = time.time()

    dSdxyz[name][0] = bp.unpack_ndarray_file(volume_dir + '/%(stack)s/%(stack)s_scoreVolume_%(label)s_gx.bp' % {'stack':stack, 'label':name})
    dSdxyz[name][1] = bp.unpack_ndarray_file(volume_dir + '/%(stack)s/%(stack)s_scoreVolume_%(label)s_gy.bp' % {'stack':stack, 'label':name})
    dSdxyz[name][2] = bp.unpack_ndarray_file(volume_dir + '/%(stack)s/%(stack)s_scoreVolume_%(label)s_gz.bp' % {'stack':stack, 'label':name})

    sys.stderr.write('load gradient %s: %f seconds\n' % (name, time.time() - t))

sys.stderr.write('overall: %f seconds\n' % (time.time() - t1)) # 140s

93 364
844

load gradient 12N: 5.484035 seconds
load gradient 5N: 6.940750 seconds
load gradient 6N: 5.883514 seconds
load gradient 7N: 6.698527 seconds
load gradient 7n: 5.487856 seconds
load gradient AP: 5.461054 seconds
load gradient Amb: 5.355412 seconds
load gradient LC: 6.452136 seconds
load gradient LRt: 6.066488 seconds
load gradient Pn: 6.762748 seconds
load gradient R: 6.241635 seconds
load gradient RtTg: 7.202044 seconds
load gradient Tz: 5.602598 seconds
load gradient VLL: 6.871995 seconds
load gradient sp5: 5.807950 seconds
load gradient outerContour: 6.142578 seconds


 484 443
[ 422.   242.   221.5]


overall: 98.468625 seconds


In [21]:
handler = logging.FileHandler(atlasAlignOptLogs_dir + '/%(stack)s_atlasAlignOpt.log' % {'stack': stack})
handler.setLevel(logging.INFO)
logger.addHandler(handler)

################# GRID SEARCH ######################

grid_search_iteration_number = 10
# grid_search_iteration_number = 1

params_best_upToNow = (0, 0, 0)
score_best_upToNow = 0

init_n = 1000

for iteration in range(grid_search_iteration_number):

    logger.info('grid search iteration %d', iteration)

    init_tx, init_ty, init_tz  = params_best_upToNow

    n = int(init_n*np.exp(-iteration/3.))

    sigma_tx = 300*np.exp(-iteration/3.)
    sigma_ty = 300*np.exp(-iteration/3.)
    sigma_tz = 100*np.exp(-iteration/3.)

    tx_grid = init_tx + sigma_tx * (2 * np.random.random(n) - 1)
    ty_grid = init_ty + sigma_ty * (2 * np.random.random(n) - 1)
    tz_grid = init_tz + sigma_tz * (2 * np.random.random(n) - 1)

    samples = np.c_[tx_grid, ty_grid, tz_grid]

    import time

    t = time.time()
    # num jobs * memory each job
    
    scores = Parallel(n_jobs=8)(delayed(compute_score)(np.array([1, 0, 0, tx, 0, 1, 0, ty, 0, 0, 1, tz]))
                                 for tx, ty, tz in samples)

#     scores = []
#     for tx, ty, tz in samples:
#         scores.append(compute_score([1, 0, 0, tx, 0, 1, 0, ty, 0, 0, 1, tz]))

    sys.stderr.write('grid search: %f seconds\n' % (time.time() - t)) # ~23s
    
    score_best = np.max(scores)

    tx_best, ty_best, tz_best = samples[np.argmax(scores)]

    if score_best > score_best_upToNow:
        logger.info('%f %f', score_best_upToNow, score_best)

        score_best_upToNow = score_best
        params_best_upToNow = tx_best, ty_best, tz_best

        logger.info('%f %f %f', tx_best, ty_best, tz_best)
    
    logger.info('\n')

INFO:__main__:grid search iteration 0
grid search: 23.652536 seconds
INFO:__main__:0.000000 22.896895
INFO:__main__:72.163303 -2.333709 4.358628
INFO:__main__:

INFO:__main__:grid search iteration 1
grid search: 19.867429 seconds
INFO:__main__:22.896895 31.645752
INFO:__main__:52.958007 -13.982644 -3.264277
INFO:__main__:

INFO:__main__:grid search iteration 2
grid search: 15.752052 seconds
INFO:__main__:31.645752 32.950571
INFO:__main__:57.643631 -2.465972 -1.885419
INFO:__main__:

INFO:__main__:grid search iteration 3
grid search: 11.736301 seconds
INFO:__main__:

INFO:__main__:grid search iteration 4
grid search: 8.945194 seconds
INFO:__main__:32.950571 34.997639
INFO:__main__:42.691690 -8.521901 -10.018324
INFO:__main__:

INFO:__main__:grid search iteration 5
grid search: 6.705439 seconds
INFO:__main__:34.997639 37.114959
INFO:__main__:49.579430 0.540843 -6.979395
INFO:__main__:

INFO:__main__:grid search iteration 6
grid search: 5.046976 seconds
INFO:__main__:

INFO:__main__:grid 

In [41]:
################# GRADIENT DESCENT ######################

lr1, lr2 = (10., 1e-1)
# lr1, lr2 = (1., 1e-3)

# auto_corr = .95

max_iter_num = 1000
fudge_factor = 1e-6 #for numerical stability
dMdA_historical = np.zeros((12,))

tx_best, ty_best, tz_best = params_best_upToNow
T_best = np.r_[1,0,0, tx_best, 0,1,0, ty_best, 0,0,1, tz_best]

lr = np.array([lr2, lr2, lr2, lr1, lr2, lr2, lr2, lr1, lr2, lr2, lr2, lr1])

score_best = 0

scores = []

for iteration in range(max_iter_num):

    logger.info('iteration %d', iteration)

#     t = time.time()
    s, dMdA = compute_score_and_gradient(T_best)
#     sys.stderr.write('compute_score_and_gradient: %f seconds\n' % (time.time() - t)) #~ 2s/iteration or ~.5s: 1e5 samples per landmark

    dMdA_historical += dMdA**2
#     dMdA_historical = auto_corr * dMdA_historical + (1-auto_corr) * dMdA**2

    dMdA_adjusted = dMdA / (fudge_factor + np.sqrt(dMdA_historical))

    T_best += lr*dMdA_adjusted

#         logger.info('A: ' + ' '.join(['%f']*12) % tuple(A_best))
#         logger.info('dMdA adjusted: ' + ' '.join(['%f']*12) % tuple(dMdA_adjusted))

    logger.info('score: %f', s)
    scores.append(s)

    logger.info('\n')

    history_len = 50
    if iteration > 100:
        if np.abs(np.mean(scores[iteration-history_len:iteration]) - \
                  np.mean(scores[iteration-2*history_len:iteration-history_len])) < 1e-1:
            break

    if s > score_best:
#             logger.info('Current best')
        best_gradient_descent_params = T_best
        score_best = s

    np.save(atlasAlignOptLogs_dir + '/%(stack)s_scoreEvolutions.npy' % {'stack':stack}, scores)

plt.plot(scores);

INFO:__main__:iteration 0
INFO:__main__:score: 37.223110
INFO:__main__:

INFO:__main__:iteration 1
INFO:__main__:score: 19.423907
INFO:__main__:

INFO:__main__:iteration 2
INFO:__main__:score: 14.992588
INFO:__main__:

INFO:__main__:iteration 3
INFO:__main__:score: 35.149701
INFO:__main__:

INFO:__main__:iteration 4
INFO:__main__:score: 43.225993
INFO:__main__:

INFO:__main__:iteration 5
INFO:__main__:score: 31.121026
INFO:__main__:

INFO:__main__:iteration 6
INFO:__main__:score: 54.053836
INFO:__main__:

INFO:__main__:iteration 7
INFO:__main__:score: 44.242153
INFO:__main__:

INFO:__main__:iteration 8
INFO:__main__:score: 42.196761
INFO:__main__:

INFO:__main__:iteration 9
INFO:__main__:score: 51.080771
INFO:__main__:

INFO:__main__:iteration 10
INFO:__main__:score: 56.038654
INFO:__main__:

INFO:__main__:iteration 11
INFO:__main__:score: 46.065938
INFO:__main__:

INFO:__main__:iteration 12
INFO:__main__:score: 61.802275
INFO:__main__:

INFO:__main__:iteration 13
INFO:__main__:score: 

In [18]:
with open(atlasAlignParams_dir + '/%(stack)s/%(stack)s_3dAlignParams.txt' % {'stack': stack}, 'r') as f:
    lines = f.readlines()
T_final = np.array(map(float, lines[1].strip().split()))

In [19]:
T_final

array([  9.89331000e-01,   2.79640000e-01,  -1.02790000e-02,
         4.27320680e+01,  -2.13326000e-01,   9.15999000e-01,
         8.35500000e-02,  -3.21320860e+01,   6.00710000e-02,
        -9.62790000e-02,   9.48397000e-01,   1.48596170e+01])

In [20]:
import numdifftools as nd
# https://media.readthedocs.org/pdf/numdifftools/latest/numdifftools.pdf

In [43]:
g = nd.Gradient(compute_score, step=np.r_[1e-1, 1e-1, 1e-1, 5,
                                          1e-1, 1e-1, 1e-1, 5,
                                          1e-1, 1e-1, 1e-1, 5])
g(T_final).reshape((3,4))

array([[  6.00927734,  -3.41088867,  -8.94720459,  -0.04060059],
       [  2.46795654,   6.37304688,  13.78807068,   0.11897095],
       [ 26.5664978 ,  11.76049805,   4.54318309,  -0.18624512]])

In [42]:
g = nd.Gradient(compute_score, step=np.r_[1e-2, 1e-2, 1e-2, 5,
                                          1e-2, 1e-2, 1e-2, 5,
                                          1e-2, 1e-2, 1e-2, 5])
g(T_final).reshape((3,4))

array([[  6.40136719e+00,  -1.84082031e+00,  -4.86450195e+00,
         -4.06005859e-02],
       [ -1.78344727e+01,   9.51049805e+00,   4.89379883e+00,
          1.18970947e-01],
       [  4.16870117e+01,   3.95385742e+00,  -2.40734863e+01,
         -1.86245117e-01]])

In [34]:
compute_score_and_gradient(T_final)[1] / 1e4 * 7

array([ 21.61148478,  -5.18398775,  -8.78326591,  -0.11946889,
       -41.65740612,  19.84589307,  11.66668153,   0.2068148 ,
        36.87631868,   0.40260411, -27.92260616,  -0.1375492 ])

In [46]:
h = nd.Hessdiag(compute_score, step=np.r_[1e-1, 1e-1, 1e-1, 5,
                                          1e-1, 1e-1, 1e-1, 5,
                                          1e-1, 1e-1, 1e-1, 5])
h(T_final).reshape((3,4))

array([[ -3.90701172e+03,  -8.38144531e+02,  -2.65509399e+03,
         -1.59375000e-01],
       [ -5.51532593e+03,  -2.22789551e+03,  -3.02422333e+03,
         -2.86456543e-01],
       [ -5.67038879e+03,  -2.03737305e+03,  -4.39688902e+03,
         -3.18267578e-01]])

In [47]:
h = nd.Hessdiag(compute_score, step=np.r_[1e-2, 1e-2, 1e-2, 5,
                                          1e-2, 1e-2, 1e-2, 5,
                                          1e-2, 1e-2, 1e-2, 5])
h(T_final).reshape((3,4))

array([[ -6.06542969e+03,  -1.24853516e+03,  -3.79125977e+03,
         -1.59375000e-01],
       [ -9.91210937e+03,  -3.15698242e+03,  -3.87182617e+03,
         -2.86456543e-01],
       [ -1.20551758e+04,  -3.13647461e+03,  -7.52709961e+03,
         -3.18267578e-01]])