# Align and remove blanks

Adapted from Giulia Vallardi's ImageJ macro, this notebook removes any blank frames from timelapse experiments and aligns the images. 

"Fiji macro to remove over- and under-exposed images, and align the image stacks

The settings for the alignments are: 
registration by Translation > only modify XY coordinates
Shrinkage constrain activated (this model allows a better registration based on all images, not using a reference image. It is more time consuming though)
Transform matrices are saved during registration and then applied to the other channels during transformation."

In [1]:
import os
import glob
import enum
import numpy as np
from pystackreg import StackReg
from skimage import io
from tqdm import tqdm
from octopuslite import DaskOctopusLiteLoader
from skimage import transform as tf

# Find images, organise into raw folder and load using dask octo

In [3]:
### define root directory and specific experiment and location (will later make iterable)
root_dir = '/home/nathan/data/kraken/test/'
expt = "MK0003"
pos = "Pos15"

In [4]:
### create new subdir of for raw files and move them all there
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
    files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
    for file in files:
        os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

In [61]:
### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

Using cropping: None


In [64]:
images['brightfield']

Unnamed: 0,Array,Chunk
Bytes,0.99 GiB,2.18 MiB
Shape,"(465, 1352, 1688)","(1, 1352, 1688)"
Count,1395 Tasks,465 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 0.99 GiB 2.18 MiB Shape (465, 1352, 1688) (1, 1352, 1688) Count 1395 Tasks 465 Chunks Type uint8 numpy.ndarray",1688  1352  465,

Unnamed: 0,Array,Chunk
Bytes,0.99 GiB,2.18 MiB
Shape,"(465, 1352, 1688)","(1, 1352, 1688)"
Count,1395 Tasks,465 Chunks
Type,uint8,numpy.ndarray


# Find blank or overexposed images and display average channel brightness

In [13]:
channel.name

'IRFP'

In [15]:
images[channel.name][0]

Unnamed: 0,Array,Chunk
Bytes,2.18 MiB,2.18 MiB
Shape,"(1352, 1688)","(1352, 1688)"
Count,1396 Tasks,1 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 2.18 MiB 2.18 MiB Shape (1352, 1688) (1352, 1688) Count 1396 Tasks 1 Chunks Type uint8 numpy.ndarray",1688  1352,

Unnamed: 0,Array,Chunk
Bytes,2.18 MiB,2.18 MiB
Shape,"(1352, 1688)","(1352, 1688)"
Count,1396 Tasks,1 Chunks
Type,uint8,numpy.ndarray


In [16]:
images[channel.name]

Unnamed: 0,Array,Chunk
Bytes,0.99 GiB,2.18 MiB
Shape,"(465, 1352, 1688)","(1, 1352, 1688)"
Count,1395 Tasks,465 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 0.99 GiB 2.18 MiB Shape (465, 1352, 1688) (1, 1352, 1688) Count 1395 Tasks 465 Chunks Type uint8 numpy.ndarray",1688  1352  465,

Unnamed: 0,Array,Chunk
Bytes,0.99 GiB,2.18 MiB
Shape,"(465, 1352, 1688)","(1, 1352, 1688)"
Count,1395 Tasks,465 Chunks
Type,uint8,numpy.ndarray


In [20]:
np.mean(images[channel.name][0].compute())

71.97978245323762

In [17]:
%%time
# pixel range criteria
max_pixel, min_pixel = 200, 2

mean_arrays = {}
dodgy_frame_list = set([])
for channel in tqdm(images.channels, desc = f'Finding mean values of image channels'):
    mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
    for frame, mean_value in enumerate(mean_arrays[channel.name]):
        if max_pixel < mean_value or mean_value < min_pixel:
            dodgy_frame_list.add(frame)
dodgy_frame_list = list(dodgy_frame_list)

print('Number of under/over-exposed frames:', len(dodgy_frame_list))

Finding mean values of image channels: 100%|██████████| 4/4 [00:34<00:00,  8.65s/it]

Number of under/over-exposed frames: 7
CPU times: user 11.7 s, sys: 3.4 s, total: 15.1 s
Wall time: 34.6 s





In [25]:
dodgy_frame_list

[288, 324, 299, 109, 46, 370, 55]

# Filtering to remove blank or overexposed frames from image array and mean value arrays

In [65]:
filtered_images= {}
for channel in images.channels:
    filtered_images[channel.name] = np.delete(images[channel], dodgy_frame_list, axis = 0)
    mean_arrays[channel.name] = np.delete(mean_arrays[channel.name], dodgy_frame_list, axis = 0) 

In [27]:
filtered_images

{'BRIGHTFIELD': dask.array<concatenate, shape=(459, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'GFP': dask.array<concatenate, shape=(458, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'RFP': dask.array<concatenate, shape=(458, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'IRFP': dask.array<concatenate, shape=(458, 1352, 1688), dtype=uint8, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>}

# Select reference image to base alignment around

In [66]:
print('Average channel brightness for selection of reference image:')
for channel in images.channels:
    print(f'{channel.value}: {channel.name}:', np.mean(mean_arrays[channel.name]))

Average channel brightness for selection of reference image:
0: BRIGHTFIELD: 37.62035125299296
1: GFP: 78.58584791196907
2: RFP: 4.305421344208685
3: IRFP: 72.24970833064663


In [67]:
# manually select reference channel by adding index
# reference_channel = filtered_images['IRFP']
# automatically select reference channel from max average pixel value (ie. brightest channel)
reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
reference_image = filtered_images[reference_channel.name]
print('Reference channel:', reference_channel.name)

Reference channel: IRFP


## Set cropped area of reference image to base alignment around (whole image struggles to compute)

In [68]:
crop_area = 500
reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
reference_image.shape

(458, 500, 500)

# Register alignment and save out

In [69]:
%%time
# create operator using transformation type (translation)
sr = StackReg(StackReg.TRANSLATION) 

# register each frame to the previous as transformation matrices/tensor
transform_tensor = sr.register_stack(reference_image, reference = 'first')

# save out transform tensor
np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

CPU times: user 53 s, sys: 1.14 s, total: 54.1 s
Wall time: 54.5 s


In [33]:
transform_tensor.shape

(458, 3, 3)

# Apply transformation matrix to all channels and save out images

In [86]:
%%time
### iterating over channels
# create aligned image dir if does not exist 
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
# iterate over channels
for channel in images.channels:
    #iterate over all images in channel
    for i in tqdm(range(len(transform_tensor)), desc = f'Aligning {channel.name.lower()} channel {channel.value+1}/{len(images.channels)}'):#filtered_images[channel.name]))):
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(filtered_images[channel.name][i,...].compute(), transform_matrix, preserve_range=True)).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1 / 4


Aligning brightfield channel 1/4: 100%|██████████| 458/458 [01:22<00:00,  5.55it/s]


Aligning gfp channel 2 / 4


Aligning gfp channel 2/4: 100%|██████████| 458/458 [00:56<00:00,  8.08it/s]


Aligning rfp channel 3 / 4


Aligning rfp channel 3/4: 100%|██████████| 458/458 [00:58<00:00,  7.89it/s]


Aligning irfp channel 4 / 4


Aligning irfp channel 4/4: 100%|██████████| 458/458 [01:11<00:00,  6.39it/s]

CPU times: user 2min 35s, sys: 30.7 s, total: 3min 6s
Wall time: 4min 28s





# check images

In [75]:
import napari

In [87]:
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'), crop = (1200,1600))
old_aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
viewer = napari.Viewer()
for channel in aligned_images.channels:
#    if channel.name == 'IRFP':
        viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255])
        #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps
                         #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps
        #viewer.add_image(old_aligned_images[channel.name], name = channel.name+'raw', blending = 'additive', contrast_limits = [0,255])
                         #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

Using cropping: (1200, 1600)
Using cropping: None


# Troubleshooting jumpy transforms

-- still have jumpy transforms using first frame ref and filtered

# Check transform tensor for jumpy transtitions

In [78]:
for j, i in enumerate(transform_tensor):
    if 150 > j > 100:
        print("frame", j,'\n', i)

frame 101 
 [[  1.           0.         -13.65152047]
 [  0.           1.         -22.70402882]
 [  0.           0.           1.        ]]
frame 102 
 [[  1.           0.         -14.45267518]
 [  0.           1.         -23.35706889]
 [  0.           0.           1.        ]]
frame 103 
 [[  1.           0.         -14.59337138]
 [  0.           1.         -21.27063231]
 [  0.           0.           1.        ]]
frame 104 
 [[  1.           0.         -14.79626622]
 [  0.           1.         -20.1216003 ]
 [  0.           0.           1.        ]]
frame 105 
 [[  1.           0.         -14.35232363]
 [  0.           1.         -25.79305233]
 [  0.           0.           1.        ]]
frame 106 
 [[  1.           0.         -14.42558353]
 [  0.           1.         -25.12203227]
 [  0.           0.           1.        ]]
frame 107 
 [[  1.           0.         -14.57673501]
 [  0.           1.         -17.46695681]
 [  0.           0.           1.        ]]
frame 108 
 [[  1.         

# problematic shift at frame 107

In [79]:
reference_image[106]

array([[76, 76, 74, ..., 70, 74, 74],
       [79, 71, 71, ..., 72, 73, 77],
       [72, 74, 75, ..., 76, 74, 73],
       ...,
       [75, 74, 76, ..., 69, 71, 69],
       [75, 75, 72, ..., 70, 70, 70],
       [71, 75, 72, ..., 72, 72, 73]], dtype=uint8)

In [80]:
reference_image[107]

array([[76, 77, 75, ..., 73, 72, 73],
       [74, 74, 73, ..., 78, 74, 78],
       [73, 75, 76, ..., 76, 79, 72],
       ...,
       [74, 77, 72, ..., 71, 74, 72],
       [73, 70, 73, ..., 68, 69, 76],
       [76, 73, 73, ..., 72, 72, 71]], dtype=uint8)

In [81]:
reference_image[108]

array([[70, 74, 76, ..., 77, 74, 74],
       [75, 73, 73, ..., 73, 73, 76],
       [76, 75, 71, ..., 74, 77, 73],
       ...,
       [76, 79, 75, ..., 72, 73, 74],
       [78, 75, 74, ..., 71, 71, 70],
       [71, 72, 73, ..., 74, 71, 70]], dtype=uint8)

In [82]:
### checking raw images
for i in range(100, 110):
    print(np.amax(reference_image[i]))
    print(np.amin(reference_image[i]))

255
60
255
60
255
60
255
60
255
59
255
59
255
60
255
60
255
61
255
60


# Is the problematic shift a result of the stack reg, ie can i reproduce it in an individual frame by frame registration

In [83]:
### problematic frame
i = 107

In [84]:
StackReg(StackReg.TRANSLATION).register(reference_image[i-1], reference_image[i])

array([[ 1.        ,  0.        , -0.18532186],
       [ 0.        ,  1.        ,  0.05314004],
       [ 0.        ,  0.        ,  1.        ]])

#### is the stack registration a cumulative measure over 1183 frames?

In [85]:
stack_reg = []
for i in tqdm(range(1,1192)):
    stack_reg.append(StackReg(StackReg.TRANSLATION).register(reference_image[i-1], reference_image[i]))

 38%|███▊      | 457/1191 [01:05<01:45,  6.97it/s]


IndexError: index 458 is out of bounds for axis 0 with size 458

In [None]:
np.sum(stack_reg, axis = 0)

# Try different transformation methods

In [None]:
# transformation matrix should take on the form

#             [1 , 0 , x
#              0 , 1 , y
#              0 , 0 , 1]

# where x and y are the translate magnitudes

In [None]:
i = 107

In [None]:
### float translation
StackReg(StackReg.TRANSLATION).register(images['gfp'][i-1], images['gfp'][i])

In [None]:
### int translation
StackReg(StackReg.TRANSLATION).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

In [None]:
### float rigid body
StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i])

In [None]:
### np.rint rigid body
np.rint(StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i]))

In [188]:
### integer-ising the matrix zeroes some important numbers

In [None]:
### int rigidbody
StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

In [None]:
### float affine
StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i])

In [None]:
### np.rint affine
np.rint(StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i]))

In [None]:
### int affine
StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

# seems like all the alignment methods produce similar shifted outputs... is it the images?

In [None]:
for channel in images.channels:
    print(channel.name, StackReg(StackReg.TRANSLATION).register(images[channel.name][i-1], images[channel.name][i]).astype(np.int8))

# is it the gfp channel? checking each channel for max transformation

In [5]:
trans_tensors = {}
for channel in images.channels:
    print('Starting channel:', channel.name)
    trans_tensor = []
#     for i in tqdm(range(1, len(images['gfp']))):
#         ### create transformation matrix for i'th and i+1'th frame
#         trans_matrix = StackReg(StackReg.TRANSLATION).register(images[channel.name][i-1], images[channel.name][i]).astype(np.int8)
#         trans_tensor.append(trans_matrix)

    trans_tensor = np.stack(trans_tensor)
    trans_tensors[channel.name] = trans_tensor
    print(channel.name, np.amax(trans_tensor))

Starting channel: BRIGHTFIELD


100%|██████████| 1199/1199 [25:51<00:00,  1.29s/it]


BRIGHTFIELD 126
Starting channel: GFP


100%|██████████| 1199/1199 [26:16<00:00,  1.31s/it]


GFP 108
Starting channel: RFP


100%|██████████| 1199/1199 [27:56<00:00,  1.40s/it]


RFP 47
Starting channel: IRFP


100%|██████████| 1199/1199 [27:06<00:00,  1.36s/it]

IRFP 24





In [6]:
import pickle
with open('all_ch_trans_tensors.json', 'wb') as fp:
    pickle.dump(trans_tensors, fp)

# Checking the alignment tensors of each channel

In [10]:
for channel in trans_tensors:
    print(channel, np.amax(trans_tensors[channel]))

BRIGHTFIELD 126
GFP 108
RFP 47
IRFP 24


# irfp channel has lowest max shift in so test run alignment on that 

In [15]:
transform_tensor = trans_tensors['IRFP']

In [16]:
np.amax(transform_tensor)

24

In [23]:
len(images['gfp']), len(transform_tensor)

(1200, 1199)

In [26]:
%%time

for channel in images.channels:
    print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
    #iterate over all images in channel
    for i in tqdm(range(len(images[channel.name]))):
        # skip dodgy frames and don't save out into aligned folder
        if i in dodgy_frame_list or i == 1199:
            continue
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1 / 4


100%|██████████| 1200/1200 [01:28<00:00, 13.62it/s]


Aligning gfp channel 2 / 4


100%|██████████| 1200/1200 [01:28<00:00, 13.62it/s]


Aligning rfp channel 3 / 4


100%|██████████| 1200/1200 [01:26<00:00, 13.92it/s]


Aligning irfp channel 4 / 4


100%|██████████| 1200/1200 [01:31<00:00, 13.06it/s]

CPU times: user 5min 11s, sys: 43.9 s, total: 5min 55s
Wall time: 5min 54s





# Batch execute

In [1]:
root_dir = '/home/nathan/data/kraken/commitment/test/'


In [None]:
%%time
alignment(expt_list = ['MK0000', 'MK0001', 'MK0002', 'MK0003'], 
          max_pixel = 200, 
          min_pixel = 2, 
          crop_area = 500)

In [None]:
viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255], colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

In [5]:
def alignment(expt_list, max_pixel, min_pixel, crop_area):

    ### Iterate over all experiments defined in expt_list
    for expt in expt_list:
        # Find all positions in that experiment
        pos_list = [pos for pos in os.listdir(os.path.join(root_dir, expt)) if 'Pos' in pos]
        ### Iterate over all positions in that experiment
        for pos in pos_list:
            ### create new subdir of for raw files and move them all there
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
                files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
                for file in files:
                    os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # create empty dicts and sets to store values in 
            mean_arrays = {}
            dodgy_frame_list = set([])
            # iterate over channels
            for channel in tqdm(images.channels):
                print(f'Finding mean values of {channel.name.lower()} images', pos, expt)
                # find mean pixel values for each channel
                mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if frame does not meet inclusion criteria then add to dodgy list
                        dodgy_frame_list.add(frame)
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list), pos, expt)

            ### create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))

            ### Automatically pick reference image to perform alignment on 
            # Pick channel based on index of brightest channel from maximum mean pixel array
            reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
            # Define reference images
            reference_image = images[reference_channel.name]
            reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
            reference_image.shape
            print('Automatically selected and cropped reference image:', reference_channel.name)

            ### Register alignment
            print('Registering alignment for', pos, expt)
            # create operator using transformation type (translation)
            sr = StackReg(StackReg.TRANSLATION) 
            # register each frame using reference image to the previous as transformation matrices/tensor
            transform_tensor = sr.register_stack(reference_image, reference = 'previous').astype(np.uint8)
            # save out transform tensor
            np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

            ### Perform alignment
            # create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
            # iterate over channels
            for channel in images.channels:
                print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
                #iterate over all images in channel
                for i in tqdm(range(len(images[channel.name]))):
                    # skip dodgy frames and don't save out into aligned folder
                    if i in dodgy_frame_list:
                        continue
                    # load specific transform matrix for that frame
                    transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
                    # transform image
                    transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
                    # set transformed image pathname by editing base dir
                    fn = images.files(channel.name)[i].replace('_raw', '_aligned')
                    # save trans image out
                    io.imsave(fn, transformed_image, check_contrast=False)

# Repeat alignment for larger datasets

# Compile stacks and save out?

In [None]:
start()
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
stop()