In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import cv2
import PIL
from PIL import Image

%config InlineBackend.figure_format = 'retina'
%matplotlib inline
plt.style.use('dark_background')
plt.rcParams["axes.grid"] = False

from jupyterthemes import jtplot
jtplot.style(theme='monokai', context='notebook')

In [None]:
data_path = '/Volumes/HB-ExSSD-T5/2021-04-01 Phenix/HKB 4W SN CAR__2021-03-28T16_00_04-Measurement 1/Images/'

### Image & Timecourse Retrieval

In [None]:
def retrieve_and_stitch_image(path, row, col, channel, timepoint, tophat=None, blur=None):
    """
    Retrieve all the tiles for a timepoint and channel and stitch into global image for a measurement.
    
    Parameters
    ----------
    directory : path to files
    row, col : position in multi-well plate (1-indexed)
    channel : index of fluorescence/brightfield channel. Probably on a case by case basis.
    timepoint : index of the timepoint to retrieve (1-indexed)
    tophat : kernel size to use for tophat flatfield correction (None for no correction)
    blur : gaussian width to use for blur flatfield correction (None for no correction)
    
    Returns
    -------
    Pillow Image
    """
    
    # Setup tiling parameters (for stitching the tiles)
    n_rows = 11
    n_cols = 11
    tile_size = 1080 # px width of square fov tiles
    
    # Since the mapping of tile numbers to positions turns out to be arbitrary, we will specify that map here as an n_rows by n_cols array with integral values corresponding to the tile number.
    tiling_map = [[2, 3, 4, 5, 6, 7, 8, 9, 10],
                 [19, 18, 17, 16, 15, 14, 13, 12, 11],
                 [20, 21, 22, 23, 24, 25, 26, 27, 28],
                 [36, 35, 34, 33, 1, 32, 31, 30, 29],
                 [37, 38, 39, 40, 41, 42, 43, 44, 45],
                 [54, 53, 52, 51, 50, 49, 48, 47, 46],
                 [55, 56, 57, 58, 59, 60, 61, 62, 63],
                 [72, 71, 70, 69, 68, 67, 66, 65, 64]]
    
    tiling_map = [[-1, -1, -1, 2, 3, 4, 5, 6, -1, -1, -1],
                 [-1, 15, 14, 13, 12, 11, 10, 9, 8, 7, -1],
                 [-1, 16, 17, 18, 19, 20, 21, 22, 23, 24, -1],
                 [35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25],
                 [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46],
                 [56, 55, 54, 53, 52, 1, 51, 50, 49, 48, 47],
                 [57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67],
                 [78, 77, 76, 75, 74, 73, 72, 71, 70, 69, 68],
                 [-1, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1],
                 [-1, 96, 95, 94, 93, 92, 91, 90, 89, 88, -1],
                 [-1, -1, -1, 97, 98, 99, 100, 101, -1, -1, -1]]
    
    # Read in all the tiles.
    # File nomenclature is row : col : tile_number : fluorescence_channel : timepoint
    tiles = []
    for idx in range(1,n_rows*n_cols+1):
        try:
            tiles.append(Image.open(path + 'r%02dc%02df%02dp01-ch%isk%ifk1fl1.tiff' % 
                        (row, col, idx, channel, timepoint)))
        except:
            tiles.append(Image.fromarray(np.ones((tile_size,tile_size)), 'L'))
    
    # Perform the flatfield correction as desired.
    nt = []
    if tophat != None:
        for tile in tiles:
            kernel100 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(tophat, tophat))
            new_tile = cv2.morphologyEx(np.asarray(tile), cv2.MORPH_BLACKHAT, kernel100)
            nt.append(Image.fromarray(new_tile))
        tiles=nt
        
    if blur != None:
        for tile in tiles:
            inferred_bg = np.array(cv2.blur(np.asarray(tile), (blur, blur)))
            tile = np.array(tile)
            corrected = np.subtract(tile.astype('int32'), inferred_bg.astype('int32'))
            nt.append(corrected)
        tiles = nt
        
    global_image = np.empty((n_cols * tile_size, n_rows * tile_size), dtype='int32')
    for y, row in enumerate(tiling_map):
        for x, tile_idx in enumerate(row):
            x_coord = x * tile_size
            y_coord = y * tile_size
            tile = tiles[tile_idx-1]
            global_image[y_coord:y_coord+tile_size, x_coord:x_coord+tile_size] = tiles[tile_idx-1]

    return global_image

def retrieve_timecourse(path, row, col, channel, start, end, tophat=None, blur=None):
    """
    Wrapper for retrieve_and_stitch_image that gets the entire timecourse for a channel.
    
    Parameters
    ----------
    Same as retrieve_and_stitch_image, and also:
    start : first frame to get (1-indexed)
    end : last timepoint to get
    
    Returns
    -------
    List of Pillow Images in order.
    """
    timecourse = []
    for t in tqdm(range(start, end+1)):
        image = retrieve_and_stitch_image(path, row, col, channel, t, tophat, blur)
        timecourse.append(image)
        
    return timecourse

### Movie Generation

In [None]:
from matplotlib.animation import FuncAnimation
from matplotlib import animation, rc
import matplotlib.colors

def render_movie_for_timecourse(tc, filename, size=(2250,2000)):
    """
    Given a timecourse (e.g. from retrieve_timecourse), render a movie.
    """
    
    data = []
    for img in tc:
        img = cv2.resize(img, dsize=size, interpolation=cv2.INTER_NEAREST)
        data.append(img)
        
    # Normalize the colormap based on timepoint 24
    _vmin = np.mean(np.asarray(data[24]))-2*np.std(np.asarray(data[24]))
    _vmax = np.mean(np.asarray(data[24]))+2*np.std(np.asarray(data[24]))
    
    n_frames = len(data)
    
    fig = plt.figure(figsize=(10,10))
    plot = plt.imshow(data[0], vmax=_vmax, vmin=_vmin, cmap='viridis')
    text = plt.text(.1, -.1, '', fontsize=20, transform=plt.gca().transAxes)
    plt.rcParams["axes.grid"] = False

    def init():
        plot.set_data(data[0])
        text.set_text('t=0h')
        return [plot, text]

    def update(j):
        plot.set_data(data[j])
        text.set_text('t=%ih'%(j+1))
        return [plot]

    anim = FuncAnimation(fig, update, init_func=init, frames=n_frames, interval=500, blit=True)
    anim.save('%s.mp4'%(filename), dpi=300)
    print('Animation rendered successfully: %s' % filename)
    
def render_overlay_movie(tc_base, tc_overlay, filename, size=(2250,2000)):
    """
    Overlay tc_overlay on tc_base and render an animation. E.g. GFP over brightfield.
    This method drops pixels below μ - 1σ in the overlay channel, and overlays at 50% opacity.
    """
    
    # Resize all the images
    base_data = []
    overlay_data = []
    for base_img, overlay_img in zip(tc_base, tc_overlay):
        
        base_img = cv2.resize(base_img, dsize=(2250, 2000), interpolation=cv2.INTER_NEAREST)
        overlay_img = cv2.resize(overlay_img, dsize=(2250, 2000), interpolation=cv2.INTER_NEAREST)
                
        # Drop pixels less than the mean in the overlay image
        overlay_img = np.array(overlay_img)
        overlay_img[overlay_img < np.mean(overlay_img)-np.std(overlay_img)] = 0
        
        base_data.append(base_img)
        overlay_data.append(overlay_img)
    
    # Make colormaps
    cmap_base = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black", "white"])
    cmap_overlay = matplotlib.colors.LinearSegmentedColormap.from_list("", [(0,0,0,0), "green"])
    
    # Generate colormap limits based on timepoint 24
    baselim = [np.mean(np.asarray(base_data[24]))-2*np.std(np.asarray(base_data[24])), np.mean(np.asarray(base_data[24]))+2*np.std(np.asarray(base_data[24]))]
    overlaylim = [np.mean(np.asarray(overlay_data[24]))-2*np.std(np.asarray(overlay_data[24])), np.mean(np.asarray(overlay_data[24]))+2*np.std(np.asarray(overlay_data[24]))]
    
    n_frames = len(base_data)
    
    fig = plt.figure(figsize=(10,10))
    plot = plt.imshow(base_data[0], clim=baselim, cmap=cmap_base, alpha=1)
    plt.imshow(overlay_data[0], clim=overlaylim, cmap=cmap_overlay, alpha=.5)

    text = plt.text(.1, -.1, '', fontsize=20, transform=plt.gca().transAxes)
    
    def init():
        plt.imshow(base_data[0], clim=baselim, cmap=cmap_base, alpha=1)
        plt.imshow(overlay_data[0], clim=overlaylim, cmap=cmap_overlay, alpha=.5)
        text.set_text('t=0h')
        return [plot, text]

    def update(j):
        plt.clf()
        plt.imshow(base_data[j], clim=baselim, cmap=cmap_base, alpha=1)
        plt.imshow(overlay_data[j], clim=overlaylim, cmap=cmap_overlay, alpha=.5)
        plt.text(.1, -.1, 't=%ih'%(j+1), fontsize=20, transform=plt.gca().transAxes)
        return [plot]

    anim = FuncAnimation(fig, update, init_func=init, frames=n_frames, interval=500, blit=True)
    anim.save('%s.mp4'%(filename), dpi=300)
    print('Overlay animation rendered successfully: %s' % filename)

### Helper Functions (visualization, etc)

In [None]:
def viz(img, dpi=300, save=None, cmap=None):
    plt.rcParams["axes.grid"] = False

    _vmin = np.mean(np.asarray(img))-2*np.std(np.asarray(img))
    _vmax = np.mean(np.asarray(img))+2*np.std(np.asarray(img))

    fig = plt.figure(figsize=(10,10),dpi=dpi)
    colormap = 'viridis'
    if cmap != None:
        colormap = cmap
    plot = plt.imshow(img, vmax=_vmax, vmin=_vmin, cmap=colormap)
    
    if save != None:
        plt.savefig('%s.png'%save, dpi=dpi)
        plt.close()

### Validation - Image Retrieval, Animation, Segmentation

In [None]:
def make_overlay(gfp, bf, filename, color):
    # Make colormaps for brightfield, segmentation, and target channels
    cmap_base = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black", "white"])
    cmap_red = matplotlib.colors.LinearSegmentedColormap.from_list("", [(0,0,0,0), color])

    # Calculate color thresholds (μ-2σ, μ+2σ)
    bfmin = np.mean(np.asarray(bf))-2*np.std(np.asarray(bf))
    bfmax = np.mean(np.asarray(bf))+2*np.std(np.asarray(bf))

    gfpmin = np.mean(np.asarray(gfp))-2*np.std(np.asarray(gfp))
    gfpmax = np.mean(np.asarray(gfp))+2*np.std(np.asarray(gfp))

    # Make the plot
    fig = plt.figure(figsize=(10,10),dpi=1000)
    plot = plt.imshow(bf, vmax=bfmax, vmin=bfmin, cmap=cmap_base)
    plt.imshow(gfp, vmax=gfpmax, vmin=gfpmin, cmap=cmap_red, alpha=.75)
    plt.savefig('%s.tiff'%filename, dpi=1000)
    plt.close()
def make_triple_overlay(gfp, mCh, bf, filename):
    # Make colormaps for brightfield, segmentation, and target channels
    cmap_base = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black", "white"])
    cmap_red = matplotlib.colors.LinearSegmentedColormap.from_list("", [(0,0,0,0), 'red'])
    cmap_green = matplotlib.colors.LinearSegmentedColormap.from_list("", [(0,0,0,0), 'green'])
    # Calculate color thresholds (μ-2σ, μ+2σ)
    bfmin = np.mean(np.asarray(bf))-2*np.std(np.asarray(bf))
    bfmax = np.mean(np.asarray(bf))+2*np.std(np.asarray(bf))

    gfpmin = np.mean(np.asarray(gfp))-2*np.std(np.asarray(gfp))
    gfpmax = np.mean(np.asarray(gfp))+2*np.std(np.asarray(gfp))

    mchmin = np.mean(np.asarray(mCh))-2*np.std(np.asarray(mCh))
    mchmax = np.mean(np.asarray(mCh))+2*np.std(np.asarray(mCh))
    
    # Make the plot
    fig = plt.figure(figsize=(10,10),dpi=1000)
    plot = plt.imshow(bf, vmax=bfmax, vmin=bfmin, cmap=cmap_base)
    plt.imshow(gfp, vmax=gfpmax, vmin=gfpmin, cmap=cmap_green, alpha=.5)
    plt.imshow(mCh, vmax=mchmax, vmin=mchmin, cmap=cmap_red, alpha=.5)
    plt.savefig('%s.tiff'%filename, dpi=1000)
    plt.close()
# make_overlay(edges1, img, 'edges1_overlay')

In [None]:
print('channel 1')
img = retrieve_and_stitch_image(data_path, 1, 1, 1, 1)
viz(img, cmap='viridis', save='channel_1')
print('channel 2')
img = retrieve_and_stitch_image(data_path, 1, 1, 2, 1)
viz(img, cmap='viridis', save='channel_2')
print('channel 3')
img = retrieve_and_stitch_image(data_path, 1, 1, 3, 1)
viz(img, cmap='viridis', save='channel_3')
print('channel 4')
img = retrieve_and_stitch_image(data_path, 1, 1, 4, 1)
viz(img, cmap='viridis', save='channel_4')
