# Python notebook for pre-processing an ROI image stack.
# Image Stabilization
Assumes folder directory structure:
<pre><code>  IMAGING
    image_stacks
    notebooks
    results
</code></pre>
Execute the code sequentially, one block at a time, using &lt;shift-return&gt;.
#### Initialize.

In [1]:
%matplotlib widget

import glob
from ipyfilechooser import FileChooser
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import os
from scipy.interpolate import splprep, splev
from scipy.spatial import cKDTree
from skimage import color, data, exposure, filters, io
from skimage.draw import disk, circle_perimeter
from skimage.feature import canny
from skimage.morphology import binary_erosion, binary_dilation
from skimage.morphology import remove_small_objects
from skimage.restoration import denoise_bilateral, denoise_wavelet
from skimage.util import img_as_ubyte, img_as_int, img_as_float
import skimage.transform as tf
from sklearn.cluster import DBSCAN

# global variables
if os.name == "nt":
    FILE_SEP = "\\"
else:
    FILE_SEP = "/"


#### Select an image stack file and set options.

In [10]:
%matplotlib widget

# global variables
image_stack = ""    # the selected image stack
image_bits = 10     # bits per pixel (can be found in the oir meta data)
high_magnification = True

s = {'description_width':'200px'} # a default widget style

# create image files widget
image_files = sorted([f.split(FILE_SEP)[-1] for f in glob.glob("../image_stacks/*.tif", recursive=False)], key=str.casefold)
image_widget = widgets.Select(options=image_files, description='Image stack', 
                            disabled=False, layout=widgets.Layout(width='400px'))
# create image bits widget
image_bits_widget = widgets.BoundedIntText(value=image_bits, min=8, max=16, step=1,
                    description='Image data bits', disabled=False, layout={'width':'270px'}, style=s)
# create high magnification widget
high_magnification_widget = widgets.Checkbox(value=high_magnification, description='High magnification image?',
                 disabled=False, indent=True)

def f(w1,w2,w3):
  global image_stack, high_magnification
  image_stack = image_widget.value
  image_bits = image_bits_widget.value
  high_magnification = high_magnification_widget.value
display(widgets.interactive(f, w1=image_widget, w3=image_bits_widget, w2=high_magnification_widget))


interactive(children=(Select(description='Image stack', layout=Layout(width='400px'), options=('MistGcamp-2_4x…

#### Get an image stack.

In [13]:
# Load picture
images = io.imread("../image_stacks/" + image_stack)
images = np.float32(images/(2.0**image_bits))
zdepth = images.shape[0]
for i in images:
  for l in range(i.shape[0] - 1): # moving average over every two lines
    i[l] = (i[l] + i[l+1]) / 2.0


#### OPTIONAL: Use this code block to interactively explore landmark nuclei detection parameters.

In [14]:
from __future__ import print_function

plt.close('all')
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8,8))

@interact(
  gn=widgets.FloatSlider(description='image gain',min=1.0, max=5.0, step=0.1, value=1.0),
  sr=widgets.IntRangeSlider(description='stack range',min=0, max=zdepth, step=1, value=[0,8]), 
  bs=widgets.FloatSlider(description='BILATERAL sigma',min=0.0, max=4.0, step=0.1, value=1.0), 
  cs=widgets.FloatSlider(description='CANNY sigma',min=1.0, max=4.0, step=0.1, value=1.8), 
  ct=widgets.IntRangeSlider(description='threshold',min=0, max=100, step=1, value=[9,22]),
  hr=widgets.IntRangeSlider(description='HOUGH radii',min=3, max=25, step=1, value=[x*(2 if high_magnification else 1) for x in[5,8]]),
  hd=widgets.IntSlider(description='distance',min=5, max=50, step=1, value=10),
  hp=widgets.IntSlider(description='peaks',min=50, max=500, step=10, value=270),
  ht=widgets.FloatSlider(description='threshold',min=0.0, max=1.0, step=0.01, value=0.12),
  cr=widgets.FloatSlider(description='circle ratio',min=1.0, max=2.0, step=0.01, value=1.2))

def f(gn, sr, bs, cs, ct, hr, hd, hp, ht, cr):
  A = gn*np.mean(images[sr[0]:sr[1]], axis=0) # the static images
  A0 = A / np.amax(A) # normalize
  imageA = color.gray2rgb(img_as_ubyte(A0))

  # identify nuclei (circles)   
  #A = filters.gaussian(A0, sigma=gs) # noise filter
  #A = denoise_wavelet(A0, wavelet_levels=7, multichannel=False, rescale_sigma=False)
  A = denoise_bilateral(A0, sigma_spatial=bs)
  edges = canny(img_as_ubyte(A), sigma=cs, low_threshold=ct[0], high_threshold=ct[1])
  hough_radii = np.arange(hr[0], hr[1], 1) # the range of radii to use in search
  hough_res = tf.hough_circle(edges, hough_radii) # look for circles
  accums, cy, cx, radii = tf.hough_circle_peaks(hough_res, hough_radii, min_xdistance=hd, 
                                           min_ydistance=hd, total_num_peaks=hp, 
                                           threshold=ht, normalize=False)

  # remove false positives (bright disks with dark perimeter)
  pix = [] # as an empty list (for the remaining center pixels)
  for center_x, center_y, radius in zip(cx, cy, radii):
    c = disk((center_x, center_y), radius, shape=A0.shape) # central disk
    cp = circle_perimeter(center_x, center_y, radius+1, shape=A0.shape) # perimeter ring
    if (np.mean(imageA[cp]) / np.mean(imageA[c])) > cr:
      pix.append((center_x, center_y)) # dark disks with bright perimeter are OK

  # remove duplicates (close center pixels)
  pix = np.array(pix) # as a numpy array
  tree = cKDTree(pix) # for pairwise distance query
  rows_to_fuse = list(tree.query_pairs(r=8.0))
  p = np.ones(pix.shape[0])           # array of "keep" flags
  if(len(rows_to_fuse)):
    p[np.array(rows_to_fuse)[:,0]] = 0  # flag the first of all duplicate pairs for deletion
  pixx = pix[p.astype(bool)]          # the remaining center pixels

  # draw nuclei centre pixels
  for i in pixx:
    #imageA[i[0], i[1]] = (255,0,0)
    imageA[disk((i[0], i[1]), 1.1, shape=A0.shape)] = (255,0,0)
  
  ax.cla()
  ax.imshow(imageA, norm=None)
  plt.show()
  return(str(pixx.shape[0]) + " nuclei identified")


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=1.0, description='image gain', max=5.0, min=1.0), IntRangeSlider(value…

#### Find all landmark nuclei in the image stack.
NOTE: Can take several minutes to run.

In [4]:
# landmark detection paramters
bs = 1.0 
cs = 1.8
ct = [9,22]
hr = [x*(2 if high_magnification else 1) for x in[5,8]]
hd = 10
hp = 270
ht = 0.12
cr = 1.2

pixx = [] # a list of all the landmark nuclei centers
min_n = 100000   # the least number of nuclei identified in a frame
max_n = 0        # the most number of nuclei identified in a frame

print("Processing frame: ", end = '')
for i in range(3,images.shape[0]-3): # use moving average over seven frames
  A = np.mean(images[i-3:i+4], axis=0)
  A0 = A / np.amax(A) # normalized

  # identify nuclei (circles)   
  #A = filters.gaussian(A0, sigma=gs) # noise filter
  #A = denoise_wavelet(A0, wavelet_levels=7, multichannel=False, rescale_sigma=False)
  A = denoise_bilateral(A0, sigma_spatial=bs)
  edges = canny(img_as_ubyte(A), sigma=cs, low_threshold=ct[0], high_threshold=ct[1])
  hough_radii = np.arange(hr[0], hr[1], 1) # the range of radii to use in search
  hough_res = tf.hough_circle(edges, hough_radii) # look for circles
  accums, cx, cy, radii = tf.hough_circle_peaks(hough_res, hough_radii, min_xdistance=hd, 
                                           min_ydistance=hd, total_num_peaks=hp, 
                                           threshold=ht, normalize=False)

  # remove false positives (bright disks with dark perimeter)
  pix = [] # as an empty list (for the remaining center pixels)
  for center_y, center_x, radius in zip(cy, cx, radii):
    c = disk((center_y, center_x), radius, shape=A0.shape) # central disk
    cp = circle_perimeter(center_y, center_x, radius+1, shape=A0.shape) # perimeter ring
    if (np.mean(A0[cp]) / np.mean(A0[c])) > cr:
      pix.append((center_x, center_y)) # dark disks with bright perimeter are OK

  # remove duplicates (close center pixels)
  pix = np.array(pix) # as a numpy array
  tree = cKDTree(pix) # for pairwise distance query
  rows_to_fuse = list(tree.query_pairs(r=8.0))
  p = np.ones(pix.shape[0])           # array of "keep" flags
  if(len(rows_to_fuse)):
    p[np.array(rows_to_fuse)[:,0]] = 0  # flag the first of all duplicate pairs for deletion

  # get counts and append to the landmark list
  temp = np.full((np.count_nonzero(p),1),np.float(i))
  pp = pix[p.astype(bool)].astype(float)
  pp = np.concatenate((pp,temp),axis=1)
  pp = list(map(tuple,pp)) # the remaining center pixels 
  pixx += pp
  c = np.count_nonzero(p)
  if c < min_n:
    min_n = c
  if c > max_n:
    max_n = c    
  print(str(i) + ", ", end = '')
print("DONE.")
print("Range of per frame nuclei identified:", str(min_n) + '-' + str(max_n))


Processing frame: 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 2

#### OPTIONAL: Plot all landmark nuclei centers.

In [5]:
# plot landmarks
plt.close() # frees up memory
fig = plt.figure()
fig.suptitle("landmark nuclei centers")
ax = Axes3D(fig)

tp = np.array(pixx)
ax.scatter(tp[:,0],tp[:,1],tp[:,2])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …


#### Identify and plot landmark "threads" to use for image stabilization.

In [6]:
# identify and plot landmark threads

plt.close() # frees up memory
fig = plt.figure()
fig.suptitle("landmark threads")
ax = Axes3D(fig)

# distance based spatial clustering
tp = np.array(pixx)
tpp = tp * [1.0,1.0,0.5] # compress the z scale
db = DBSCAN(eps=10, min_samples=10).fit(tpp)
labels = db.labels_
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
core_samples_mask[db.core_sample_indices_] = True

# get cluster and noise counts
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)
print('Number of landmark threads: %d' % n_clusters_)
print('Number of deleted noise points: %d' % n_noise_)

unique_labels = set(labels)
colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
  class_member_mask = (labels == k)
  xy = tp[class_member_mask & core_samples_mask]
  ax.scatter(xy[:, 0], xy[:, 1], xy[:, 2], color=tuple(col))

plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Number of landmark threads: 126
Number of deleted noise points: 1334


#### OPTIONAL: Identify and plot a sample thread.

In [None]:
# get the longest thread
unique, counts = np.unique(labels, return_counts=True)
thread = unique[np.where(counts==np.max(counts[1:]))][0]

plt.close() # frees up memory
fig = plt.figure()
fig.suptitle("sample thread")

ax = Axes3D(fig)
ax.set_xlim3d(0,images.shape[1])
ax.set_ylim3d(0,images.shape[2])

tpp = tp[labels==thread]
ax.plot(tpp[:,0],tpp[:,1],tpp[:,2])
plt.show()


#### OPTIONAL: Check thread smoothing parameters on the sample thread.

In [None]:
x = tpp[:,0]
y = tpp[:,1]
z = tpp[:,2]

# smooth the thread
tckp,u = splprep([x,y,z],k=3,nest=-1,s=4000)
xnew,ynew,znew = splev(np.linspace(0,1,images.shape[0]),tckp)

plt.close() # frees up memory
fig = plt.figure()
fig.suptitle("sample thread - smoothed")

ax = Axes3D(fig)
ax.set_xlim3d(0,512)
ax.set_ylim3d(0,512)

ax.plot(xnew,ynew,znew)
plt.show()

#### Stabilize the image stack using piece-wise affine transformation warping. 

In [8]:
# get a copy of the original stack
A = img_as_float(io.imread("../image_stacks/" + image_stack)) # convert to float
out = np.copy(A)

# find threads that span the stack
# NOTE: there are no threads in the first or last three frames, so skip those
tcount = 0     # spanning thread count
ls = set(labels)
ls.remove(-1)
lxnew = []
lynew = []
for ll in ls:
  f = (tp[labels==ll])[0,2]
  l = (tp[labels==ll])[-1,2]
  if f==3 and l==(A.shape[0]-4): # NOTE: there are no threads in the first or last three frames
    tcount = tcount + 1
    tpp = tp[labels==ll]
    tckp,u = splprep([tpp[:,0],tpp[:,1],tpp[:,2]],s=4000,k=3,nest=-1)
    xnew,ynew,znew = splev(np.linspace(0,1,out.shape[0]),tckp)
    lxnew.append(xnew)
    lynew.append(ynew)
lxnew = np.array(lxnew)
lynew = np.array(lynew)
print("Found " + str(tcount) + " spanning threads.")

# find image cropping values (to eliminate black borders caused by translation)
XL = np.max(-np.int(np.floor(np.min(lxnew-lxnew[:,0][:,None]))),0)
XR = np.min(-np.int(np.ceil(np.max(lxnew-lxnew[:,0][:,None]))),0)
YL = np.max(-np.int(np.floor(np.min(lynew-lynew[:,0][:,None]))),0)
YR = np.min(-np.int(np.ceil(np.max(lynew-lynew[:,0][:,None]))),0)

# translate the frame corners using the average of the spanning thread translations
transx = np.mean(lxnew-lxnew[:,0][:,None], axis=0)
transy = np.mean(lynew-lynew[:,0][:,None], axis=0)
cornersx = np.full((lxnew.shape[1],4),[0,0,511,511]) + transx[:, None]
cornersy = np.full((lynew.shape[1],4),[0,511,0,511]) + transy[:, None]
lxnew = np.concatenate((lxnew, np.transpose(cornersx)))
lynew = np.concatenate((lynew, np.transpose(cornersy)))
lnew = np.array([lxnew, lynew])

# piece-wise affine transformation warping
print("Warping frame:", end = '')
for i in range(3, out.shape[0]-3):
  print(' ' + str(i) + ',', end = '')
  tform = tf.PiecewiseAffineTransform()
  tform.estimate(np.transpose(lnew[:,:,i]), np.transpose(lnew[:,:,0]))
  out[i] = tf.warp(A[i], tform.inverse)
print(" DONE.")

# save the stabilized image stack
for i in range(3): # duplicate the first and last three frames
  out[i] = out[3]
  out[-(1+i)] = out[-4]
io.imsave("../image_stacks/" + image_stack[0:-4] + "_stab.tif", 
    img_as_int(out[:,YL:YR,XL:XR]), check_contrast=False)  # out[x,y] goes to image(y,x)


Found 6 spanning threads.
Warping frame: 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215

#### OPTIONAL: Saved a cropped copy of the original image stack.

In [9]:
# saved cropped copy of original image
A = io.imread("../image_stacks/" + image_stack)
io.imsave("../image_stacks/" + image_stack[0:-4] + "_orig.tif", 
    A[:,YL:YR,XL:XR], check_contrast=False)  # out[x,y] goes to image(y,x)
print("DONE.")

DONE.
