In [None]:
# Aligns a score volume with an annotation volume

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

import sys
import os

sys.path.append(os.environ['REPO_DIR'] + '/utilities')
from utilities2015 import *

import matplotlib.pyplot as plt
%matplotlib inline

from joblib import Parallel, delayed
import time

import logging

In [None]:
labels = ['BackG', '5N', '7n', '7N', '12N', 'Pn', 'VLL', 
          '6N', 'Amb', 'R', 'Tz', 'RtTg', 'LRt', 'LC', 'AP', 'sp5']

n_labels = len(labels)

labels_index = dict((j, i) for i, j in enumerate(labels))

labels_from_surround = dict( (l+'_surround', l) for l in labels[1:])

labels_surroundIncluded_list = labels[1:] + [l+'_surround' for l in labels[1:]]
labels_surroundIncluded = set(labels_surroundIncluded_list)

labels_surroundIncluded_index = dict((j, i) for i, j in enumerate(labels_surroundIncluded_list))

# colors = np.random.randint(0, 255, (len(labels_index), 3))

colors = np.loadtxt(os.environ['REPO_DIR'] + '/visualization/100colors.txt')
colors[labels_index['BackG']] = 1.

In [None]:
volume_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_volumes/'

In [None]:
volume1 = bp.unpack_ndarray_file(os.path.join(volume_dir, 'volume_MD589_annotation.bp'))
atlas_ydim, atlas_xdim, atlas_zdim = volume1.shape
print atlas_xdim, atlas_ydim, atlas_zdim

In [None]:
def parallel_where(l):
    w = np.where(volume1 == l)
    return [w[1], w[0], w[2]]

t = time.time()

atlas_nzs = Parallel(n_jobs=16)(delayed(parallel_where)(l) for l in range(1, n_labels))

print time.time() - t, 'seconds'

atlas_xmin, atlas_ymin, atlas_zmin = np.min([np.min(atlas_nzs[l-1], axis=1) for l in range(1, n_labels)], axis=0)
atlas_xmax, atlas_ymax, atlas_zmax = np.max([np.max(atlas_nzs[l-1], axis=1) for l in range(1, n_labels)], axis=0)
print atlas_xmin, atlas_xmax, atlas_ymin, atlas_ymax, atlas_zmin, atlas_zmax

atlas_centroid = np.array([.5*atlas_xmin+.5*atlas_xmax, .5*atlas_ymin+.5*atlas_ymax, .5*atlas_zmin+.5*atlas_zmax])
print atlas_centroid

atlas_cx, atlas_cy, atlas_cz = atlas_centroid

In [None]:
downsample_factor = 16

section_thickness = 20 # in um
xy_pixel_distance_lossless = 0.46
xy_pixel_distance_tb = xy_pixel_distance_lossless * 32 # in um, thumbnail
# factor = section_thickness/xy_pixel_distance_lossless

xy_pixel_distance_downsampled = xy_pixel_distance_lossless * downsample_factor
z_xy_ratio_downsampled = section_thickness / xy_pixel_distance_downsampled

In [None]:
atlasAlignOptLogs_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignOptLogs'
if not os.path.exists(atlasAlignOptLogs_dir):
    os.makedirs(atlasAlignOptLogs_dir)

In [None]:
atlasAlignParams_dir = '/oasis/projects/nsf/csd395/yuncong/CSHL_atlasAlignParams'
if not os.path.exists(atlasAlignParams_dir):
    os.makedirs(atlasAlignParams_dir)

In [None]:
annotationsViz_rootdir = '/oasis/projects/nsf/csd395/yuncong/CSHL_annotaionsPojectedViz'

In [None]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [None]:
volume2_allLabels = None

In [None]:
ds = []
for l in range(1, n_labels):
    ds.append(np.array(atlas_nzs[l-1]) - atlas_centroid[:, np.newaxis])

In [None]:
def compute_score_and_gradient(T):
    global ds
    
    Tm = T.reshape((3,4))
    tx, ty, tz = Tm[:, 3]
    A = Tm[:, :3]

    score = 0
    dMdA = np.zeros((12,))
    
    for l in range(1, n_labels):
#         t1 = time.time()
    
        xs_prime, ys_prime, zs_prime = (np.dot(A, ds[l-1]) + \
                                    np.asarray([tx + test_cx, 
                                                ty + test_cy, 
                                                tz + test_cz])[:,np.newaxis]).astype(np.int)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
                (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
                   
        if np.count_nonzero(valid) > 0:

            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            voxel_probs_valid = volume2_allLabels[l-1][ys_prime_valid, xs_prime_valid, zs_prime_valid] / 1e4

            score += voxel_probs_valid.sum()
            
            Sx = dSdxyz[l-1][0][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sy = dSdxyz[l-1][1][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sz = dSdxyz[l-1][2][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            
            dxs, dys, dzs = ds[l-1][:, valid]

#             dMdA += np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
#                           Sy*dxs, Sy*dys, Sy*dzs, Sy,
#                           Sz*dxs, Sz*dys, Sz*dzs, Sz].sum(axis=0)
            
            q = np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
                          Sy*dxs, Sy*dys, Sy*dzs, Sy,
                          Sz*dxs, Sz*dys, Sz*dzs, Sz]        
            
            dMdA += q.sum(axis=0)
            
            del voxel_probs_valid, q, Sx, Sy, Sz, dxs, dys, dzs, xs_prime_valid, ys_prime_valid, zs_prime_valid
        
#         sys.stderr.write('########### %s: %f seconds\n' % (labels[l], time.time() - t1))
        
        del valid, xs_prime, ys_prime, zs_prime
        
    return score, dMdA

In [None]:
def compute_score_and_gradient_and_hessian(T):
    global ds
    
    Tm = T.reshape((3,4))
    tx, ty, tz = Tm[:, 3]
    A = Tm[:, :3]

    score = 0
    dMdA = np.zeros((12,))
    d2MdT2 = np.zeros((12, 12))
    
#     for l in range(1, n_labels):
    for l in range(1, 5):
#         t1 = time.time()
    
        xs_prime, ys_prime, zs_prime = (np.dot(A, ds[l-1]) + \
                                    np.asarray([tx + test_cx, 
                                                ty + test_cy, 
                                                tz + test_cz])[:,np.newaxis]).astype(np.int)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
                (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
                   
        if np.count_nonzero(valid) > 0:

            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            voxel_probs_valid = volume2_allLabels[l-1][ys_prime_valid, xs_prime_valid, zs_prime_valid] / 1e4

            score += voxel_probs_valid.sum()
            
            Sx = dSdxyz[l-1][0][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sy = dSdxyz[l-1][1][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sz = dSdxyz[l-1][2][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            
            dxs, dys, dzs = ds[l-1][:, valid]

#             dMdA += np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
#                           Sy*dxs, Sy*dys, Sy*dzs, Sy,
#                           Sz*dxs, Sz*dys, Sz*dzs, Sz].sum(axis=0)
            
            q = np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
                          Sy*dxs, Sy*dys, Sy*dzs, Sy,
                          Sz*dxs, Sz*dys, Sz*dzs, Sz]        
            
            dMdA += q.sum(axis=0)
            
            
            Sxx_full, Sxy_full, Sxz_full, Syx_full, Syy_full, Syz_full, Szx_full, Szy_full, Szz_full = d2Sdxyz2[l-1]
            Sxx = Sxx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sxy = Sxy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sxz = Sxz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syx = Syx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syy = Syy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syz = Syz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szx = Szx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szy = Szy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szz = Szz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            
            rx = np.c_[Sxx*dxs, Sxx*dys, Sxx*dzs, Sxx, Sxy*dxs, Sxy*dys, Sxy*dzs, Sxy, Sxz*dxs, Sxz*dys, Sxz*dzs, Sxz]
            ry = np.c_[Syx*dxs, Syx*dys, Syx*dzs, Syx, Syy*dxs, Syy*dys, Syy*dzs, Syy, Syz*dxs, Syz*dys, Syz*dzs, Syz]
            rz = np.c_[Szx*dxs, Szx*dys, Szx*dzs, Szx, Szy*dxs, Szy*dys, Szy*dzs, Szy, Szz*dxs, Szz*dys, Szz*dzs, Szz]
            r1 = (rx*dxs[:,None]).sum(axis=0)
            r2 = (rx*dys[:,None]).sum(axis=0)
            r3 = (rx*dzs[:,None]).sum(axis=0)
            r4 = rx.sum(axis=0)
            r5 = (ry*dxs[:,None]).sum(axis=0)
            r6 = (ry*dys[:,None]).sum(axis=0)
            r7 = (ry*dzs[:,None]).sum(axis=0)
            r8 = ry.sum(axis=0)
            r9 = (rz*dxs[:,None]).sum(axis=0)
            r10 = (rz*dys[:,None]).sum(axis=0)
            r11 = (rz*dzs[:,None]).sum(axis=0)
            r12 = rz.sum(axis=0)
        
            d2MdT2 += np.vstack([r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12])
            
            del voxel_probs_valid, q, Sx, Sy, Sz, dxs, dys, dzs, xs_prime_valid, ys_prime_valid, zs_prime_valid
            del Sxx_full, Sxy_full, Sxz_full, Syx_full, Syy_full, Syz_full, Szx_full, Szy_full, Szz_full
        
#         sys.stderr.write('########### %s: %f seconds\n' % (labels[l], time.time() - t1))
        
        del valid, xs_prime, ys_prime, zs_prime
        
    return score, dMdA, d2MdT2

In [None]:
def compute_score_minus(T):
    return -compute_score(T)

def compute_score_gradient_minus(T):
    return -compute_score_gradient(T)

def compute_score_hessian_minus(T):
    return -compute_score_hessian(T)

def compute_score_and_gradient_minus(T):
    s, g = compute_score_and_gradient(T)
    return -s, -g

def compute_score_and_gradient_and_hessian_minus(T):
    s, g, h = compute_score_and_gradient_and_hessian(T)
    return -s, -g, -h

In [None]:
def compute_score(T):
    
    Tm = np.reshape(T, (3,4))
    tx, ty, tz = Tm[:, 3]
    A = Tm[:, :3]
  
    score = 0
#     for l in range(1, n_labels):
    for l in range(1, 5):
        
        xs_prime, ys_prime, zs_prime = (np.dot(A, ds[l-1]) + \
                                    np.asarray([tx + test_cx, 
                                                ty + test_cy, 
                                                tz + test_cz])[:,np.newaxis]).astype(np.int)
        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
            (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
        voxel_probs_valid = volume2_allLabels[l-1][ys_prime[valid], xs_prime[valid], zs_prime[valid]] / 1e4

        score += voxel_probs_valid.sum()
                
    del voxel_probs_valid, valid, xs_prime, ys_prime, zs_prime
                
    return score

def compute_score_gradient(T):

    Tm = np.reshape(T, (3,4))
    tx, ty, tz = Tm[:, 3]
    A = Tm[:, :3]

    dMdA = np.zeros((12,))

    for l in range(1, n_labels):    
#     for l in [1]:

        xs_prime, ys_prime, zs_prime = (np.dot(A, ds[l-1]) + \
                                        np.asarray([tx + test_cx, 
                                                    ty + test_cy, 
                                                    tz + test_cz])[:,np.newaxis]).astype(np.int)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
            (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)
            
        if np.count_nonzero(valid) > 0:
            
            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            Sx = dSdxyz[l-1][0][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sy = dSdxyz[l-1][1][ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sz = dSdxyz[l-1][2][ys_prime_valid, xs_prime_valid, zs_prime_valid]
               
            dxs, dys, dzs = ds[l-1][:, valid]
            dMdA += np.c_[Sx*dxs, Sx*dys, Sx*dzs, Sx, 
                          Sy*dxs, Sy*dys, Sy*dzs, Sy,
                          Sz*dxs, Sz*dys, Sz*dzs, Sz].sum(axis=0)
            
    return dMdA


def compute_score_hessian(T):
    
    Tm = np.reshape(T, (3,4))
    tx, ty, tz = Tm[:, 3]
    A = Tm[:, :3]

    d2MdT2 = np.zeros((12, 12))
    
    for l in range(1, n_labels):

        xs_prime, ys_prime, zs_prime = (np.dot(A, ds[l-1]) + \
                                        np.asarray([tx + test_cx, 
                                                    ty + test_cy, 
                                                    tz + test_cz])[:,np.newaxis]).astype(np.int)

        valid = (xs_prime >= 0) & (ys_prime >= 0) & (zs_prime >= 0) & \
            (xs_prime < test_xdim) & (ys_prime < test_ydim) & (zs_prime < test_zdim)

        if np.count_nonzero(valid) > 0:

            xs_prime_valid = xs_prime[valid]
            ys_prime_valid = ys_prime[valid]
            zs_prime_valid = zs_prime[valid]
            
            dxs, dys, dzs = ds[l-1][:, valid]

            Sxx_full, Sxy_full, Sxz_full, Syx_full, Syy_full, Syz_full, Szx_full, Szy_full, Szz_full = d2Sdxyz2[l-1]
            Sxx = Sxx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sxy = Sxy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Sxz = Sxz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syx = Syx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syy = Syy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Syz = Syz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szx = Szx_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szy = Szy_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]
            Szz = Szz_full[ys_prime_valid, xs_prime_valid, zs_prime_valid]

            rx = np.c_[Sxx*dxs, Sxx*dys, Sxx*dzs, Sxx, Sxy*dxs, Sxy*dys, Sxy*dzs, Sxy, Sxz*dxs, Sxz*dys, Sxz*dzs, Sxz]
            ry = np.c_[Syx*dxs, Syx*dys, Syx*dzs, Syx, Syy*dxs, Syy*dys, Syy*dzs, Syy, Syz*dxs, Syz*dys, Syz*dzs, Syz]
            rz = np.c_[Szx*dxs, Szx*dys, Szx*dzs, Szx, Szy*dxs, Szy*dys, Szy*dzs, Szy, Szz*dxs, Szz*dys, Szz*dzs, Szz]
            r1 = (rx*dxs[:,None]).sum(axis=0)
            r2 = (rx*dys[:,None]).sum(axis=0)
            r3 = (rx*dzs[:,None]).sum(axis=0)
            r4 = rx.sum(axis=0)
            r5 = (ry*dxs[:,None]).sum(axis=0)
            r6 = (ry*dys[:,None]).sum(axis=0)
            r7 = (ry*dzs[:,None]).sum(axis=0)
            r8 = ry.sum(axis=0)
            r9 = (rz*dxs[:,None]).sum(axis=0)
            r10 = (rz*dys[:,None]).sum(axis=0)
            r11 = (rz*dzs[:,None]).sum(axis=0)
            r12 = rz.sum(axis=0)
            
            d2MdT2_l = np.vstack([r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12])
    
        d2MdT2 += d2MdT2_l
    
    return d2MdT2

In [None]:
stack = 'MD594'

section_bs_begin, section_bs_end = section_range_lookup[stack]
print section_bs_begin, section_bs_end

(volume_xmin, volume_xmax, volume_ymin, volume_ymax, volume_zmin, volume_zmax) = \
np.loadtxt(os.path.join(volume_dir, 'volume_%(stack)s_scoreMap_limits.txt' % {'stack': stack}), dtype=np.int)

map_z_to_section = {}
for s in range(section_bs_begin, section_bs_end+1):
    for z in range(int(z_xy_ratio_downsampled*s) - volume_zmin, int(z_xy_ratio_downsampled*(s+1)) - volume_zmin + 1):
        map_z_to_section[z] = s

global volume2_allLabels
volume2_allLabels = []

for l in labels[1:]:
    
    volume2 = bp.unpack_ndarray_file(os.path.join(volume_dir, 'volume_%(stack)s_scoreMap_%(label)s.bp' % \
                                                  {'stack': stack, 'label': l}))

    volume2_cropped = volume2[volume_ymin:volume_ymax+1, volume_xmin:volume_xmax+1]
    # copy is important, because then you can delete the large array

    volume2_allLabels.append(volume2_cropped.copy())

    del volume2, volume2_cropped


test_ydim, test_xdim, test_zdim = volume2_allLabels[0].shape
test_centroid = (.5*test_xdim, .5*test_ydim, .5*test_ydim)
test_cx, test_cy, test_cz = test_centroid

print test_xdim, test_ydim, test_zdim
print test_centroid

In [None]:
grid_search_iteration_number = 1

params_best_upToNow = (0, 0, 0)
score_best_upToNow = 0

for iteration in range(grid_search_iteration_number):

    logger.info('grid search iteration %d', iteration)

    init_tx, init_ty, init_tz  = params_best_upToNow

    n = int(1000*np.exp(-iteration/3.))

    sigma_tx = 300*np.exp(-iteration/3.)
    sigma_ty = 300*np.exp(-iteration/3.)
    sigma_tz = 100*np.exp(-iteration/3.)

    tx_grid = init_tx + sigma_tx * (2 * np.random.random(n) - 1)
    ty_grid = init_ty + sigma_ty * (2 * np.random.random(n) - 1)
    tz_grid = init_tz + sigma_tz * (2 * np.random.random(n) - 1)

    samples = np.c_[tx_grid, ty_grid, tz_grid]

    import time
    t = time.time()

    scores = Parallel(n_jobs=16)(delayed(compute_score)([1, 0, 0, tx, 0, 1, 0, ty, 0, 0, 1, tz]) 
                                 for tx, ty, tz in samples)

#     scores = [compute_score([1, 0, 0, tx, 0, 1, 0, ty, 0, 0, 1, tz]) for tx, ty, tz in samples]

    print time.time() - t, 'seconds'

    score_best = np.max(scores)

    tx_best, ty_best, tz_best = samples[np.argmax(scores)]

    if score_best > score_best_upToNow:
        logger.info('%f %f', score_best_upToNow, score_best)

        score_best_upToNow = score_best
        params_best_upToNow = tx_best, ty_best, tz_best

        logger.info('%f %f %f', tx_best, ty_best, tz_best)
        logger.info('\n')

In [None]:
# def parallel_where(l):
#     w = np.where(volume2_allLabels[l-1] > .5)
#     return [w[1], w[0], w[2]]

# t = time.time()

# test_nzs = Parallel(n_jobs=16)(delayed(parallel_where)(l) for l in range(1, n_labels))

# print time.time() - t, 'seconds'

In [None]:
plt.figure(figsize=(10,10));
plt.imshow(volume2_allLabels[1][..., 0], vmin=0, vmax=1);

In [None]:
dSdxyz = []

# for l in range(1, n_labels):
for l in range(1, 5):
    
    print labels[l]
    
    t = time.time()
    
#     gx, gy, gz = Parallel(n_jobs=3)(delayed(load_hdf)(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_%(g)s.hdf' % \
#                                          {'stack': stack, 'lab': labels[l], 'g': g} )
#                                     for g in ['gx', 'gy', 'gz'])

    gx = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gx.hdf' % {'stack': stack, 'lab': labels[l]})
    gy = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gy.hdf' % {'stack': stack, 'lab': labels[l]})
    gz = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gz.hdf' % {'stack': stack, 'lab': labels[l]})
    
    print time.time() - t
    
    dSdxyz.append([gx, gy, gz])

In [None]:
d2Sdxyz2 = []

# for l in range(1, n_labels):
for l in range(1, 5):
    
    print labels[l]
    
    t = time.time()
    
    gxx = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gxx.hdf' % {'stack': stack, 'lab': labels[l]})
    gxy = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gxy.hdf' % {'stack': stack, 'lab': labels[l]})
    gxz = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gxz.hdf' % {'stack': stack, 'lab': labels[l]})
    gyx = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gyx.hdf' % {'stack': stack, 'lab': labels[l]})
    gyy = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gyy.hdf' % {'stack': stack, 'lab': labels[l]})
    gyz = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gyz.hdf' % {'stack': stack, 'lab': labels[l]})
    gzx = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gzx.hdf' % {'stack': stack, 'lab': labels[l]})
    gzy = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gzy.hdf' % {'stack': stack, 'lab': labels[l]})
    gzz = load_hdf(volume_dir + '/volume_%(stack)s_scoreMap_%(lab)s_gzz.hdf' % {'stack': stack, 'lab': labels[l]})
    
    print time.time() - t
    
    d2Sdxyz2.append([gxx, gxy, gxz, gyx, gyy, gyz, gzx, gzy, gzz])

In [None]:
from scipy.optimize import fmin_bfgs, fmin_ncg, fmin_l_bfgs_b, fmin_cg

In [None]:
# import collections, functools

# def func_wrapper(f, cache_size=10):
#     evals = {}
#     last_points = collections.deque()

#     def get(pt, which):
#         s = pt.tostring() # get binary string of numpy array, to make it hashable
#         if s not in evals:
#             evals[s] = f(pt)
#             last_points.append(s)
#             if len(last_points) >= cache_size:
#                 del evals[last_points.popleft()]
#         return evals[s][which]

#     return functools.partial(get, which=0), functools.partial(get, which=1)

In [None]:
import collections, functools

def func_wrapper(f, cache_size=10):
    evals = {}
    last_points = collections.deque()

    def get(pt, which):
        s = pt.tostring() # get binary string of numpy array, to make it hashable
        if s not in evals:
            evals[s] = f(pt)
            last_points.append(s)
            if len(last_points) >= cache_size:
                del evals[last_points.popleft()]
        return evals[s][which]

    return functools.partial(get, which=0), functools.partial(get, which=1), functools.partial(get, which=2)

In [None]:
tx_best, ty_best, tz_best = params_best_upToNow
T_best = np.r_[1,0,0, tx_best, 0,1,0, ty_best, 0,0,1, tz_best]

t = time.time()

# f_, fprime = func_wrapper(compute_score_and_gradient_minus)

f_, fprime, fhess = func_wrapper(compute_score_and_gradient_and_hessian_minus)

# res = fmin_ncg(f=f_, x0=T_best, fprime=fprime, fhess=fhess, maxiter=100, epsilon=1e-3, full_output=True)
# res = fmin_ncg(f=compute_score_minus, x0=T_best, fprime=compute_score_gradient_minus, maxiter=10, epsilon=1e-2)

res = fmin_cg(f=f_, x0=T_best, fprime=fprime, maxiter=10)

# res = fmin_bfgs(f=f_, x0=T_best, fprime=fprime, maxiter=10)

# res = fmin_l_bfgs_b(func=compute_score_minus, x0=T_best, maxiter=10, approx_grad=True)

sys.stderr.write('optimize: %f seconds\n' % (time.time() - t))

In [None]:
tx_best, ty_best, tz_best = params_best_upToNow
T_best = np.r_[1,0,0, tx_best, 0,1,0, ty_best, 0,0,1, tz_best]

lr1, lr2 = (10., 1e-3)
max_iter_num = 5000

fudge_factor = 1e-6 #for numerical stability

dMdA_historical = np.zeros((12,))

lr = np.array([lr2, lr2, lr2, lr1, lr2, lr2, lr2, lr1, lr2, lr2, lr2, lr1])

score_best = 0

scores = []

for iteration in range(max_iter_num):
    
    logger.info('iteration %d\n', iteration)
    
    t = time.time()
    
    s, dMdA = compute_score_and_gradient(T_best)
#     s, dMdA = compute_score_and_gradient_parallel(T_best)

#     sys.stderr.write('###### compute_score_and_gradient: %f seconds\n' % (time.time() - t))
#     sys.stderr.write('###### s: %f\n' % s)

    dMdA_historical += dMdA**2
    dMdA_adjusted = dMdA / (fudge_factor + np.sqrt(dMdA_historical))
    
    T_best += lr*dMdA_adjusted

#     logger.info('A: ' + ' '.join(['%f']*12) % tuple(T_best))
#     logger.info('dMdA adjusted: ' + ' '.join(['%f']*12) % tuple(dMdA_adjusted))

    logger.info('score: %f', s)
    scores.append(s)

    logger.info('\n')

    history_len = 50
    if iteration > 200:
        if np.abs(np.mean(scores[iteration-history_len:iteration]) - \
                  np.mean(scores[iteration-2*history_len:iteration-history_len])) < 1e-2:
            break

    if s > score_best:
#         logger.info('Current best')
        best_gradient_descent_params = T_best
        score_best = s