In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import PiecewiseAffineTransform, warp
from skimage import data
import skimage.io

# http://scikit-image.org/docs/dev/auto_examples/transform/plot_piecewise_affine.html

In [None]:
def deform(image1, image2, points=10, distort=5.0):
    
    rows, cols = image1.shape[0], image1.shape[1]
    src_cols = np.linspace(0, cols, points)
    src_rows = np.linspace(0, rows, points)
    src_rows, src_cols = np.meshgrid(src_rows, src_cols)
    src = np.dstack([src_cols.flat, src_rows.flat])[0]

    # add distortion to coordinates
    s = src[:, 1].shape
    dst_rows = src[:, 1] + np.random.normal(size=s)*np.random.uniform(0.0, distort, size=s)
    dst_cols = src[:, 0] + np.random.normal(size=s)*np.random.uniform(0.0, distort, size=s)
    
    #dst_rows = dst_rows*1. - distort
    #dst_cols = dst_cols*1. - distort
    
    dst = np.vstack([dst_cols, dst_rows]).T

    tform = PiecewiseAffineTransform()
    tform.estimate(src, dst)

    out_rows = rows #image.shape[0] - 1.5 * 50
    out_cols = cols
    out1 = warp(image1, tform, output_shape=(out_rows, out_cols), mode="symmetric")
    out2 = warp(image2, tform, output_shape=(out_rows, out_cols), mode="symmetric")
    
    return out1, out2, {"tform":tform, "src":src, "out_cols":out_cols, "out_rows":out_rows}

In [None]:
def display(im1, im2, p, d, params=None):
    fig, ax = plt.subplots(1,2, figsize=(18,12))
    ax[0].imshow(im1)#[0:256,0:256])
    ax[1].imshow(im2)#[0:256,0:256])
    ax[0].set_title('Points {} Distort {}'.format(p, d))
    if params is not None:
        ax[0].plot(params["tform"].inverse(params["src"])[:, 0], params["tform"].inverse(params["src"])[:, 1], '.b')
        ax[0].axis((0, params["out_cols"], params["out_rows"], 0))
        ax[1].plot(params["tform"].inverse(params["src"])[:, 0], params["tform"].inverse(params["src"])[:, 1], '.b')
        ax[1].axis((0, params["out_cols"], params["out_rows"], 0))
    plt.show()

In [None]:
x = skimage.io.imread("/data1/image-segmentation/BBBC022/unet/x/IXMtest_G01_s2_w1FBE52723-8BDF-4346-89BB-216A4A69ED1C.png")
y = skimage.io.imread("/data1/image-segmentation/BBBC022/unet/y/IXMtest_G01_s2_w1FBE52723-8BDF-4346-89BB-216A4A69ED1C.png")

fig, ax = plt.subplots(1,2, figsize=(18,12))
ax[0].imshow(x)
ax[1].imshow(y)

In [None]:
for i in range(10, 12, 2):
    for j in range(15, 20, 1):
        out1, out2, params = deform(x, y, points=i, distort=j)
        display(out1, out2, i, j, params)

In [None]:
%timeit out1, out2, _ = deform(x, y, points=20, distort=6)