## Initial setup

In [None]:
import numpy as np
from scipy.interpolate import interp2d
from scipy.interpolate import RectBivariateSpline

import numba
import matplotlib.pyplot as plt
import time
import os
%matplotlib inline

In [None]:
def plot_warp(xphi, yphi, downsample='auto', **kwarg):
    if (downsample == 'auto'):
        skip = np.max([xphi.shape[0]/32,1])
    elif (downsample == 'no'):
        skip = 1
    else:
        skip = downsample
    plt.plot(xphi[:,skip::skip],yphi[:,skip::skip],'black',\
             xphi[skip::skip,::1].T,yphi[skip::skip,::1].T,'black', **kwarg)

def get_dir_name(I0name, I1name, sigma):
    file_dir, file_name0 = os.path.split(I0name)
    file_dir, file_name1 = os.path.split(I1name)
    dir_name = os.path.join(file_dir, 
                        os.path.splitext(file_name1)[0] + 
                        ' to ' + os.path.splitext(file_name0)[0] + 
                        ' with sigma ' + str(sigma))
    return dir_name
    

## Perform the function matching

In [None]:
import difforma_base
import scipy.ndimage as ndimage

In [None]:
I0name = 'Example1 box/box_at_4th.png'
I1name = 'Example1 box/box_at_2nd.png'


# Parameter settings for inertia operator
alpha = 0.001
beta = 0.03

sigma = 0.05   # the higher the sigma, the more regularization
epsilon = 0.10  # step size
n_iter = 490    # number of iterations

I0 = plt.imread(I0name).astype('float')
I1 = plt.imread(I1name).astype('float')

# Apply Gaussian filter
# I0 = ndimage.gaussian_filter(I0, sigma=6)
# I1 = ndimage.gaussian_filter(I1, sigma=6)

# Filter away pixels too dark
# cutoff = 0.3
# I0 = np.where(I0<cutoff,cutoff,I0)

In [None]:
dm = difforma_base.DiffeoFunctionMatching(source=I0, target=I1, alpha=alpha, beta=beta, sigma=sigma)

In [None]:
%%time 
dm.run(n_iter, epsilon=epsilon)

## Plot the results

In [None]:
plt1 = plt.figure(1, figsize=(11.7,9))
plt.clf()

plt.subplot(2,2,1)
plt.imshow(dm.target, cmap='bone', vmin=dm.I0.min(), vmax=dm.I0.max())
plt.colorbar()
plt.title('Target image')

plt.subplot(2,2,2)
plt.imshow(dm.source, cmap='bone', vmin=dm.I0.min(), vmax=dm.I0.max())
plt.colorbar()
plt.title('Template image')

plt.subplot(2,2,3)
plt.imshow(dm.I, cmap='bone', vmin=dm.I0.min(), vmax=dm.I0.max())
plt.colorbar()
plt.title('Warped image')

plt.subplot(2,2,4)
use_forward = True
if use_forward:
    phix = dm.phix
    phiy = dm.phiy
else:
    phix = dm.phiinvx
    phiy = dm.phiinvy
    
plot_warp(phix, phiy, downsample=4)
plt.axis('equal')
warplim = [phix.min(), phix.max(), phiy.min(), phiy.max()]
warplim[0] = min(warplim[0], warplim[2])
warplim[2] = warplim[0]
warplim[1] = max(warplim[1], warplim[3])
warplim[3] = warplim[1]

plt.axis(warplim)
plt.gca().invert_yaxis()
plt.gca().set_aspect('equal')
plt.title('Warp')
plt.grid()

plt3 = plt.figure(3, figsize=(8,4.5))
plt.clf()
plt.plot(dm.E)
plt.grid()
plt.ylabel('Energy')

plt4 = plt.figure(4, figsize=(10,10))
plot_warp(phix, phiy, downsample=4)
plt.axis('equal')
warplim = [phix.min(), phix.max(), phiy.min(), phiy.max()]
warplim[0] = min(warplim[0], warplim[2])
warplim[2] = warplim[0]
warplim[1] = max(warplim[1], warplim[3])
warplim[3] = warplim[1]

plt.axis(warplim)
plt.gca().invert_yaxis()
plt.gca().set_aspect('equal')
plt.title('Warp')
plt.axis('off')


### Save the plots

In [None]:
from matplotlib.transforms import Bbox
def full_extent(ax, jac_colorbar, pad=0.0):
    """Get the full extent of an axes, including axes labels, tick labels, and
    titles."""
    # For text objects, we need to draw the figure first, otherwise the extents
    # are undefined.
    ax.figure.canvas.draw()
    items = ax.get_xticklabels() + ax.get_yticklabels()
    items += jac_colorbar.ax.get_xticklabels() + jac_colorbar.ax.get_yticklabels()

    #    items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
    items += [ax, ax.title, jac_colorbar.ax]
    bbox = Bbox.union([item.get_window_extent() for item in items])

    return bbox.expanded(1.0 + pad, 1.0 + pad)

In [None]:
# Setup directories and files
fig_dir_name = os.path.join(get_dir_name(I0name, I1name, sigma), 'figures')
if not os.path.exists(fig_dir_name):
    os.makedirs(fig_dir_name)
    print("Creating directory " + fig_dir_name)

fig = plt.figure(1)
plt1.savefig(os.path.join(fig_dir_name,'images.png'), dpi=300, bbox_inches='tight')

plt.figure(2)
plt.axis('off')
plt.title('')
plt2.savefig(os.path.join(fig_dir_name,'warp.png'), dpi=150, bbox_inches='tight')

plt3.savefig(os.path.join(fig_dir_name,'energy.png'), dpi=150, bbox_inches='tight')

# Create images

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

arr = np.zeros([256,256])
arr[30:90,30:60]=np.ones_like(arr[30:90,30:60])*255
#arr = gauss(arr,sigma=6)

im = Image.fromarray(arr).convert('L')
im.save('box_at_1st.png')

arr = np.roll(arr,20,axis=0)
im = Image.fromarray(arr).convert('L')
im.save('box_at_2nd.png')

arr = np.roll(arr,20,axis=1)
im = Image.fromarray(arr).convert('L')
im.save('box_at_3rd.png')

im = im.rotate(25)
im.save('box_at_4th.png')

plt.imshow(im)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

img = np.zeros([256,256])
img[30:90,30:60]=np.ones_like(img[30:90,30:60])

plt.figure()
plt.imshow(img)
plt.colorbar()
plt.show()

img = np.roll(img,100,axis=0)
plt.figure()
plt.imshow(img)
plt.colorbar()
plt.show()

img = np.roll(img,100,axis=1)
plt.figure()
plt.imshow(img)
plt.colorbar()
plt.show()

plt.imsave('box_at_2nd.png', img)

importedimg = plt.imread('box_at_2nd.png').astype('float')
print(importedimg.shape)

