# Align and remove blanks

Adapted from Giulia Vallardi's ImageJ macro, this notebook removes any blank frames from timelapse experiments and aligns the images. 

"Fiji macro to remove over- and under-exposed images, and align the image stacks

The settings for the alignments are: 
registration by Translation > only modify XY coordinates
Shrinkage constrain activated (this model allows a better registration based on all images, not using a reference image. It is more time consuming though)
Transform matrices are saved during registration and then applied to the other channels during transformation."

# to - do: decide how to import dask octopus

In [4]:
import octopuslite

ModuleNotFoundError: No module named 'octopuslite'

In [1]:
import os
import glob
import enum
import numpy as np
from pystackreg import StackReg
from skimage import io
from tqdm import tqdm
from daskoctopus import DaskOctopusLiteLoader
from skimage import transform as tf

# Find images, organise into raw folder and load using dask octo

In [2]:
### define root directory and specific experiment and location (will later make iterable)
root_dir = '/home/nathan/data/kraken/commitment/test/'
expt = "MK0003"
pos = "Pos1"

In [3]:
### create new subdir of for raw files and move them all there
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
    files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
    for file in files:
        os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

In [4]:
### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

Using cropping: None


# Find blank or overexposed images and display average channel brightness

In [6]:
%%time
# pixel range criteria
max_pixel, min_pixel = 200, 2

### find mean values ### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

mean_arrays = {}
dodgy_frame_list = set([])
for channel in tqdm(images.channels):
    print(f'Finding mean values of {channel.name.lower()} images')
    mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
    for frame, mean_value in enumerate(mean_arrays[channel.name]):
        if max_pixel < mean_value or mean_value < min_pixel:
            dodgy_frame_list.add(frame)
dodgy_frame_list = list(dodgy_frame_list)

print('Number of under/over-exposed frames:', len(dodgy_frame_list))

Using cropping: None


  0%|          | 0/4 [00:00<?, ?it/s]

Finding mean values of brightfield images


 25%|██▌       | 1/4 [00:03<00:09,  3.16s/it]

Finding mean values of gfp images


 50%|█████     | 2/4 [00:06<00:06,  3.20s/it]

Finding mean values of rfp images


 75%|███████▌  | 3/4 [00:09<00:03,  3.22s/it]

Finding mean values of irfp images


100%|██████████| 4/4 [00:12<00:00,  3.20s/it]

Number of under/over-exposed frames: 8
CPU times: user 27.1 s, sys: 6.09 s, total: 33.2 s
Wall time: 13.6 s





# Select reference image to base alignment around

In [7]:
print('Average channel brightness for selection of reference image:')
for channel in images.channels:
    print(f'{channel.value}: {channel.name}:', np.mean(mean_arrays[channel.name]))

Average channel brightness for selection of reference image:
0: BRIGHTFIELD: 28.949467
1: GFP: 62.79905
2: RFP: 6.4401646
3: IRFP: 76.01225


In [8]:
# manually select reference channel by adding index
reference_channel = images.channels[1]
# automatically select reference channel from max average pixel value (ie. brightest channel)
#reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
reference_image = images[reference_channel.name]
reference_channel.name

'GFP'

## Set cropped area of reference image to base alignment around (whole image struggles to compute)

In [31]:
crop_area = 500
reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
reference_image.shape

(1200, 500, 500)

# Register alignment and save out

In [10]:
%%time
# create operator using transformation type (translation)
sr = StackReg(StackReg.TRANSLATION) 

# register each frame to the previous as transformation matrices/tensor
transform_tensor = sr.register_stack(reference_image, reference = 'previous').astype(np.int8)

# save out transform tensor
np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)



CPU times: user 23min 39s, sys: 2min 19s, total: 25min 58s
Wall time: 25min 56s


In [11]:
transform_tensor.shape

(1200, 3, 3)

# Apply transformation matrix to all channels and save out images

In [12]:
%%time
### iterating over channels
# create aligned image dir if does not exist 
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
# iterate over channels
for channel in images.channels:
    print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
    #iterate over all images in channel
    for i in tqdm(range(len(images[channel.name]))):
        # skip dodgy frames and don't save out into aligned folder
        if i in dodgy_frame_list:
            continue
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1 / 4


100%|██████████| 1200/1200 [02:27<00:00,  8.12it/s]


Aligning gfp channel 2 / 4


100%|██████████| 1200/1200 [02:21<00:00,  8.51it/s]


Aligning rfp channel 3 / 4


100%|██████████| 1200/1200 [02:19<00:00,  8.59it/s]


Aligning irfp channel 4 / 4


100%|██████████| 1200/1200 [02:23<00:00,  8.35it/s]

CPU times: user 6min 26s, sys: 16.4 s, total: 6min 43s
Wall time: 9min 32s





# Check alignment using Napari

In [13]:
import napari

In [14]:
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))

viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255])
                     #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

Using cropping: None


# Batch execute

In [1]:
root_dir = '/home/nathan/data/kraken/commitment/test/'

In [None]:
%%time
alignment(expt_list = ['MK0000', 'MK0001', 'MK0002', 'MK0003'], 
          max_pixel = 200, 
          min_pixel = 2, 
          crop_area = 500)

In [None]:
viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255], colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

In [5]:
def alignment(expt_list, max_pixel, min_pixel, crop_area):

    ### Iterate over all experiments defined in expt_list
    for expt in expt_list:
        # Find all positions in that experiment
        pos_list = [pos for pos in os.listdir(os.path.join(root_dir, expt)) if 'Pos' in pos]
        ### Iterate over all positions in that experiment
        for pos in pos_list:
            ### create new subdir of for raw files and move them all there
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
                files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
                for file in files:
                    os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # create empty dicts and sets to store values in 
            mean_arrays = {}
            dodgy_frame_list = set([])
            # iterate over channels
            for channel in tqdm(images.channels):
                print(f'Finding mean values of {channel.name.lower()} images', pos, expt)
                # find mean pixel values for each channel
                mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if frame does not meet inclusion criteria then add to dodgy list
                        dodgy_frame_list.add(frame)
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list), pos, expt)

            ### create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))

            ### Automatically pick reference image to perform alignment on 
            # Pick channel based on index of brightest channel from maximum mean pixel array
            reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
            # Define reference images
            reference_image = images[reference_channel.name]
            reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
            reference_image.shape
            print('Automatically selected and cropped reference image:', reference_channel.name)

            ### Register alignment
            print('Registering alignment for', pos, expt)
            # create operator using transformation type (translation)
            sr = StackReg(StackReg.TRANSLATION) 
            # register each frame using reference image to the previous as transformation matrices/tensor
            transform_tensor = sr.register_stack(reference_image, reference = 'previous').astype(np.uint8)
            # save out transform tensor
            np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

            ### Perform alignment
            # create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
            # iterate over channels
            for channel in images.channels:
                print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
                #iterate over all images in channel
                for i in tqdm(range(len(images[channel.name]))):
                    # skip dodgy frames and don't save out into aligned folder
                    if i in dodgy_frame_list:
                        continue
                    # load specific transform matrix for that frame
                    transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
                    # transform image
                    transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
                    # set transformed image pathname by editing base dir
                    fn = images.files(channel.name)[i].replace('_raw', '_aligned')
                    # save trans image out
                    io.imsave(fn, transformed_image, check_contrast=False)

# Compile stacks and save out?

In [None]:
start()
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
stop()