In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import numpy as np

import sys
import os

sys.path.append(os.environ['REPO_DIR'] + '/utilities')
from utilities2015 import *

import matplotlib.pyplot as plt
%matplotlib inline

from joblib import Parallel, delayed
import time

import logging

In [14]:
from quarternion import *

In [15]:
# labels = ['BackG', '5N', '7n', '7N', '12N', 'Gr', 'LVe', 'Pn', 'SuVe', 'VLL']

labels = ['BackG', '5N', '7n', '7N', '12N', 'Gr', 'LVe', 'Pn', 'SuVe', 'VLL', 
                     '6N', 'Amb', 'R', 'Tz', 'Sol', 'RtTg', 'LRt', 'LC', 'AP', 'sp5']

n_labels = len(labels)

label_dict = dict([(l,i) for i, l in enumerate(labels)])

In [16]:
volume_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_volumes/'

In [17]:
volume1 = bp.unpack_ndarray_file(os.path.join(volume_dir, 'volume_MD589_annotation.bp'))
atlas_ydim, atlas_xdim, atlas_zdim = volume1.shape
print atlas_xdim, atlas_ydim, atlas_zdim

809 405 536


In [18]:
def parallel_where(l):
    w = np.where(volume1 == l)
    return [w[1], w[0], w[2]]

t = time.time()

atlas_nzs = Parallel(n_jobs=16)(delayed(parallel_where)(l) for l in range(1, n_labels))

print time.time() - t, 'seconds'

atlas_xmin, atlas_ymin, atlas_zmin = np.min([np.min(atlas_nzs[l-1], axis=1) for l in range(1, n_labels)], axis=0)
atlas_xmax, atlas_ymax, atlas_zmax = np.max([np.max(atlas_nzs[l-1], axis=1) for l in range(1, n_labels)], axis=0)
print atlas_xmin, atlas_xmax, atlas_ymin, atlas_ymax, atlas_zmin, atlas_zmax

atlas_centroid = np.array([.5*atlas_xmin+.5*atlas_xmax, .5*atlas_ymin+.5*atlas_ymax, .5*atlas_zmin+.5*atlas_zmax])
print atlas_centroid

atlas_cx, atlas_cy, atlas_cz = atlas_centroid

2.20171308517 seconds
0 808 0 404 0 535
[ 404.   202.   267.5]


In [19]:
downsample_factor = 16

section_thickness = 20 # in um
xy_pixel_distance_lossless = 0.46
xy_pixel_distance_tb = xy_pixel_distance_lossless * 32 # in um, thumbnail
# factor = section_thickness/xy_pixel_distance_lossless

xy_pixel_distance_downsampled = xy_pixel_distance_lossless * downsample_factor
z_xy_ratio_downsampled = section_thickness / xy_pixel_distance_downsampled

In [20]:
atlasAlignOptLogs_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignOptLogs'
if not os.path.exists(atlasAlignOptLogs_dir):
    os.makedirs(atlasAlignOptLogs_dir)

In [21]:
atlasAlignParams_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignParams'
if not os.path.exists(atlasAlignParams_dir):
    os.makedirs(atlasAlignParams_dir)

In [22]:
annotationsViz_rootdir = '/oasis/projects/nsf/csd395/yuncong/CSHL_annotaionsPojectedViz'

In [23]:
colors = np.loadtxt(os.environ['REPO_DIR'] + '/visualization/100colors.txt')
colors[label_dict['BackG']] = 1.

In [24]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [25]:
volume2_allLabels = None

In [26]:
def score_transform(tx, ty, tz, qr, qx, qy, qz):
    
    R = quarternion_to_matrix(qr, qx, qy, qz)
                
    scores = np.empty((n_labels-1,))
    for l in range(1, n_labels):
                    
        test_xs, test_ys, test_zs = (np.dot(R, np.array(atlas_nzs[l-1]) - atlas_centroid[:, np.newaxis]) + \
                                    np.asarray([tx + test_cx, 
                                                ty + test_cy, 
                                                tz + test_cz])[:,np.newaxis]).astype(np.int)

        ydim, xdim, zdim = volume2_allLabels[l-1].shape

        valid = (test_xs >= 0) & (test_ys >= 0) & (test_zs >= 0) & \
                (test_ys < ydim) & (test_xs < xdim) & (test_zs < zdim)

        voxel_probs_valid = volume2_allLabels[l-1][test_ys[valid], test_xs[valid], test_zs[valid]] / 1e4
        
        scores[l-1] = voxel_probs_valid.sum()
        
    del voxel_probs_valid, valid, test_xs, test_ys, test_zs
                
    score = np.sum(scores)
    
    return score

In [None]:
# for stack in ['MD593', 'MD592', 'MD590', 'MD591', 'MD595', 'MD598', 'MD602', 'MD585']:
for stack in ['MD585']:
    
    section_bs_begin, section_bs_end = section_range_lookup[stack]
    print section_bs_begin, section_bs_end

    (volume_xmin, volume_xmax, volume_ymin, volume_ymax, volume_zmin, volume_zmax) = \
    np.loadtxt(os.path.join(volume_dir, 'volume_%(stack)s_scoreMap_limits.txt' % {'stack': stack}), dtype=np.int)

    map_z_to_section = {}
    for s in range(section_bs_begin, section_bs_end+1):
        for z in range(int(z_xy_ratio_downsampled*s) - volume_zmin, int(z_xy_ratio_downsampled*(s+1)) - volume_zmin + 1):
            map_z_to_section[z] = s

            
    global volume2_allLabels
    volume2_allLabels = []

    for l in labels[1:]:
        
        print l

        volume2 = bp.unpack_ndarray_file(os.path.join(volume_dir, 'volume_%(stack)s_scoreMap_%(label)s.bp' % \
                                                      {'stack': stack, 'label': l}))

        volume2_cropped = volume2[volume_ymin:volume_ymax+1, volume_xmin:volume_xmax+1]
        # copy is important, because then you can delete the large array

        volume2_allLabels.append(volume2_cropped.copy())

        del volume2, volume2_cropped


    test_ydim, test_xdim, test_zdim = volume2_allLabels[0].shape

    print test_xdim, test_ydim, test_zdim

    test_centroid = (.5*test_xdim, .5*test_ydim, .5*test_ydim)
    print test_centroid

    test_cx, test_cy, test_cz = test_centroid

    
    handler = logging.FileHandler(atlasAlignOptLogs_dir + '/%(stack)s_atlasAlignOpt.log' % {'stack': stack})
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)


    ########## Grid Search ##########

    grid_search_iteration_number = 10

    params2_best_upToNow = (0, 0, 0, 0,0, np.deg2rad(0))
    score_best_upToNow = 0

    for iteration in range(grid_search_iteration_number):

        logger.info('grid search iteration %d', iteration)

        init_tx, init_ty, init_tz, init_vx, init_vy, init_theta  = params2_best_upToNow

        n = int(1000*np.exp(-iteration/3.))

        sigma_tx = 300*np.exp(-iteration/3.)
        sigma_ty = 300*np.exp(-iteration/3.)
        sigma_tz = 200*np.exp(-iteration/3.)

        sigma_v = .1 # axis of rotation
        sigma_theta = np.deg2rad(30*np.exp(-iteration/3.))

        tx_grid = init_tx + sigma_tx * (2 * np.random.random(n) - 1)
        ty_grid = init_ty + sigma_ty * (2 * np.random.random(n) - 1)
        tz_grid = init_tz + sigma_tz * (2 * np.random.random(n) - 1)
        vx_grid = init_vx + sigma_v * (2 * np.random.random(n) - 1)
        vy_grid = init_vy + sigma_v * (2 * np.random.random(n) - 1)
        theta_grid = init_theta + sigma_theta * (2 * np.random.random(n) - 1)

        qr_grid = np.cos(theta_grid/2)
        qx_grid = np.sin(theta_grid/2)*vx_grid
        qy_grid = np.sin(theta_grid/2)*vy_grid

        vz_grid = np.sqrt(1-vx_grid**2-vy_grid**2)
        qz_grid = np.sin(theta_grid/2)*vz_grid

        samples = np.c_[tx_grid, ty_grid, tz_grid, qr_grid, qx_grid, qy_grid, qz_grid]

        q_norm = np.sqrt(qr_grid**2 + qx_grid**2 + qy_grid**2 + qz_grid**2)
        qr_grid = qr_grid / q_norm 
        qx_grid = qx_grid / q_norm 
        qy_grid = qy_grid / q_norm 
        qz_grid = qz_grid / q_norm 

        import time
        t = time.time()

        scores = Parallel(n_jobs=16)(delayed(score_transform)(tx, ty, tz, qr, qx, qy, qz ) 
                                     for tx, ty, tz, qr, qx, qy, qz in samples)

        print time.time() - t, 'seconds'

        score_best = np.max(scores)

        tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best = samples[np.argmax(scores)]

        if score_best > score_best_upToNow:
            logger.info('%f %f', score_best_upToNow, score_best)

            score_best_upToNow = score_best
            params_best_upToNow = tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best

            v_best = np.array([qx_best,qy_best,qz_best])/np.sqrt(qx_best**2+qy_best**2+qz_best**2)
            theta_best = np.arccos(qr_best)

            params2_best_upToNow = tx_best, ty_best, tz_best, v_best[0], v_best[1], theta_best

            logger.info('%f %f %f (%f %f %f) %f', 
                         tx_best, ty_best, tz_best, v_best[0], v_best[1], v_best[2], np.rad2deg(theta_best))
            logger.info(' '.join(['%f']*7) % params_best_upToNow)
            logger.info('\n')

    ########## Compute score volume gradient ##########

    dSdyxz = []
    for l in range(1, n_labels):
        print labels[l]

        t = time.time()
        dSdyxz.append(np.gradient(volume2_allLabels[l-1], 10, 10, 10))
        print time.time() - t, 'seconds'


    ########## Gradient descent ##########

    def optimal_global_rigid_params(init_params, iter_num=100, return_scores=False, lr=(.01, 1e-6)):

        fudge_factor = 1e-6 #for numerical stability

        dMdu_historical = 0
        dMdv_historical = 0
        dMdw_historical = 0
        dMdqr_historical = 0
        dMdqx_historical = 0
        dMdqy_historical = 0
        dMdqz_historical = 0

        lr1, lr2 = lr

        score_best = 0

        tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best = init_params

        scores = []

        for iteration in range(iter_num):

            logger.info('iteration %d\n', iteration)

            dMdu = 0
            dMdv = 0
            dMdw = 0
            dMdqr = 0
            dMdqx = 0
            dMdqy = 0
            dMdqz = 0

            R_best = quarternion_to_matrix(qr_best, qx_best, qy_best, qz_best)

            for l in range(1, n_labels):

                ds = np.array(atlas_nzs[l-1]) - atlas_centroid[:, np.newaxis]

                xs_prime, ys_prime, zs_prime = (np.dot(R_best, atlas_nzs[l-1] - atlas_centroid[:, np.newaxis]) + \
                                                np.asarray([tx_best + test_cx, 
                                                            ty_best + test_cy, 
                                                            tz_best + test_cz])[:,np.newaxis]).astype(np.int)

                valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
                    (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)

                if np.count_nonzero(valid) > 0:

                    xs_prime_valid = xs_prime[valid]
                    ys_prime_valid = ys_prime[valid]
                    zs_prime_valid = zs_prime[valid]

                    Sx = dSdyxz[l-1][1][ys_prime_valid, xs_prime_valid, zs_prime_valid]
                    Sy = dSdyxz[l-1][0][ys_prime_valid, xs_prime_valid, zs_prime_valid]
                    Sz = dSdyxz[l-1][2][ys_prime_valid, xs_prime_valid, zs_prime_valid]

                    dMdu += Sx.sum()
                    dMdv += Sy.sum()
                    dMdw += Sz.sum()

                    ds_valid = ds[:, valid]
                    xs, ys, zs = ds_valid

                    qn_jac = quarternion_normalization_jacobian(qr_best,qx_best,qy_best,qz_best)

                    qr = qr_best
                    qx = qx_best
                    qy = qy_best
                    qz = qz_best

                    dxdqr = 2*(-qz*ys+qy*zs)
                    dxdqx = 2*qy*ys+qz*zs
                    dxdqy = 2*(-2*qy*xs+qx*ys+qr*zs)
                    dxdqz = 2*(-2*qz*xs-qr*ys+qx*zs)

                    dydqr = 2*(qz*xs-qx*zs)
                    dydqx = 2*(qy*xs-2*qx*ys-qr*zs)
                    dydqy = 2*(qx*xs+qz*zs)
                    dydqz = 2*(qr*xs-2*qz*ys+qy*zs)

                    dzdqr = 2*(-qy*xs+qx*ys)
                    dzdqx = 2*(qz*xs+qr*ys-2*qx*zs)
                    dzdqy = 2*(-qr*xs+qz*ys-2*qy*zs)
                    dzdqz = 2*(qx*xs+qy*ys)

                    dxdq = np.dot(np.c_[dxdqr, dxdqx, dxdqy, dxdqz], qn_jac)
                    dydq = np.dot(np.c_[dydqr, dydqx, dydqy, dydqz], qn_jac)
                    dzdq = np.dot(np.c_[dzdqr, dzdqx, dzdqy, dzdqz], qn_jac)

                    dx2dqr = dxdq[:,0]
                    dx2dqx = dxdq[:,1]
                    dx2dqy = dxdq[:,2]
                    dx2dqz = dxdq[:,3]

                    dy2dqr = dydq[:,0]
                    dy2dqx = dydq[:,1]
                    dy2dqy = dydq[:,2]
                    dy2dqz = dydq[:,3]

                    dz2dqr = dzdq[:,0]
                    dz2dqx = dzdq[:,1]
                    dz2dqy = dzdq[:,2]
                    dz2dqz = dzdq[:,3]

                    dMdqr += np.dot(Sx, dx2dqr) + np.dot(Sy, dy2dqr) + np.dot(Sz, dz2dqr)
                    dMdqx += np.dot(Sx, dx2dqx) + np.dot(Sy, dy2dqx) + np.dot(Sz, dz2dqx)
                    dMdqy += np.dot(Sx, dx2dqy) + np.dot(Sy, dy2dqy) + np.dot(Sz, dz2dqy)
                    dMdqz += np.dot(Sx, dx2dqz) + np.dot(Sy, dy2dqz) + np.dot(Sz, dz2dqz)


            logger.info('(dMdu, dMdv, dMdw): %f %f %f', dMdu, dMdv, dMdw)
            logger.info('(dMdqr, dMdqx, dMdqy, dMdqz): %f %f %f %f', dMdqr, dMdqx, dMdqy, dMdqz)

            dMdu_historical += dMdu**2
            dMdv_historical += dMdv**2
            dMdw_historical += dMdw**2
            dMdqr_historical += dMdqr**2
            dMdqx_historical += dMdqx**2
            dMdqy_historical += dMdqy**2
            dMdqz_historical += dMdqz**2

            dMdu_adjusted = dMdu / (fudge_factor + np.sqrt(dMdu_historical))
            dMdv_adjusted = dMdv / (fudge_factor + np.sqrt(dMdv_historical))
            dMdw_adjusted = dMdw / (fudge_factor + np.sqrt(dMdw_historical))
            dMdqr_adjusted = dMdqr / (fudge_factor + np.sqrt(dMdqr_historical))
            dMdqx_adjusted = dMdqx / (fudge_factor + np.sqrt(dMdqx_historical))
            dMdqy_adjusted = dMdqy / (fudge_factor + np.sqrt(dMdqy_historical))
            dMdqz_adjusted = dMdqz / (fudge_factor + np.sqrt(dMdqz_historical))

            logger.info('lr1: %f, lr2: %f', lr1, lr2)
            tx_best += lr1*dMdu_adjusted
            ty_best += lr1*dMdv_adjusted
            tz_best += lr1*dMdw_adjusted
            qr_best += lr2*dMdqr_adjusted
            qx_best += lr2*dMdqx_adjusted
            qy_best += lr2*dMdqy_adjusted
            qz_best += lr2*dMdqz_adjusted

            logger.info('(dMdu, dMdv, dMdw) adjusted: %f %f %f', 
                        dMdu_adjusted, dMdv_adjusted, dMdw_adjusted)
            logger.info('(dMdqr, dMdqx, dMdqy, dMdqz) adjusted: %f %f %f %f', 
                        dMdqr_adjusted, dMdqx_adjusted, dMdqy_adjusted, dMdqz_adjusted)

            logger.info('(tx_best, ty_best, tz_best):  %f %f %f', tx_best, ty_best, tz_best)
            logger.info('(qr_best, qx_best, qy_best, qz_best): %f %f %f %f', qr_best, qx_best, qy_best, qz_best)

            qn = np.sqrt(qr_best**2 + qx_best**2 + qy_best**2 + qz_best**2)
            qx_best = qx_best / qn
            qy_best = qy_best / qn
            qz_best = qz_best / qn
            qr_best = qr_best / qn

            v_best = np.array([qx_best, qy_best, qz_best])/np.sqrt(qx_best**2+qy_best**2+qz_best**2)
            theta_best = np.arccos(qr_best)
            logger.info('(v_best, theta_best): (%f %f %f) %f', v_best[0], v_best[1], v_best[2], np.rad2deg(theta_best))

            s = score_transform(tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best)
            logger.info('score: %f', s)
            scores.append(s)

            logger.info('\n')

            history_len = 50
            if iteration > 200:
                if np.abs(np.mean(scores[iteration-history_len:iteration]) - \
                          np.mean(scores[iteration-2*history_len:iteration-history_len])) < 1e-3:
                    break
            
            if s > score_best:
                logger.info('Current best')
                best_params = (tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best)
                score_best = s

        if return_scores:
            return best_params, scores
        else:
            return best_params



    init_params = params_best_upToNow
#     learning_rate = (10., 1e-3)
    learning_rate = (1., 1e-3)
    iteration_number = 4000

    t = time.time()

    best_global_params, scores = optimal_global_rigid_params(init_params=init_params, 
                                                             iter_num=iteration_number, 
                                                             return_scores=True,
                                                            lr=learning_rate)
    print best_global_params

    print time.time() - t, 'seconds'

    plt.plot(scores);
    plt.title('improvement of overlap score');
    plt.xlabel('iteration');
    plt.ylabel('overlap score');
    plt.show();

    np.save(atlasAlignOptLogs_dir + '/%(stack)s_scoreEvolutions.npy' % {'stack':stack}, scores)
    
    ########## Project atlas to test images using found alignment matrix ########## 

    tx_best, ty_best, tz_best, qr_best, qx_best, qy_best, qz_best = best_global_params
    R_best = quarternion_to_matrix(qr_best, qx_best, qy_best, qz_best)

    atlas_nzs_projected_to_test = [(np.dot(R_best, vs - atlas_centroid[:, np.newaxis]) + \
                                                np.asarray([tx_best + test_cx, 
                                                            ty_best + test_cy, 
                                                            tz_best + test_cz])[:,np.newaxis]).astype(np.int)
                                    for vs in atlas_nzs]

    print np.min(atlas_nzs_projected_to_test[0], axis=1)
    print np.max(atlas_nzs_projected_to_test[0], axis=1)

    test_volume_atlas_projected = np.zeros_like(volume2_allLabels[0], np.int)

    for l in range(1, n_labels):

        test_xs, test_ys, test_zs = atlas_nzs_projected_to_test[l-1].astype(np.int)

        valid = (test_xs >= 0) & (test_ys >= 0) & (test_zs >= 0) & \
            (test_xs < test_xdim) & (test_ys < test_ydim) & (test_zs < test_zdim)

        atlas_xs, atlas_ys, atlas_zs = atlas_nzs[l-1]

        test_volume_atlas_projected[test_ys[valid], test_xs[valid], test_zs[valid]] = \
        volume1[atlas_ys[valid], atlas_xs[valid], atlas_zs[valid]]

        
    del atlas_nzs_projected_to_test
        
    bp.pack_ndarray_file(test_volume_atlas_projected, 
                         volume_dir + '/%(stack)s_volume_atlasProjected.bp'%{'stack':stack})


    with open(os.path.join(atlasAlignParams_dir, '%(stack)s_3dAlignParams.txt' % {'stack':stack}), 'w') as f:
        f.writelines(' '.join(['%f']*len(params_best_upToNow)) % tuple(params_best_upToNow) + '\n')
        f.writelines(' '.join(['%f']*len(best_global_params)) % tuple(best_global_params) + '\n')
        f.writelines(' '.join(['%f']*len(learning_rate)) % tuple(learning_rate) + '\n')
        f.writelines('%d' % iteration_number + '\n')

    annotationsViz_dir = annotationsViz_rootdir + '/' + stack
    if not os.path.exists(annotationsViz_dir):
        os.makedirs(annotationsViz_dir)

    for z in range(0, test_zdim, 10):
        print z

        dm = DataManager(stack=stack, section=map_z_to_section[z])
        dm._load_image(versions=['rgb-jpg'])
        viz1 = dm.image_rgb_jpg[::downsample_factor, ::downsample_factor][volume_ymin:volume_ymax+1, volume_xmin:volume_xmax+1]

        viz2 = colors[test_volume_atlas_projected[...,z]]
        viz = alpha_blending(viz2, viz1[...,:3], .2, 1.)

        cv2.imwrite(annotationsViz_dir + '/%(stack)s_%(sec)04d_annotationsProjectedViz_z%(z)04d.jpg' % \
                    {'stack': stack, 'sec': map_z_to_section[z], 'z': z}, 
                    img_as_ubyte(viz[..., [2,1,0,3]]))
        
        del  viz1, viz2, viz
        
        
    del test_volume_atlas_projected
    
    logger.removeHandler(handler)

In [None]:
from skimage.measure import find_contours

def find_contour_points(labelmap):
    '''
    return is (x,y)
    '''

    regions = regionprops(labelmap)

    contour_points = {}

    for r in regions:

        (min_row, min_col, max_row, max_col) = r.bbox

        padded = np.pad(r.filled_image, ((5,5),(5,5)), mode='constant', constant_values=0)

        contours = find_contours(padded, .5, fully_connected='high')
        contours = [cnt.astype(np.int) for cnt in contours if len(cnt) > 10]
        if len(contours) > 1:
            sys.stderr.write('region has more than one part\n')
        elif len(contours) == 0:
            sys.stderr.write('no contour is found\n')

        pts = contours[0] - (5,5)

        pts_sampled = pts[np.arange(0, pts.shape[0], 10)]

    #         viz = np.zeros_like(r.filled_image)
    #         viz[pts_sampled[:,0], pts_sampled[:,1]] = 1
    #         plt.imshow(viz, cmap=plt.cm.gray);
    #         plt.show();

        contour_points[r.label] = pts_sampled[:, ::-1] + (min_col, min_row)
        
    return contour_points