In [None]:
import sys
sys.path.append('..')
from par_segmentation import load_image, offset_coordinates, interp_roi, straighten, interp_1d_array
from par_segmentation.model import ImageQuantGradientDescent, create_offsets_spline
from matplotlib import animation
from matplotlib_polyroi import RoiJupyter
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook

In [None]:
img = load_image('nwg338_af_corrected.tif')

### Rough manual ROI

In [None]:
# r1 = RoiJupyter(img, periodic=True, spline=True)
# r1.run()

In [None]:
# np.savetxt('nwg338_ROI_for_animation.txt', r1.roi)

In [None]:
roi = np.loadtxt('nwg338_ROI_for_animation.txt')

### Run

In [None]:
straights = []
rois_new = []
mems = []

In [None]:
iq = ImageQuantGradientDescent(img=img, roi=roi, sigma=3.5, descent_steps=400, rol_ave=5, 
                                lr=0.01, iterations=1, fit_outer=True, roi_knots=20, nfits=None, save_training=True,
                                zerocap=False, freedom=25)

In [None]:
# First iteration
iq.run()
iq.adjust_roi()
roi2 = iq.roi[0]
# np.savetxt('nwg338_ROI_mid.txt', roi2)

for i in iq.saved_vars[0::10]:
    a = tf.concat((i['offsets'], i['offsets'][:, :1]), axis=1)
    offsets_spline = create_offsets_spline(i['offsets'], iq.roi_knots, iq.periodic, iq.n, iq.nfits, [roi,]).numpy()
    roi_new = offset_coordinates(roi, iq.freedom * tf.math.tanh(offsets_spline[0]))
    roi_new_interp = interp_roi(roi_new)
    rois_new.append(roi_new_interp)
    straights.append(straighten(img, roi_new_interp, 50))
    mems.append(interp_1d_array(i['mems'][0], roi_new_interp.shape[0]))

In [None]:
# # Second iteration
# iq.run()
# iq.adjust_roi()
# roi3 = iq.roi[0]

# for i in iq.saved_vars[0::10]:
#     a = tf.concat((i['offsets'], i['offsets'][:, :1]), axis=1)
#     offsets_spline = create_offsets_spline(i['offsets'], iq.roi_knots, iq.periodic, iq.n, iq.nfits, [roi2,]).numpy()
#     roi_new = offset_coordinates(roi2, iq.freedom * tf.math.tanh(offsets_spline[0]))
#     roi_new_interp = interp_roi(roi_new)
#     rois_new.append(roi_new_interp)
#     straights.append(straighten(img, roi_new_interp, 50))
#     mems.append(interp_1d_array(i['mems'][0], roi_new_interp.shape[0]))

In [None]:
# # Third iteration
# iq.run()

# for i in iq.saved_vars[0::10]:
#     a = tf.concat((i['offsets'], i['offsets'][:, :1]), axis=1)
#     offsets_spline = create_offsets_spline(i['offsets'], iq.roi_knots, iq.periodic, iq.n, iq.nfits, [roi3,]).numpy()
#     roi_new = offset_coordinates(roi3, iq.freedom * tf.math.tanh(offsets_spline[0]))
#     roi_new_interp = interp_roi(roi_new)
#     rois_new.append(roi_new_interp)
#     straights.append(straighten(img, roi_new_interp, 50))

### Segmentation animation

In [None]:
from matplotlib.lines import Line2D

class LineDataUnits(Line2D):
    def __init__(self, *args, **kwargs):
        _lw_data = kwargs.pop("linewidth", 1) 
        super().__init__(*args, **kwargs)
        self._lw_data = _lw_data

    def _get_lw(self):
        if self.axes is not None:
            ppd = 72./self.axes.figure.dpi
            trans = self.axes.transData.transform
            return ((trans((1, self._lw_data))-trans((0, 0)))*ppd)[1]
        else:
            return 1

    def _set_lw(self, lw):
        self._lw_data = lw

    _linewidth = property(_get_lw, _set_lw)

In [None]:
max_width = max(s.shape[1] for s in straights)
max_intensity = max(np.max(s) for s in straights)

fig, ax = plt.subplots(2, 1)
@widgets.interact(t=(0, len(straights)-1, 1))
def update1(t=0): 
    ax[0].clear()
    ax[0].imshow(img[170:-110, 130:-100], cmap='gray', vmin=0, vmax=max_intensity)
    ax[0].axis('off')
    ax[0].annotate('', xy=(0.55, -0.2), xycoords='axes fraction', xytext=(0.55, -0.05), 
                arrowprops=dict(arrowstyle="->", color='k', linewidth=2))
    ax[0].annotate('', xy=(0.45, -0.2), xycoords='axes fraction', xytext=(0.45, -0.05), 
                arrowprops=dict(arrowstyle="<-", color='k', linewidth=2))
    line = LineDataUnits(rois_new[t][:, 0] - 130, rois_new[t][:, 1] - 170, c='tab:cyan', linewidth=5, alpha=0.3)
    ax[0].add_line(line)
    
    ax[1].clear()
    pad_full = max_width - straights[t].shape[1]
    pad_left = int(np.ceil(pad_full/2))
    pad_right = int(pad_full/2)
    straight_padded = np.c_[max_intensity * np.ones([50, pad_left]), straights[t], 
                            max_intensity * np.ones([50, pad_right])]
    ax[1].imshow(straight_padded, cmap='gray', vmin=0, vmax=max_intensity)
    ax[1].axis('off')
    
fig.set_size_inches(6, 3.5)
fig.subplots_adjust(hspace=-0.1, bottom=-0.1)

In [None]:
fig, ax = plt.subplots(2, 1)
fig.set_size_inches(6, 3.5)
fig.subplots_adjust(hspace=-0.1, bottom=-0.1)

# def update(t=0): 
#     ax[0].clear()
#     ax[0].imshow(img[170:-110, 130:-100], cmap='gray', vmin=0, vmax=max_intensity)
#     ax[0].axis('off')
#     ax[0].annotate('', xy=(0.55, -0.2), xycoords='axes fraction', xytext=(0.55, -0.05), 
#                 arrowprops=dict(arrowstyle="->", color='k', linewidth=2))
#     ax[0].annotate('', xy=(0.45, -0.2), xycoords='axes fraction', xytext=(0.45, -0.05), 
#                 arrowprops=dict(arrowstyle="<-", color='k', linewidth=2))
#     line = LineDataUnits(rois_new[t][:, 0] - 130, rois_new[t][:, 1] - 170, c='tab:cyan', linewidth=5, alpha=0.3)
#     ax[0].add_line(line)
    
#     ax[1].clear()
#     pad_full = max_width - straights[t].shape[1]
#     pad_left = int(np.ceil(pad_full/2))
#     pad_right = int(pad_full/2)
#     straight_padded = np.c_[max_intensity * np.ones([50, pad_left]), straights[t], 
#                             max_intensity * np.ones([50, pad_right])]
#     ax[1].imshow(straight_padded, cmap='gray', vmin=0, vmax=max_intensity)
#     ax[1].axis('off')

frames = np.r_[[0], np.arange(0, len(straights)-1,)]
anim = animation.FuncAnimation(fig, update1, frames=iter(frames), save_count=len(frames))
writer = animation.writers['ffmpeg']
writer = writer(fps=24, bitrate=2000)
anim.save('Figs/animation.gif', writer=writer, dpi=200)

### Quantification animation

In [None]:
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(8, 3)
fig.subplots_adjust(wspace=0.4, bottom=0.2)

ymax, ymin = max(np.max(m) for m in mems), min(np.min(m) for m in mems)
xmax = max(len(m) for m in mems) * 1.1
@widgets.interact(t=(0, len(mems)-1, 1))
def update2(t=0): 
    
    ax[0].clear()
    ax[0].imshow(img[170:-110, 130:-100], cmap='gray', vmin=0, vmax=max_intensity)
    ax[0].axis('off')
    ax[0].annotate('', xy=(1.2, 0.55), xycoords='axes fraction', xytext=(1.05, 0.55), 
                arrowprops=dict(arrowstyle="->", color='k', linewidth=2))
    line = LineDataUnits(rois_new[t][:, 0] - 130, rois_new[t][:, 1] - 170, c='tab:cyan', linewidth=5, alpha=0.3)
    ax[0].add_line(line)
    ax[0].scatter(rois_new[t][0, 0] - 130, rois_new[t][0, 1] - 170, c='r', edgecolors='k', zorder=10)
    
    left = (xmax - len(mems[t])) / 2
    right = left + len(mems[t])
    ax[1].clear()
    ax[1].plot(np.linspace(left, right, len(mems[t])), mems[t], c='tab:cyan')
    ax[1].set_ylim(ymin, ymax)   
    ax[1].set_xlim(-10, xmax + 10)
    ax[1].scatter(left, mems[t][0], c='r', edgecolors='k', zorder=10)
    ax[1].scatter(right, mems[t][-1], c='r', edgecolors='k', zorder=10)
    ax[1].set_xticks([])
    ax[1].set_xlabel('Position\n(clockwise from posterior)', labelpad=10)
    ax[1].set_yticks([])
    ax[1].set_ylabel('Membrane concentration', labelpad=10)
    ax[1].axhline(0, linestyle='--', c='tab:gray', zorder=-10, linewidth=1)

In [None]:
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(8, 3)
fig.subplots_adjust(wspace=0.4, bottom=0.2)

ymax, ymin = max(np.max(m) for m in mems), min(np.min(m) for m in mems)
xmax = max(len(m) for m in mems) * 1.1
# def update(t=0): 
    
#     ax[0].clear()
#     ax[0].imshow(img[170:-110, 130:-100], cmap='gray', vmin=0, vmax=max_intensity)
#     ax[0].axis('off')
#     ax[0].annotate('', xy=(1.2, 0.55), xycoords='axes fraction', xytext=(1.05, 0.55), 
#                 arrowprops=dict(arrowstyle="->", color='k', linewidth=2))
#     line = LineDataUnits(rois_new[t][:, 0] - 130, rois_new[t][:, 1] - 170, c='tab:cyan', linewidth=5, alpha=0.3)
#     ax[0].add_line(line)
#     ax[0].scatter(rois_new[t][0, 0] - 130, rois_new[t][0, 1] - 170, c='r', edgecolors='k', zorder=10)
    
#     left = (xmax - len(mems[t])) / 2
#     right = left + len(mems[t])
#     ax[1].clear()
#     ax[1].plot(np.linspace(left, right, len(mems[t])), mems[t], c='tab:cyan')
#     ax[1].set_ylim(ymin, ymax)   
#     ax[1].set_xlim(-10, xmax + 10)
#     ax[1].scatter(left, mems[t][0], c='r', edgecolors='k', zorder=10)
#     ax[1].scatter(right, mems[t][-1], c='r', edgecolors='k', zorder=10)
#     ax[1].set_xticks([])
#     ax[1].set_xlabel('Position\n(clockwise from posterior)', labelpad=10)
#     ax[1].set_yticks([])
#     ax[1].set_ylabel('Membrane concentration', labelpad=10)
#     ax[1].axhline(0, linestyle='--', c='tab:gray', zorder=-10, linewidth=1)
    
frames = np.r_[[0], np.arange(0, len(mems)-1,)]
anim = animation.FuncAnimation(fig, update2, frames=iter(frames), save_count=len(frames))
writer = animation.writers['ffmpeg'](fps=24, bitrate=2000)
anim.save('Figs/animation2.gif', writer=writer, dpi=200)