In [None]:
import sys, os, ndreg, skimage
import matplotlib.pyplot as plt
from matplotlib import cm
import SimpleITK as sitk
import numpy as np
from intern.remote.boss import BossRemote
from intern.resource.boss.resource import *
import missing_data as mdmask
import ingest_tif_stack as ingest

In [None]:
import registerer_copy as reg
import preprocessor_copy as pre
import matplotlib
matplotlib.rcParams['figure.figsize'] = (10.0, 8.0)

In [None]:
full_atlas = ndreg.imgRead('./atlas/ara_atlas.img')
atlas_width, atlas_height, atlas_depth = full_atlas.GetSize()

In [None]:
# Create bottom half atlas
bot_half_atlas_array = mdmask.gen_frac_mask(atlas_depth, atlas_height, atlas_width, 0.5, 0, side='right')
bot_half_atlas_mask = mdmask.convert_to_image(bot_half_atlas_array.astype('uint16'))
bot_half_atlas_mask.CopyInformation(full_atlas)
bot_half_atlas = mdmask.mask_img(full_atlas, bot_half_atlas_mask)

In [None]:
# Create top half atlas
top_half_atlas_array = mdmask.gen_frac_mask(atlas_depth, atlas_height, atlas_width, 0.5, 0, side='left')
top_half_atlas_mask = mdmask.convert_to_image(top_half_atlas_array.astype('uint16'))
top_half_atlas_mask.CopyInformation(full_atlas)
top_half_atlas = mdmask.mask_img(full_atlas, top_half_atlas_mask)

In [None]:
full_img = ndreg.imgRead('./missing_insula/whole_insula.img')
# Reorient image to be same orientation as atlas
orientation_atlas = 'pir'
orientation_image = 'lps'
full_img = ndreg.imgReorient(full_img, orientation_image, orientation_atlas)
# img_width, img_height, img_depth = full_img.GetSize()

In [None]:
def run_experiment(atlas, img, missing_percentages, missing_dim, transformation_type=reg.register_rigid):
    atlas_width, atlas_height, atlas_depth = atlas.GetSize()
    img_width, img_height, img_depth = full_img.GetSize()
    
    percent_results = {}
    
    for p in missing_percentages:

        # Create the missing image
        missing_array = mdmask.gen_frac_mask(img_depth, img_height, img_width, p, missing_dim, side='right')
        missing_mask = mdmask.convert_to_image(missing_array.astype('uint16'))
        missing_mask.CopyInformation(full_img)
        missing_img = mdmask.mask_img(full_img, missing_mask)

        img = missing_img

        # Bias correction
#         mask_dilation_radius = 10 # voxels
#         mask_bc = sitk.BinaryDilate(pre.create_mask(img, use_triangle=True), mask_dilation_radius)
#         img_bc, bias = pre.correct_bias_field(img, scale=0.25, spline_order=4, mask=mask_bc,
#                                              num_control_pts=[5,5,5],
#                                              niters=[50, 50, 50, 50])
        atlas_n = sitk.Normalize(atlas)
        img_bc_n = sitk.Normalize(img)
#         img_bc_n = sitk.Normalize(img_bc)

        # Compute affine transform
        final_transform = transformation_type(atlas_n,
                                            img_bc_n,
                                            fixed_mask=missing_mask,
                                            learning_rate=1e-1,
                                            grad_tol=4e-6,
                                            use_mi=False,
                                            iters=50,
                                            shrink_factors=[4,2,1],
                                            sigmas=[0.4, 0.2, 0.1],
                                            verbose=False)

        # Perform the rigid transform
        atlas_rigid = reg.resample(atlas, final_transform, img_bc, default_value=ndreg.imgPercentile(atlas,0.01))

        print("Percentage of image mssing: {}".format(p))
        ndreg.imgShow(atlas_rigid)
        percent_results[p] = (atlas_rigid, final_transform, img)
    return percent_results

# Affine experiments (or rigid)

1. Try increasing missing data amount until rigid doesn't work
2. Try composing translation then rotation transforms
3. Look into the versor transform, registration method functions
 * http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/22_Transforms.html
4. Try take out the physical optimizer
5. Look at the transformation computed

In [None]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual,FloatSlider
import ipywidgets as widgets

def atlas_rigid_slider(x, res):
    plt.clf()
    atlas_rigid, final_transform, img = res[x]
    ndreg.imgShow(atlas_rigid)
    return

def mse_plot_slider(x, res):
    plt.clf()
    atlas_rigid, final_transform, img = res[x]
    mask_dilation_radius = 10 # voxels
    mask_bc = sitk.BinaryDilate(pre.create_mask(img, use_triangle=True), mask_dilation_radius)
    atlas_rigid_slice = sitk.GetArrayFromImage(sitk.Normalize(atlas_rigid))[:,50,:]
    img_rigid_slice = sitk.GetArrayFromImage(sitk.Normalize(img))[:,50,:]
    side_img_slice = sitk.GetArrayFromImage(mask_bc)[:,50,:]
    mdmask.plot_mse(atlas_rigid_slice, img_rigid_slice, side_img=side_img_slice, color_blend=True)
    return

def missing_img_slider(x, res):
    plt.clf()
    atlas_rigid, final_transform, img = res[x]
    ndreg.imgShow(img, vmax=2500)
    return

## Anterior missing

In [None]:
# starting_percent = 0.0; ending_percent = 0.50; increments = 25
# missing_percentages = np.round(np.linspace(starting_percent, ending_percent, 
#                                            num=int(increments*(ending_percent-starting_percent))+1), 2)
# missing_dim = 2
# print("Missing percentages: {}".format(list(missing_percentages)))

In [None]:
# ant_percent_results = run_experiment(atlas, full_img, missing_percentages, missing_dim)

In [None]:
# interact(mse_plot_slider, 
#          res=fixed(ant_percent_results),
#          x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


In [None]:
# interact(atlas_rigid_slider, 
#          res=fixed(ant_percent_results),
#          x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


In [None]:
# interact(missing_img_slider, 
#          res=fixed(ant_percent_results),
#          x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


## Hemisphere missing

In [None]:
starting_percent = 0.25; ending_percent = 0.50; increment_factor = 20
missing_percentages = np.round(np.linspace(starting_percent, ending_percent, 
                                           num=int(increment_factor*(ending_percent-starting_percent))+1), 2)
missing_dim = 0
print("Missing percentages: {}".format(list(missing_percentages)))

In [None]:
hemi_percent_results = run_experiment(bot_half_atlas, full_img, missing_percentages, missing_dim)

In [None]:
interact(mse_plot_slider, 
         res=fixed(hemi_percent_results),
         x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


In [None]:
interact(atlas_rigid_slider, 
         res=fixed(hemi_percent_results),
         x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


In [None]:
interact(missing_img_slider, 
         res=fixed(hemi_percent_results),
         x=FloatSlider(min=starting_percent, max=ending_percent+0.0001, step=1/float(increment_factor), continuous_update=True))


# Create gifs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
def atlas_rigid_update(x):
    plt.clf()
    atlas_rigid, final_transform, img = result_dict[x]
    ndreg.imgShow(atlas_rigid)
    return

def mse_anim_update(x):
    atlas_rigid, final_transform, img = result_dict[x]
    mask_dilation_radius = 10 # voxels
    mask_bc = sitk.BinaryDilate(pre.create_mask(img, use_triangle=True), mask_dilation_radius)
    atlas_rigid_slice = sitk.GetArrayFromImage(sitk.Normalize(atlas_rigid))[:,50,:]
    img_rigid_slice = sitk.GetArrayFromImage(sitk.Normalize(img))[:,50,:]
    side_img_slice = sitk.GetArrayFromImage(mask_bc)[:,50,:]
    mdmask.plot_mse(atlas_rigid_slice, img_rigid_slice, side_img=side_img_slice, color_blend=True)
    return

def missing_img_update(x):
    atlas_rigid, final_transform, img = result_dict[x]
    ndreg.imgShow(img, vmax=2500)
    return

In [None]:
anim_fargs = np.round(np.linspace(starting_percent, ending_percent, 6),2)
result_dict = hemi_percent_results

In [None]:
print(anim_fargs)

## Missing Hemisphere

In [None]:
mse_plot_anim = animation.FuncAnimation(plt.figure(), mse_anim_update, frames=anim_fargs,interval=200)
mse_plot_anim.save('mse_anim_affine_mod_atlas_hemi.mp4')

In [None]:
# mse_plot_anim = animation.FuncAnimation(plt.figure(), atlas_rigid_update, frames=anim_fargs,interval=200)
# mse_plot_anim.save('atlas_rigid_hemi.mp4')

In [None]:
# mse_plot_anim = animation.FuncAnimation(plt.figure(), missing_img_update, frames=anim_fargs,interval=200)
# mse_plot_anim.save('missing_img_hemi.mp4')

## Missing Anterior

In [None]:
anim_fargs = np.round(np.linspace(starting_percent, ending_percent, 9),2)
result_dict = ant_percent_results

In [None]:
mse_plot_anim = animation.FuncAnimation(plt.figure(), mse_anim_update, frames=anim_fargs,interval=200)
mse_plot_anim.save('mse_anim_ant.mp4')

In [None]:
# mse_plot_anim = animation.FuncAnimation(plt.figure(), atlas_rigid_update, frames=anim_fargs,interval=200)
# mse_plot_anim.save('atlas_rigid_hemi.mp4')

In [None]:
# mse_plot_anim = animation.FuncAnimation(plt.figure(), missing_img_update, frames=anim_fargs,interval=200)
# mse_plot_anim.save('missing_img_hemi.mp4')