# Align and remove under/overexposed images

This notebook removes any under/overexposed frames from timelapse experiments and aligns the images. 

### The aligned images will be used later in the object localisation and tracking steps. 

The structure of this notebook is:

1. Load images using the octopuslite dask loader.
2. Find over/underexposed images by measuring each channel and frame for average pixel intensity.
3. Select a reference channel to center the alignment on
4. Register alignment and save out transformation tensor
5. (Optional) Apply transformation matrix to all channels and save out images
6. Check images using Napari
7. Function to iterate over many experiments, many positions

In [11]:
import os
import glob
import enum
import numpy as np
from pystackreg import StackReg
from skimage.io import imsave
from tqdm import tqdm
from octopuslite import DaskOctopusLiteLoader, image_generator
from skimage import transform as tf

## 1. Find images, organise and load using octopuslite

Define root directory and specific experiment and location to align

In [10]:
root_dir = '/home/nathan/data/'
expt = 'ND0012'
pos = 'Pos0'

Create new subdir for image files and move them all there. This is so that the miscelleanous non-image files (such as transformation matrices and tracking files) are easy to access later on and not lost amongst many single frame timelapse images.

In [3]:
image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
if not os.path.exists(image_path):
    os.mkdir(image_path)
    files = sorted(glob.glob(f'{root_dir}/{expt}/{pos}/*.tif'))
    for file in files:
        os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_images'))

Lazily load image array and associated information using dask octopuslite-loader and display channels found. Note the optional background removal is not invoked at this stage.

In [4]:
images = DaskOctopusLiteLoader(image_path, remove_background = False)
print([channel.name for channel in images.channels])

['BRIGHTFIELD', 'GFP', 'RFP', 'IRFP']


## 2. Identify under/overexposed images and display average channel brightness

In [5]:
%%time
# pixel range criteria
max_pixel, min_pixel = 200, 2
# set empty dict arrays for mean values 
mean_arrays = {}
# set for dodgy frames (only unique entries)
dodgy_frame_list = set([])
#iterate over channels
for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
    # find mean value of each frame in each channel
    mean_arrays[channel.name] = [np.mean(img) for img in image_generator(images.files(channel.name))]
    # iterate over frames
    for frame, mean_value in enumerate(mean_arrays[channel.name]):
        # check to see if mean frame pixel value meets criteria
        if max_pixel < mean_value or mean_value < min_pixel:
            # if so add to delete list
            dodgy_frame_list.add(frame)
# format delete list to only include single values
dodgy_frame_list = list(dodgy_frame_list)
print('Number of under/over-exposed frames:', len(dodgy_frame_list))

Finding mean values of image channels: 100%|██████████| 4/4 [00:13<00:00,  3.50s/it]

Number of under/over-exposed frames: 19
CPU times: user 11.5 s, sys: 2.49 s, total: 14 s
Wall time: 14 s





### 2a. Filter blanks from main image folder into separate directory

This step is optional as there is a parameter within `DaskOctopusLiteLoader` that filters the images, but employing that every time you load images is time consuming for large data sets

In [30]:
# check if blanks dir exists and make if not
if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_blanks'):
    os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_blanks')
# move blank images into this directory
for channel in images.channels:
    for f in images.files(channel.name):
        for i in dodgy_frame_list:
            if str(i).zfill(9) in f:
                os.rename(f, f.replace('_images', '_blanks'))
# reload image arrays now that blanks filtered
images = DaskOctopusLiteLoader(image_path, remove_background = False)
images['gfp']

Unnamed: 0,Array,Chunk
Bytes,3.48 GiB,2.18 MiB
Shape,"(1638, 1352, 1688)","(1, 1352, 1688)"
Count,4914 Tasks,1638 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 3.48 GiB 2.18 MiB Shape (1638, 1352, 1688) (1, 1352, 1688) Count 4914 Tasks 1638 Chunks Type uint8 numpy.ndarray",1688  1352  1638,

Unnamed: 0,Array,Chunk
Bytes,3.48 GiB,2.18 MiB
Shape,"(1638, 1352, 1688)","(1, 1352, 1688)"
Count,4914 Tasks,1638 Chunks
Type,uint8,numpy.ndarray


## 3. Select reference image to base alignment around

Display the average intensities of each channel. The automatically-measured brightest channel isn't necessarily the best.

In [31]:
print('Average channel brightness for selection of reference image:')
for channel in images.channels:
    print(f'{channel.value}: {channel.name}:', np.mean(mean_arrays[channel.name]))

Average channel brightness for selection of reference image:
0: BRIGHTFIELD: 40.96403863128827
1: GFP: 45.97970142886046
2: RFP: 6.969645431239206
3: IRFP: 47.732877102157886


In [33]:
# manually select reference channel by adding index
reference_channel = images.channels[1]
# automatically select reference channel from max average pixel value (ie. brightest channel)
#reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
reference_channel.name

'GFP'

#### 3a. Set cropped area of reference image to base alignment around 
Cropping as alignment struggles on large arrays such as `shape = (1200,1353,1682)`, this step is optional but you will still need to run `.compute()` on the dask array to load the image into memory to perform the alignment.

In [37]:
%%time
# crop central window out of reference image
reference_image = DaskOctopusLiteLoader(image_path, 
                                        crop = (500, 500)
                                       )[reference_channel.name].compute()
reference_image.shape

Using cropping: (500, 500)
CPU times: user 4min 53s, sys: 3min 39s, total: 8min 32s
Wall time: 55.6 s


(1638, 500, 500)

## 4. Register alignment and save out transformation tensor
Transformation tensor is a 3D series of transformation matrices over time

In [42]:
%%time
# create operator using transformation type (translation)
sr = StackReg(StackReg.TRANSLATION) 

# register each frame to the previous as transformation matrices/tensor
transform_tensor = sr.register_stack(reference_image, reference = 'previous')

# save out transform tensor
np.save(f'{root_dir}/{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy', transform_tensor)



CPU times: user 3min 19s, sys: 3.78 s, total: 3min 23s
Wall time: 3min 23s


In [47]:
transform_tensor.shape

(1638, 3, 3)

## 5. (Optional) Apply transformation matrix to all channels and save out images in separate directory

Consumes a lot of time and space to replicate images with minor translational shifts, it is advised to just use the transform parameter in the `DaskOctopusLiteLoader`. 

In [14]:
%%time
### iterating over channels
# create aligned image dir if does not exist 
if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_aligned'):
    os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_aligned')
# iterate over channels
for channel in images.channels:
    #iterate over all images in channel
    for i in tqdm(range(len(transform_tensor)), 
                  desc = f'Aligning {channel.name.lower()} channel {channel.value+1}/{len(images.channels)}'):
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],
                                                 rotation = None)
        # transform image
        transformed_image = (tf.warp(images[channel.name][i,...].compute(), 
                                     transform_matrix, preserve_range=True)).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_images', '_aligned')
        # save trans image out
        imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1/4: 100%|██████████| 1067/1067 [02:46<00:00,  6.42it/s]
Aligning gfp channel 2/4: 100%|██████████| 1067/1067 [02:48<00:00,  6.32it/s]
Aligning rfp channel 3/4: 100%|██████████| 1067/1067 [03:06<00:00,  5.72it/s]
Aligning irfp channel 4/4: 100%|██████████| 1067/1067 [03:04<00:00,  5.79it/s]

CPU times: user 5min 24s, sys: 51.4 s, total: 6min 15s
Wall time: 11min 45s





## 6. Check alignment using Napari

In [53]:
import napari

In [56]:
aligned_images = DaskOctopusLiteLoader(image_path, 
                                       #crop = (1200,1600), 
                                       transforms = f'{root_dir}/{expt}/{pos}/gfp_transform_tensor.npy',
                                       remove_background=False)
viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], 
                     name = channel.name, 
                     blending = 'additive', 
                     contrast_limits = [0,255])

## Batch execute

Do all of the above but for many experiment IDs and many positions

In [None]:
%%time
alignment(root_dir = '/home/nathan/data/kraken/ras',
          expt_list = ['ND0012'],
          max_pixel = 200, 
          min_pixel = 2,
          alignment_channel = 'gfp',
          crop_area = 500, 
          save_out_images = False, ### this does not save out a copy of the images, only the transformation matrix
          overwrite = False) ### this checks for any prexisting transformations and does not overwrite

Finding mean values of image channels:  75%|███████▌  | 3/4 [04:03<01:19, 79.86s/it]

In [42]:
def alignment(root_dir, expt_list, max_pixel, min_pixel, alignment_channel, crop_area, save_out_images, overwrite):

    ### Iterate over all experiments defined in expt_list
    for expt in expt_list:
        # Find all positions in that experiment
        pos_list = [pos for pos in os.listdir(f'{root_dir}/{expt}') 
                    if 'Pos' in pos 
                    and os.path.isdir(f'{root_dir}/{expt}/{pos}')]        
        ### Iterate over all positions in that experiment
        for pos in pos_list:
            ### check if overwrite param is false check if raw directory already created and if type of transform file already exists and decide whether to skip pos
            if not overwrite and os.path.exists (f'{root_dir}/{expt}/{pos}/{pos}_images') and glob.glob(f'{root_dir}/{expt}/{pos}/*transform*.npy'):
                print(glob.glob(f'{root_dir}/{expt}/{pos}/*transform*.npy'), f'file found, skipping {expt}/{pos}')
                continue
            
            print(f'Starting {expt}/{pos}')
            
            ### create new subdir of for raw files and move them all there
            image_path = f'{root_dir}/{expt}/{pos}/{pos}_images'
            if not os.path.exists(image_path):
                os.mkdir(image_path)
                files = sorted(glob.glob(f'{root_dir}/{expt}/{pos}/*.tif'))
                for file in files:
                    os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_images'))
            
            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(image_path, remove_background= False)

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # set empty dict arrays for mean values 
            mean_arrays = {}
            # set for dodgy frames (only unique entries)
            dodgy_frame_list = set([])
            #iterate over channels
            for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
                # find mean value of each frame in each channel
                mean_arrays[channel.name] = [np.mean(img) for img in image_generator(images.files(channel.name))]
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    # check to see if mean frame pixel value meets criteria
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if so add to delete list
                        dodgy_frame_list.add(frame)
            # format delete list to only include single values
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list))

            # check if blanks dir exists and make if not
            if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_blanks'):
                os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_blanks')
            # move blank images into this directory
            for channel in images.channels:
                for f in images.files(channel.name):
                    for i in dodgy_frame_list:
                        if str(i).zfill(9) in f:
                            os.rename(f, f.replace('_images', '_blanks'))
            
            # crop central window out of reference image with blanks removed
            reference_image = DaskOctopusLiteLoader(image_path, 
                                                    crop = (crop_area, crop_area)
                                                   )[alignment_channel].compute() 
    
            ### Register alignment
            print('Registering alignment for', pos, expt)
            # create operator using transformation type (translation)
            sr = StackReg(StackReg.TRANSLATION) 
            # register each frame to the previous as transformation matrices/tensor
            transform_tensor = sr.register_stack(reference_image, reference = 'previous')
            # save out transform tensor
            np.save(f'{root_dir}/{expt}/{pos}/{alignment_channel}_transform_tensor.npy', transform_tensor)
            
            if save_out_images:
                ### Perform alignment
                # create aligned image dir if does not exist 
                if not os.path.exists(f'{root_dir}/{expt}/{pos}/{pos}_aligned'):
                    os.mkdir(f'{root_dir}/{expt}/{pos}/{pos}_aligned')
                # iterate over channels
                for channel in images.channels:
                    #iterate over all images in channel
                    for i in tqdm(range(len(transform_tensor)), 
                                  desc = f'Aligning {channel.name.lower()} channel {channel.value+1}/{len(images.channels)}'):
                        # load specific transform matrix for that frame
                        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],
                                                                 rotation = None)
                        # transform image
                        transformed_image = (tf.warp(filtered_images[channel.name][i,...].compute(), 
                                                     transform_matrix, preserve_range=True)).astype(np.uint8)
                        # set transformed image pathname by editing base dir
                        fn = images.files(channel.name)[i].replace('_images', '_aligned')
                        # save trans image out
                        imsave(fn, transformed_image, check_contrast=False)