# Align and remove under/overexposed images

Adapted from Giulia Vallardi's ImageJ macro, this notebook removes any under/overexposed frames from timelapse experiments and aligns the images. 

The structure of this notebook is:

1. Find images, organise containing directory as a 'raw images' folder and load using the octopuslite dask loader.
2. Find over/underexposed images by measuring each channel and frame for average pixel intensity.
3. Select a reference channel to center the alignment on
4. Register alignment and save out transformation tensor
5. Apply transformation matrix to all channels and save out images
6. Check images using Napari

In [1]:
import os
import glob
import enum
import numpy as np
from pystackreg import StackReg
from skimage import io
from tqdm import tqdm
from octopuslite import DaskOctopusLiteLoader
from skimage import transform as tf

## 1. Find images, organise and load using octopuslite

In [2]:
### define root directory and specific experiment and location (will later make iterable)
root_dir = '/home/nathan/data/kraken/test/'
expt = "ND0001"
pos = "Pos14"

In [9]:
### create new subdir of for raw files and move them all there
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
    files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
    for file in files:
        os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

In [3]:
### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

Using cropping: None


## 2. Identify under/overexposed images and display average channel brightness

In [4]:
%%time
# pixel range criteria
max_pixel, min_pixel = 200, 2
# set empty dict arrays for mean values 
mean_arrays = {}
# set for dodgy frames (only unique entries)
dodgy_frame_list = set([])
#iterate over channels
for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
    # find mean value of each frame in each channel
    mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
    # iterate over frames
    for frame, mean_value in enumerate(mean_arrays[channel.name]):
        # check to see if mean frame pixel value meets criteria
        if max_pixel < mean_value or mean_value < min_pixel:
            # if so add to delete list
            dodgy_frame_list.add(frame)
# format delete list
dodgy_frame_list = list(dodgy_frame_list)
print('Number of under/over-exposed frames:', len(dodgy_frame_list))

Finding mean values of image channels: 100%|██████████| 4/4 [00:10<00:00,  2.53s/it]

Number of under/over-exposed frames: 13
CPU times: user 28.5 s, sys: 3.77 s, total: 32.3 s
Wall time: 10.1 s





In [5]:
dodgy_frame_list

[1018, 450, 294, 1190, 936, 618, 1037, 270, 430, 19, 888, 729, 922]

#### 2a. Filtering to remove blank or overexposed frames from image array and mean value arrays

In [6]:
filtered_images= {}
for channel in images.channels:
    filtered_images[channel.name] = np.delete(images[channel], dodgy_frame_list, axis = 0)
    mean_arrays[channel.name] = np.delete(mean_arrays[channel.name], dodgy_frame_list, axis = 0) 

In [7]:
filtered_images

{'BRIGHTFIELD': dask.array<concatenate, shape=(1187, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'GFP': dask.array<concatenate, shape=(1187, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'RFP': dask.array<concatenate, shape=(1187, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'IRFP': dask.array<concatenate, shape=(1187, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>}

## 3. Select reference image to base alignment around

In [8]:
print('Average channel brightness for selection of reference image:')
for channel in images.channels:
    print(f'{channel.value}: {channel.name}:', np.mean(mean_arrays[channel.name]))

Average channel brightness for selection of reference image:
0: BRIGHTFIELD: 34.03064165089353
1: GFP: 28.038749247736085
2: RFP: 3.7034679555476733
3: IRFP: 46.306161060584216


In [9]:
# manually select reference channel by adding index
reference_channel = images.channels[1]
# automatically select reference channel from max average pixel value (ie. brightest channel)
#reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
reference_image = filtered_images[reference_channel.name]
reference_channel.name

'GFP'

#### 3a. Set cropped area of reference image to base alignment around 
Optional step as alignment struggles on 1200 frame (1353,1682) pixel images

In [10]:
crop_area = 500
# crop central window out of reference image
reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)//2)
                                  :int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),
                                  int((reference_image.shape[1]-crop_area)/2)
                                  :int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)]
reference_image.shape

(1187, 500, 500)

## 4. Register alignment and save out transformation tensor
Transformation tensor is a 3D series of transformation matrices over time

In [15]:
%%time
# create operator using transformation type (translation)
sr = StackReg(StackReg.TRANSLATION) 

# register each frame to the previous as transformation matrices/tensor
transform_tensor = sr.register_stack(reference_image, reference = 'previous')

# save out transform tensor
np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

CPU times: user 2min 18s, sys: 6.91 s, total: 2min 25s
Wall time: 2min 25s


 `/home/nathan/analysis/miniconda3/envs/cellx/lib/python3.9/site-packages/pystackreg/pystackreg.py:379: UserWarning: Detected axis 2 as the possible time axis for the stack due to its low variability, but axis 0 was supplied for registration. Are you sure you supplied the correct axis?
  warnings.warn(`

In [13]:
transform_tensor.shape

(1187, 3, 3)

## 5. Apply transformation matrix to all channels and save out images

In [17]:
%%time
### iterating over channels
# create aligned image dir if does not exist 
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
# iterate over channels
for channel in images.channels:
    #iterate over all images in channel
    for i in tqdm(range(len(transform_tensor)), desc = f'Aligning {channel.name.lower()} channel {channel.value+1}/{len(images.channels)}'):#filtered_images[channel.name]))):
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(filtered_images[channel.name][i,...].compute(), transform_matrix, preserve_range=True)).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1/4: 100%|██████████| 1187/1187 [01:26<00:00, 13.76it/s]
Aligning gfp channel 2/4: 100%|██████████| 1187/1187 [01:22<00:00, 14.44it/s]
Aligning rfp channel 3/4: 100%|██████████| 1187/1187 [01:21<00:00, 14.48it/s]
Aligning irfp channel 4/4: 100%|██████████| 1187/1187 [01:26<00:00, 13.77it/s]

CPU times: user 5min 22s, sys: 11.2 s, total: 5min 33s
Wall time: 5min 36s





## 6. Check alignment using Napari

In [18]:
import napari

In [19]:
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))#, crop = (1200,1600), remove_background=False)

viewer = napari.Viewer()
for channel in aligned_images.channels:
    #if channel.name == 'IRFP':
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255])
                     #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

Using cropping: None


## Batch execute

Do all of the above but for many experiment IDs and many positions

In [20]:
root_dir = '/home/nathan/data/kraken/ras/'

In [None]:
%%time
alignment(expt_list = ['ND0000', 'ND0001'],
          max_pixel = 200, 
          min_pixel = 2, 
          crop_area = 500)

Using cropping: None


Finding mean values of image channels: 100%|██████████| 4/4 [01:05<00:00, 16.39s/it]


Number of under/over-exposed frames: 5
Automatically selected and cropped reference image: IRFP
Registering alignment for Pos13 ND0000


Aligning brightfield channel 1/4: 100%|██████████| 431/431 [00:25<00:00, 16.58it/s]
Aligning gfp channel 2/4: 100%|██████████| 431/431 [00:27<00:00, 15.63it/s]
Aligning rfp channel 3/4: 100%|██████████| 431/431 [00:27<00:00, 15.95it/s]
Aligning irfp channel 4/4: 100%|██████████| 431/431 [00:27<00:00, 15.84it/s]


Using cropping: None


Finding mean values of image channels: 100%|██████████| 4/4 [01:08<00:00, 17.00s/it]


Number of under/over-exposed frames: 1
Automatically selected and cropped reference image: IRFP
Registering alignment for Pos5 ND0000


Aligning brightfield channel 1/4: 100%|██████████| 435/435 [00:26<00:00, 16.56it/s]
Aligning gfp channel 2/4: 100%|██████████| 435/435 [00:28<00:00, 15.48it/s]
Aligning rfp channel 3/4:  52%|█████▏    | 226/435 [00:13<00:13, 15.84it/s]

In [None]:
viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255], colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

In [23]:
def alignment(expt_list, max_pixel, min_pixel, crop_area):

    ### Iterate over all experiments defined in expt_list
    for expt in expt_list:
        # Find all positions in that experiment
        pos_list = [pos for pos in os.listdir(os.path.join(root_dir, expt)) if 'Pos' in pos]
        ### Iterate over all positions in that experiment
        for pos in pos_list:
            ### create new subdir of for raw files and move them all there
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
                files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
                for file in files:
                    os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # set empty dict arrays for mean values 
            mean_arrays = {}
            # set for dodgy frames (only unique entries)
            dodgy_frame_list = set([])
            #iterate over channels
            for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
                # find mean value of each frame in each channel
                mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    # check to see if mean frame pixel value meets criteria
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if so add to delete list
                        dodgy_frame_list.add(frame)
            # format delete list
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list))

            # create new image dicts with dodgy frames removed
            filtered_images= {}
            for channel in images.channels:
                # delete dodgy frames from images and mean value arrays
                filtered_images[channel.name] = np.delete(images[channel], dodgy_frame_list, axis = 0)
                mean_arrays[channel.name] = np.delete(mean_arrays[channel.name], dodgy_frame_list, axis = 0) 
            
            ### Automatically pick reference image to perform alignment on 
            # Pick channel based on index of brightest channel from maximum mean pixel array
            reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
            # Define reference images
            reference_image = filtered_images[reference_channel.name]
            reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
            reference_image.shape
            print('Automatically selected and cropped reference image:', reference_channel.name)

            ### Register alignment
            print('Registering alignment for', pos, expt)
            # create operator using transformation type (translation)
            sr = StackReg(StackReg.TRANSLATION) 
            # register each frame using reference image to the previous as transformation matrices/tensor
            transform_tensor = sr.register_stack(reference_image, reference = 'previous').astype(np.uint8)
            # save out transform tensor
            np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

            ### Perform alignment
            # create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
            # iterate over channels
            for channel in images.channels:
                #iterate over all images in channel
                for i in tqdm(range(len(transform_tensor)), desc = f'Aligning {channel.name.lower()} channel {channel.value+1}/{len(images.channels)}'):#filtered_images[channel.name]))):
                    # load specific transform matrix for that frame
                    transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
                    # transform image
                    transformed_image = (tf.warp(filtered_images[channel.name][i,...].compute(), transform_matrix, preserve_range=True)).astype(np.uint8)
                    # set transformed image pathname by editing base dir
                    fn = images.files(channel.name)[i].replace('_raw', '_aligned')
                    # save trans image out
                    io.imsave(fn, transformed_image, check_contrast=False)