In [1]:
import pims
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal, ndimage, stats
from scipy.ndimage.filters import uniform_filter
from scipy.interpolate import splprep, splev
from scipy.integrate import simps
from skimage import filters, feature, morphology, draw, measure, segmentation
from skimage.morphology import disk
from skimage.color import rgb2gray
from BKlib import write_video, KalmanSmoother2D
from snakes import fit_snake

%pylab


def correct_orientation(im):
    """Correct the frame orientations (to match imagej)"""
    return np.fliplr(im)


def tiff_to_ndarray(fn):
    """
    Load a tiff stack as 3D numpy array. 
    You must have enough RAM to hold the whole movie in memory.
    """
    frames = pims.TiffStack(fn)
    num_frames = len(frames)
    sz = frames.frame_shape
    arr = np.empty((num_frames, sz[0], sz[1]), dtype=frames.pixel_type)
    for frame_num, frame in enumerate(frames):
        arr[frame_num, :, :] = np.fliplr(np.swapaxes(frame, 0, 1))
    return arr


def enhance_ridges(frame):
    """A ridge detection filter (larger hessian eigenvalue)"""
    blurred = filters.gaussian_filter(frame, 2)
    sigma = 4.5
    Hxx, Hxy, Hyy = feature.hessian_matrix(blurred, sigma=sigma, mode='nearest')
    ridges = feature.hessian_matrix_eigvals(Hxx, Hxy, Hyy)[0]
    return np.abs(ridges)


def create_mask(frame):
    """"Create a big mask that encompasses all the cells"""
    
    # detect ridges
    ridges = enhance_ridges(frame)

    # threshold ridge image
    thresh = filters.threshold_otsu(ridges)
    thresh_factor = 1.1
    prominent_ridges = ridges > thresh_factor*thresh
    prominent_ridges = morphology.remove_small_objects(prominent_ridges, min_size=128)

    # the mask contains the prominent ridges
    mask = morphology.convex_hull_image(prominent_ridges)
    mask = morphology.binary_erosion(mask, disk(10))
    return mask


def frame_to_distance_images(frame):
    """
    Compute the skeleton of the cell boundaries, return the 
    distance transform of the skeleton and its branch points.
    """
    
    # distance from ridge midlines
    frame = rgb2gray(frame)
    ridges = enhance_ridges(frame)
    thresh = filters.threshold_otsu(ridges)
    prominent_ridges = ridges > 0.8*thresh
    skeleton = morphology.skeletonize(prominent_ridges)
    edge_dist = ndimage.distance_transform_edt(-skeleton)
    edge_dist = filters.gaussian_filter(edge_dist, sigma=2)

    # distance from skeleton branch points (ie, ridge intersections)
    blurred_skeleton = uniform_filter(skeleton.astype(float), size=3)
    corner_im = blurred_skeleton > 4./9
    corner_dist = ndimage.distance_transform_edt(-corner_im)
    
    return edge_dist, corner_dist


def mask_to_boundary_pts(mask, pt_spacing=5):
    """
    Convert a binary image containing a single object to a set
    of 2D points that are equally spaced along the object's contour.
    """

    # interpolate boundary
    boundary_pts = measure.find_contours(mask, 0)[0]
    tck, u = splprep(boundary_pts.T, u=None, s=0.0, per=1)
    u_new = np.linspace(u.min(), u.max(), 1000)
    x_new, y_new = splev(u_new, tck, der=0)

    # get equi-spaced points along spline-interpolated boundary
    x_diff, y_diff = np.diff(x_new), np.diff(y_new)
    S = simps(np.sqrt(x_diff**2 + y_diff**2))
    N = int(round(S/pt_spacing))

    u_equidist = np.linspace(0, 1, N+1)
    x_equidist, y_equidist = splev(u_equidist, tck, der=0)
    return np.array(zip(x_equidist, y_equidist))


def segment_cells(frame, mask=None):
    """
    Compute the initial segmentation based on ridge detection + watershed.
    This works reasonably well, but is not robust enough to use by itself.
    """
    
    blurred = filters.gaussian_filter(frame, 2)
    ridges = enhance_ridges(frame)
    
    # threshold ridge image
    thresh = filters.threshold_otsu(ridges)
    thresh_factor = 0.6
    prominent_ridges = ridges > thresh_factor*thresh
    prominent_ridges = morphology.remove_small_objects(prominent_ridges, min_size=256)
    prominent_ridges = morphology.binary_closing(prominent_ridges)
    prominent_ridges = morphology.binary_dilation(prominent_ridges)
    
    # skeletonize
    ridge_skeleton = morphology.medial_axis(prominent_ridges)
    ridge_skeleton = morphology.binary_dilation(ridge_skeleton)
    ridge_skeleton *= mask
    ridge_skeleton -= mask
    
    # label
    cell_label_im = measure.label(ridge_skeleton)
    
    # morphological closing to fill in the cracks
    for cell_num in range(1, cell_label_im.max()+1):
        cell_mask = cell_label_im==cell_num
        cell_mask = morphology.binary_closing(cell_mask, disk(3))
        cell_label_im[cell_mask] = cell_num
    
    return cell_label_im 


class CellSelectorGUI:
    """
    This class displays a labelled image and allows the user to select
    a region of interest with the mouse. All the labels that are clicked
    on are stored in the list "cell_labels". 
    """
    
    def __init__(self, cell_labels):
        
        cell_mask = np.ma.masked_where(np.ones_like(cell_labels), np.ones(cell_labels.shape))
        
        grid_sz = 7
        self.fig = plt.figure()
        plt.subplot2grid((grid_sz,grid_sz), (0,0), colspan=grid_sz, rowspan=grid_sz-1)
        plt.imshow(cell_labels, cmap='jet')
        self.mask_im = plt.imshow(cell_mask, cmap='gray')
        plt.title('Click to select a cell to track')
        plt.axis('off')
        
        button_ax = plt.subplot2grid((grid_sz,grid_sz ), (grid_sz-1,grid_sz/2))
        self.done_button = plt.Button(button_ax, 'Done')
        self.done_button.on_clicked(self.on_button_press)
        self.cid = self.fig.canvas.mpl_connect('button_press_event', self.on_mouse_click)
        
        self.cell_labels = cell_labels
        self.selected_cell_labels = []
        self.cell_mask = None
        self.fig.canvas.start_event_loop(timeout=-1)
        
    def on_mouse_click(self, event):
        try:
            i, j = int(round(event.ydata)), int(round(event.xdata))
            label = self.cell_labels[i,j]
            if label != 0:
                self.cell_mask = self.cell_labels==label
                self.cell_mask = morphology.binary_dilation(self.cell_mask)
                cell_mask = np.ma.masked_where(self.cell_mask==0, 255*np.ones(self.cell_mask.shape))
                self.mask_im.set_data(cell_mask)
                plt.draw()
                self.selected_cell_labels.append(label)
        except TypeError:  
            pass  # clicked outside axes
        
    def on_button_press(self, event):
        plt.close()
        self.fig.canvas.stop_event_loop()

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy


In [2]:
### load raw image ###

input_dir_name = '/home/brian/Projects/Munro-CellCrawling/input_movies/better_movies/'
fn = input_dir_name + '27Oct2015_mChSqh_grCM_20s_e_1.tif'
frames = tiff_to_ndarray(fn).astype(float)

# the mask (same for all frames)
mask_frame = frames[0,:,:]
mask = create_mask(mask_frame)


In [3]:
### Use a GUI the simple segmentation to select a cell for scrutiny ###
   
# segment first frame via ridge detection + watershed
frame = frames[0]
cell_labels = segment_cells(frame, mask)

# show the GUI and get its output
plt.ion()
cell_selector = CellSelectorGUI(cell_labels)
selected_label = cell_selector.selected_cell_labels[-1]
cell_mask = cell_selector.cell_mask
print 'You selected cell %i' % selected_label

You selected cell 22


  cbook._putmask(xa, xa<0.0, -1)


In [7]:
### Compute the snake-based contour for each frame ###
from skimage import img_as_uint

# This is the slowest cell; takes ~80 sec on my laptop.
tsta = time.clock()
 
boundary_pts = mask_to_boundary_pts(cell_mask, pt_spacing=6)
all_boundary_pts = np.empty((len(frames), boundary_pts.shape[0], boundary_pts.shape[1]))
masks = np.empty((len(frames), frames.shape[1], frames.shape[2]), dtype=bool)


for frame_num, frame in enumerate(frames):
    
    # print progress
    if frame_num%10 == 0:
        print 'computing frame %i of %i' % (frame_num, len(frames))
        sys.stdout.flush()
    
    # compute distance transforms and fit snake
    edge_dist, corner_dist = frame_to_distance_images(frame)
    boundary_pts = fit_snake(boundary_pts, edge_dist, alpha=0.1, beta=0.1, nits=40)
    
    # check if it went off the edge
    single_cell_mask = measure.grid_points_in_poly(frame.shape, boundary_pts)
    all_cells_mask = create_mask(mask_frame)    
    if np.any(np.logical_and(~all_cells_mask, single_cell_mask)):
        print 'cell went off edge on frame %i' % frame_num
        all_boundary_pts = np.delete(all_boundary_pts, np.s_[frame_num+1:], 0)
        break
    
    # store results in big array
    all_boundary_pts[frame_num,:,:] = boundary_pts
    masks[frame_num,:,:] = measure.grid_points_in_poly(frame.shape, boundary_pts)
    
    

print 'elapsed time:', time.clock() - tsta
print 'shape of all_boundary_pts', all_boundary_pts.shape
# write masks to file
fn = 'cell%i_masks.npy' % selected_label
print 'Saving masks to', fn
np.save(fn, masks)

computing frame 0 of 91
computing frame 10 of 91
computing frame 20 of 91
computing frame 30 of 91
computing frame 40 of 91
computing frame 50 of 91
computing frame 60 of 91
computing frame 70 of 91
computing frame 80 of 91
cell went off edge on frame 87
elapsed time: 158.75
shape of all_boundary_pts (87, 37, 2)
Saving masks to cell22_masks.npy


In [8]:
### visualize cell motion to check the results ###


def interpolate_boundary_pts(pts, N=200):
    
    # interpolate boundary
    tck, u = splprep(pts.T, u=None, s=0.0, per=1)
    u_new = np.linspace(u.min(), u.max(), 1000)
    x_new, y_new = splev(u_new, tck, der=0)

    # get equi-spaced points along spline-interpolated boundary
    x_diff, y_diff = np.diff(x_new), np.diff(y_new)
    S = simps(np.sqrt(x_diff**2 + y_diff**2))
    
    u_dense = np.linspace(0, 1, N+1)
    x_dense, y_dense = splev(u_dense, tck, der=0)
    return x_dense, y_dense


boundary_pts = all_boundary_pts[0]
boundary_x, boundary_y = interpolate_boundary_pts(boundary_pts)
mask = measure.grid_points_in_poly(frame.shape, boundary_pts)
center = measure.regionprops(mask)[0].centroid

fig = plt.figure()
implot = plt.imshow(frames[0], cmap='gray')
boundary_plot, = plt.plot(boundary_y, boundary_x, 'ro', markeredgecolor='none', markersize=3)
boundary_pts_plot, = plt.plot(boundary_pts[:,1], boundary_pts[:,0], 'ro', markeredgecolor='none', markersize=4)
center_point, = plt.plot(center[1], center[0], 'ro', markeredgecolor='r', markersize=7)
plt.axis('off')
fig.canvas.draw()


for frame_num, (frame, boundary_pts) in enumerate(zip(frames, all_boundary_pts)):
    
    print
    print 'frame_num %i' % frame_num
    print 'first point', boundary_pts[0]
    print 'last point', boundary_pts[-1]
    boundary_x, boundary_y = interpolate_boundary_pts(boundary_pts)
    mask = measure.grid_points_in_poly(frame.shape, boundary_pts)
    center = measure.regionprops(mask)[0].centroid
    
    implot.set_data(frame)
    boundary_plot.set_data(boundary_y, boundary_x)
    boundary_pts_plot.set_data(boundary_pts[:,1], boundary_pts[:,0])
    center_point.set_data(center[1], center[0])
    plt.title('frame %i' % frame_num)
    fig.canvas.draw()
    
#     time.sleep(0.2)



frame_num 0
first point [ 285.20959768  327.86165546]
last point [ 285.20959768  327.86165546]

frame_num 1
first point [ 285.9995176   327.75287836]
last point [ 283.98148221  331.95548455]

frame_num 2
first point [ 287.00681066  327.81594801]
last point [ 284.87410606  331.98751257]

frame_num 3
first point [ 286.99727802  327.41915035]
last point [ 285.04010939  331.79549635]

frame_num 4
first point [ 287.9894451   327.23733958]
last point [ 285.84258205  331.73385039]

frame_num 5
first point [ 287.99779782  327.0307406 ]
last point [ 285.93677604  331.99778129]

frame_num 6
first point [ 289.10644892  327.02385213]
last point [ 286.80451055  331.98694066]

frame_num 7
first point [ 289.28779991  326.80835816]
last point [ 287.00842348  332.00574693]

frame_num 8
first point [ 291.01477967  326.7403742 ]
last point [ 288.98417539  331.99844297]

frame_num 9
first point [ 291.86919942  326.81150621]
last point [ 289.99264703  331.9927279 ]

frame_num 10
first point [ 292.01805278

In [6]:
### save the visualization as a movie ###
from BKlib import write_video

def to_rgb(im):
    im /= im.max()
    im = np.round(255*im)
    im = im.astype(np.uint8)
    
    w, h = im.shape
    ret = np.empty((w, h, 3), dtype=np.uint8)
    ret[:, :, 0] = im
    ret[:, :, 1] = im
    ret[:, :, 2] = im
    return ret


rgb_frame_sz = (frames.shape[0], frames.shape[1], frames.shape[2], 3)
movie_frames = np.empty(rgb_frame_sz, dtype=np.uint8)

for frame_num, (frame, mask) in enumerate(zip(frames.copy(), masks)):
    
    frame = to_rgb(frame)
    
    # colorize boundary
    boundary = segmentation.find_boundaries(mask)
    boundary = morphology.binary_dilation(boundary)
    r, g, b = frame[:,:,0], frame[:,:,1], frame[:,:,2]
    r[boundary>0] = 255
    g[boundary>0] = 0
    b[boundary>0] = 0
    
    movie_frames[frame_num,:,:,:] = correct_orientation(frame)
  

fn = 'cell%i_save.avi' % selected_label
print 'Saving visualization to', fn
write_video(movie_frames, fn)

Saving visualization to cell6_save.avi


In [5]:
### smooth and plot cell trajectory ###

centers = np.array([measure.regionprops(mask)[0].centroid for mask in masks])

# estimate initial velocity via regression on first few timepoints
init_pts = 5
vx0, _, _, _, _ = stats.linregress(range(init_pts), centers[:init_pts, 0])
vy0, _, _, _, _ = stats.linregress(range(init_pts), centers[:init_pts, 1])
initial_state = np.array([centers[0,0], centers[0,1], vx0, vy0])

# smooth the cell center positions
position_noise = 15.0  # higher noise -> heavier smoothing
smoother = KalmanSmoother2D(position_noise, position_noise)
smoother.set_initial_state(initial_state)
smoother.set_measurements(centers)
smooth_cell_centers = smoother.get_smoothed_measurements()
cell_velocities = smoother.get_velocities()

# plot trajectory
plt.figure()
plt.plot(centers[:,1], centers[:,0], 'bx')
plt.plot(smooth_cell_centers[:,1], smooth_cell_centers[:,0],'r-')
plt.axis('equal')
plt.title('Trajectory for cell %i' % selected_label)
plt.gca().invert_yaxis()
plt.xlabel('x (pixels)')
plt.ylabel('y (pixels)')

# show histogram of velocity
plt.figure()
plt.hist(cell_velocities[:,1], bins=25)
plt.hist(cell_velocities[:,0], bins=25)
plt.legend(['x speed', 'y speed'], numpoints=1)
plt.title('Speed distribution for cell %i' % selected_label)
plt.xlabel('speed (pixels/frame)')
plt.ylabel('frame count')

<matplotlib.text.Text at 0x56f01d0>

In [6]:
### save a "zoomed in" movie that follows the cell ###

# compute window size (max mask size + padding)
padding = 20
bboxes = [measure.regionprops(mask)[0].bbox for mask in masks]
bbox_widths = [(bbox[2]-bbox[0]) for bbox in bboxes]
bbox_heights = [(bbox[3]-bbox[1]) for bbox in bboxes]
win_size = (max(bbox_widths)+2*padding, max(bbox_heights)+2*padding)

# make it square (apparently avconv doesn't work correctly with arbitrary frame size...)
win_size = (max(win_size), max(win_size))

windows = np.empty((len(frames), win_size[0], win_size[1]))
for frame_num, (frame, center, mask) in enumerate(zip(frames, smooth_cell_centers, masks)):
    
#     # make boundary bright
#     boundary = segmentation.find_boundaries(mask)
#     frame += 0.5*frame.max()*boundary

    # extract window centered on mask centroid
    i, j = int(round(center[0])), int(round(center[1]))
    ista, jsta = i-win_size[0]/2, j-win_size[1]/2
    iend, jend = ista+win_size[0], jsta+win_size[1]
    win = frame[ista:iend, jsta:jend]
    windows[frame_num,:,:] = correct_orientation(win)
    
    if frame_num == 0:
        mask_win = mask[ista:iend, jsta:jend]
        np.save('mask.npy', mask_win)

    
    plt.imshow(win)
    
    break

    
np.save('window.npy', windows[0])
# # save to file
# fn = 'cell%i_zoom.avi' % selected_label
# print 'Saving to', fn
# write_video(windows, fn)