# Align and remove blanks

Adapted from Giulia Vallardi's ImageJ macro, this notebook removes any blank frames from timelapse experiments and aligns the images. 

"Fiji macro to remove over- and under-exposed images, and align the image stacks

The settings for the alignments are: 
registration by Translation > only modify XY coordinates
Shrinkage constrain activated (this model allows a better registration based on all images, not using a reference image. It is more time consuming though)
Transform matrices are saved during registration and then applied to the other channels during transformation."

# to - do: decide how to import dask octopus

In [4]:
import octopuslite

ModuleNotFoundError: No module named 'octopuslite'

In [1]:
import os
import glob
import enum
import numpy as np
from pystackreg import StackReg
from skimage import io
from tqdm import tqdm
from daskoctopus import DaskOctopusLiteLoader
from skimage import transform as tf

# Find images, organise into raw folder and load using dask octo

In [3]:
### define root directory and specific experiment and location (will later make iterable)
root_dir = '/home/nathan/data/kraken/commitment/test/'
expt = "MK0003"
pos = "Pos1"

In [5]:
### create new subdir of for raw files and move them all there
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
    files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
    for file in files:
        os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

In [4]:
### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

Using cropping: None


In [7]:
images['irfp']

Unnamed: 0,Array,Chunk
Bytes,10.20 GiB,8.71 MiB
Shape,"(1200, 1352, 1688)","(1, 1352, 1688)"
Count,3600 Tasks,1200 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 10.20 GiB 8.71 MiB Shape (1200, 1352, 1688) (1, 1352, 1688) Count 3600 Tasks 1200 Chunks Type float32 numpy.ndarray",1688  1352  1200,

Unnamed: 0,Array,Chunk
Bytes,10.20 GiB,8.71 MiB
Shape,"(1200, 1352, 1688)","(1, 1352, 1688)"
Count,3600 Tasks,1200 Chunks
Type,float32,numpy.ndarray


# Find blank or overexposed images and display average channel brightness

In [8]:
%%time
# pixel range criteria
max_pixel, min_pixel = 200, 2

### find mean values ### pre load files from raw file dir 
images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

mean_arrays = {}
dodgy_frame_list = set([])
for channel in tqdm(images.channels, desc = f'Finding mean values of {channel.name.lower()} images'):
    mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
    for frame, mean_value in enumerate(mean_arrays[channel.name]):
        if max_pixel < mean_value or mean_value < min_pixel:
            dodgy_frame_list.add(frame)
dodgy_frame_list = list(dodgy_frame_list)

print('Number of under/over-exposed frames:', len(dodgy_frame_list))

Using cropping: None


  0%|          | 0/4 [00:00<?, ?it/s]

Finding mean values of brightfield images


 25%|██▌       | 1/4 [00:34<01:44, 34.75s/it]

Finding mean values of gfp images


 50%|█████     | 2/4 [01:11<01:11, 35.86s/it]

Finding mean values of rfp images


 75%|███████▌  | 3/4 [01:45<00:35, 35.23s/it]

Finding mean values of irfp images


100%|██████████| 4/4 [02:20<00:00, 35.03s/it]

Number of under/over-exposed frames: 8
CPU times: user 23 s, sys: 9.18 s, total: 32.2 s
Wall time: 2min 20s





# Filtering to remove blank or overexposed frames from image array and mean value arrays

In [18]:
filtered_images= {}
for channel in images.channels:
    filtered_images[channel.name] = np.delete(images[channel], dodgy_frame_list, axis = 0)
    mean_arrays[channel.name] = np.delete(mean_arrays[channel.name], dodgy_frame_list, axis = 0) 

In [20]:
filtered_images

{'BRIGHTFIELD': dask.array<concatenate, shape=(1192, 1352, 1688), dtype=float32, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'GFP': dask.array<concatenate, shape=(1192, 1352, 1688), dtype=float32, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'RFP': dask.array<concatenate, shape=(1192, 1352, 1688), dtype=float32, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>,
 'IRFP': dask.array<concatenate, shape=(1192, 1352, 1688), dtype=float32, chunksize=(1, 1352, 1688), chunktype=numpy.ndarray>}

# write out dask octo function to forget certain frames

### or write something to dask octo to delete from array???

In [37]:
filtered_images= {}
for channel in images.channels:
    images[channel.name] = np.delete(images[channel], dodgy_frame_list, axis = 0)

TypeError: 'DaskOctopusLiteLoader' object does not support item assignment

# Select reference image to base alignment around

In [21]:
print('Average channel brightness for selection of reference image:')
for channel in images.channels:
    print(f'{channel.value}: {channel.name}:', np.mean(mean_arrays[channel.name]))

Average channel brightness for selection of reference image:
0: BRIGHTFIELD: 28.574265
1: GFP: 62.654087
2: RFP: 6.0335283
3: IRFP: 76.07041


In [23]:
# manually select reference channel by adding index
# reference_channel = filtered_images['IRFP']
# automatically select reference channel from max average pixel value (ie. brightest channel)
reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
reference_image = filtered_images[reference_channel.name]
print('Reference channel:', reference_channel.name)

Reference channel: IRFP


In [24]:
reference_image

Unnamed: 0,Array,Chunk
Bytes,10.13 GiB,8.71 MiB
Shape,"(1192, 1352, 1688)","(1, 1352, 1688)"
Count,7108 Tasks,1192 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 10.13 GiB 8.71 MiB Shape (1192, 1352, 1688) (1, 1352, 1688) Count 7108 Tasks 1192 Chunks Type float32 numpy.ndarray",1688  1352  1192,

Unnamed: 0,Array,Chunk
Bytes,10.13 GiB,8.71 MiB
Shape,"(1192, 1352, 1688)","(1, 1352, 1688)"
Count,7108 Tasks,1192 Chunks
Type,float32,numpy.ndarray


## Set cropped area of reference image to base alignment around (whole image struggles to compute)

In [25]:
crop_area = 500
reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
reference_image.shape

(1192, 500, 500)

# Register alignment and save out

In [26]:
%%time
# create operator using transformation type (translation)
sr = StackReg(StackReg.TRANSLATION) 

# register each frame to the previous as transformation matrices/tensor
transform_tensor = sr.register_stack(reference_image, reference = 'first')

# save out transform tensor
np.save(os.path.join(root_dir, f'{expt}/{pos}/filtered_irfp_transform_tensor.npy'), transform_tensor)



CPU times: user 2min 34s, sys: 5.53 s, total: 2min 40s
Wall time: 2min 40s


In [114]:
transform_tensor.shape

(1192, 3, 3)

# Apply transformation matrix to all channels and save out images

In [104]:
%%time
### iterating over channels
# create aligned image dir if does not exist 
if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_filtered_aligned')):
    os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_filtered_aligned'))
# iterate over channels
for channel in images.channels:
    print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
    #iterate over all images in channel
    for i in tqdm(range(1, len(filtered_images[channel.name]))):
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(filtered_images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_filtered_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1 / 4


100%|██████████| 1191/1191 [01:35<00:00, 12.45it/s]


Aligning gfp channel 2 / 4


100%|██████████| 1191/1191 [01:33<00:00, 12.77it/s]


Aligning rfp channel 3 / 4


100%|██████████| 1191/1191 [01:31<00:00, 12.97it/s]


Aligning irfp channel 4 / 4


100%|██████████| 1191/1191 [01:35<00:00, 12.49it/s]

CPU times: user 5min 30s, sys: 48 s, total: 6min 17s
Wall time: 6min 16s





In [None]:
transformed_image = (tf.warp(filtered_images[channel.name][i,...], transform_matrix, preserve_range = True)).astype(np.uint8)

# check images

In [5]:
import napari

In [6]:
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_filtered_aligned'))
old_aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
viewer = napari.Viewer()
for channel in aligned_images.channels:
    if channel.name == 'IRFP':
        viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255])
        #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps
                         #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps
        viewer.add_image(old_aligned_images[channel.name], name = channel.name+'old', blending = 'additive', contrast_limits = [0,255])
                         #, colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

Using cropping: None
Using cropping: None


# Troubleshooting jumpy transforms

# Check transform tensor for jumpy transtitions

In [151]:
for j, i in enumerate(transform_tensor):
    if j > 1170:
        print("frame", j,'\n', i)

frame 1171 
 [[  1   0 -33]
 [  0   1 124]
 [  0   0   1]]
frame 1172 
 [[  1   0 -31]
 [  0   1 125]
 [  0   0   1]]
frame 1173 
 [[  1   0 -37]
 [  0   1 124]
 [  0   0   1]]
frame 1174 
 [[  1   0 -32]
 [  0   1 125]
 [  0   0   1]]
frame 1175 
 [[  1   0 -33]
 [  0   1 125]
 [  0   0   1]]
frame 1176 
 [[  1   0 -31]
 [  0   1 125]
 [  0   0   1]]
frame 1177 
 [[  1   0 -32]
 [  0   1 126]
 [  0   0   1]]
frame 1178 
 [[  1   0 -37]
 [  0   1 126]
 [  0   0   1]]
frame 1179 
 [[  1   0 -34]
 [  0   1 126]
 [  0   0   1]]
frame 1180 
 [[  1   0 -30]
 [  0   1 127]
 [  0   0   1]]
frame 1181 
 [[  1   0 -34]
 [  0   1 127]
 [  0   0   1]]
frame 1182 
 [[  1   0 -34]
 [  0   1 127]
 [  0   0   1]]
frame 1183 
 [[   1    0  -34]
 [   0    1 -128]
 [   0    0    1]]
frame 1184 
 [[  1   0 -34]
 [  0   1 127]
 [  0   0   1]]
frame 1185 
 [[  1   0 -34]
 [  0   1 127]
 [  0   0   1]]
frame 1186 
 [[   1    0  -31]
 [   0    1 -128]
 [   0    0    1]]
frame 1187 
 [[   1    0  -28]
 [   0 

# problematic shift at frame 1183

In [41]:
reference_image[1182]

array([[59, 61, 62, ..., 75, 76, 75],
       [60, 61, 59, ..., 71, 73, 76],
       [61, 60, 60, ..., 72, 74, 77],
       ...,
       [70, 71, 70, ..., 81, 86, 85],
       [66, 68, 72, ..., 83, 87, 85],
       [70, 70, 71, ..., 84, 87, 85]], dtype=uint8)

In [40]:
reference_image[1183]

array([[57, 60, 59, ..., 74, 70, 76],
       [60, 62, 62, ..., 76, 74, 74],
       [60, 61, 65, ..., 75, 71, 73],
       ...,
       [70, 68, 67, ..., 89, 88, 85],
       [69, 65, 65, ..., 83, 84, 84],
       [66, 68, 67, ..., 81, 85, 86]], dtype=uint8)

In [42]:
reference_image[1184]

array([[61, 60, 62, ..., 72, 73, 74],
       [56, 58, 60, ..., 74, 71, 73],
       [60, 59, 63, ..., 70, 73, 72],
       ...,
       [75, 75, 77, ..., 84, 84, 86],
       [72, 76, 74, ..., 83, 80, 89],
       [75, 77, 76, ..., 82, 84, 82]], dtype=uint8)

In [44]:
### checking raw images
for i in range(1181, 1186):
    print(np.amax(reference_image[i]))
    print(np.amin(reference_image[i]))

255
50
255
49
255
50
255
50
255
49


# Is the problematic shift a result of the stack reg, ie can i reproduce it in an individual frame by frame registration

In [124]:
### problematic frame
i = 1183

In [193]:
StackReg(StackReg.TRANSLATION).register(reference_image[i-1], reference_image[i])

array([[ 1.        ,  0.        , -0.27052358],
       [ 0.        ,  1.        , -0.08158292],
       [ 0.        ,  0.        ,  1.        ]])

#### is the stack registration a cumulative measure over 1183 frames?

In [194]:
stack_reg = []
for i in tqdm(range(1,1192)):
    stack_reg.append(StackReg(StackReg.TRANSLATION).register(reference_image[i-1], reference_image[i]))

100%|██████████| 1191/1191 [02:38<00:00,  7.53it/s]


In [199]:
np.sum(stack_reg, axis = 0)

array([[1191.        ,    0.        ,  -29.8818751 ],
       [   0.        , 1191.        ,  127.99044408],
       [   0.        ,    0.        , 1191.        ]])

# Try different transformation methods

In [None]:
# transformation matrix should take on the form

#             [1 , 0 , x
#              0 , 1 , y
#              0 , 0 , 1]

# where x and y are the translate magnitudes

In [183]:
i = 1184

In [184]:
### float translation
StackReg(StackReg.TRANSLATION).register(images['gfp'][i-1], images['gfp'][i])

array([[1.        , 0.        , 1.69270859],
       [0.        , 1.        , 0.54300491],
       [0.        , 0.        , 1.        ]])

In [185]:
### int translation
StackReg(StackReg.TRANSLATION).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

array([[1, 0, 1],
       [0, 1, 0],
       [0, 0, 1]], dtype=int8)

In [186]:
### float rigid body
StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i])

array([[ 9.99999870e-01,  5.09415957e-04,  1.32909743e+00],
       [-5.09415957e-04,  9.99999870e-01,  9.69781071e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])

In [187]:
### np.rint rigid body
np.rint(StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i]))

array([[ 1.,  0.,  1.],
       [-0.,  1.,  1.],
       [ 0.,  0.,  1.]])

In [188]:
### integer-ising the matrix zeroes some important numbers

In [189]:
### int rigidbody
StackReg(StackReg.RIGID_BODY).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

array([[0, 0, 1],
       [0, 0, 0],
       [0, 0, 1]], dtype=int8)

In [190]:
### float affine
StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i])

array([[ 1.00010486e+00,  7.31302774e-04,  1.14884038e+00],
       [-4.18274933e-04,  9.99898269e-01,  9.75798050e-01],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])

In [191]:
### np.rint affine
np.rint(StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i]))

array([[ 1.,  0.,  1.],
       [-0.,  1.,  1.],
       [ 0.,  0.,  1.]])

In [192]:
### int affine
StackReg(StackReg.AFFINE).register(images['gfp'][i-1], images['gfp'][i]).astype(np.int8)

array([[1, 0, 1],
       [0, 0, 0],
       [0, 0, 1]], dtype=int8)

# seems like all the alignment methods produce similar shifted outputs... is it the images?

In [136]:
for channel in images.channels:
    print(channel.name, StackReg(StackReg.TRANSLATION).register(images[channel.name][i-1], images[channel.name][i]).astype(np.int8))

BRIGHTFIELD [[  1   0  84]
 [  0   1 -15]
 [  0   0   1]]
GFP [[1 0 0]
 [0 1 0]
 [0 0 1]]
RFP [[1 0 0]
 [0 1 0]
 [0 0 1]]
IRFP [[ 1  0 -1]
 [ 0  1  0]
 [ 0  0  1]]


# is it the gfp channel? checking each channel for max transformation

In [5]:
trans_tensors = {}
for channel in images.channels:
    print('Starting channel:', channel.name)
    trans_tensor = []
#     for i in tqdm(range(1, len(images['gfp']))):
#         ### create transformation matrix for i'th and i+1'th frame
#         trans_matrix = StackReg(StackReg.TRANSLATION).register(images[channel.name][i-1], images[channel.name][i]).astype(np.int8)
#         trans_tensor.append(trans_matrix)

    trans_tensor = np.stack(trans_tensor)
    trans_tensors[channel.name] = trans_tensor
    print(channel.name, np.amax(trans_tensor))

Starting channel: BRIGHTFIELD


100%|██████████| 1199/1199 [25:51<00:00,  1.29s/it]


BRIGHTFIELD 126
Starting channel: GFP


100%|██████████| 1199/1199 [26:16<00:00,  1.31s/it]


GFP 108
Starting channel: RFP


100%|██████████| 1199/1199 [27:56<00:00,  1.40s/it]


RFP 47
Starting channel: IRFP


100%|██████████| 1199/1199 [27:06<00:00,  1.36s/it]

IRFP 24





In [6]:
import pickle
with open('all_ch_trans_tensors.json', 'wb') as fp:
    pickle.dump(trans_tensors, fp)

# Checking the alignment tensors of each channel

In [10]:
for channel in trans_tensors:
    print(channel, np.amax(trans_tensors[channel]))

BRIGHTFIELD 126
GFP 108
RFP 47
IRFP 24


# irfp channel has lowest max shift in so test run alignment on that 

In [15]:
transform_tensor = trans_tensors['IRFP']

In [16]:
np.amax(transform_tensor)

24

In [23]:
len(images['gfp']), len(transform_tensor)

(1200, 1199)

In [26]:
%%time

for channel in images.channels:
    print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
    #iterate over all images in channel
    for i in tqdm(range(len(images[channel.name]))):
        # skip dodgy frames and don't save out into aligned folder
        if i in dodgy_frame_list or i == 1199:
            continue
        # load specific transform matrix for that frame
        transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
        # transform image
        transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
        # set transformed image pathname by editing base dir
        fn = images.files(channel.name)[i].replace('_raw', '_aligned')
        # save trans image out
        io.imsave(fn, transformed_image, check_contrast=False)

Aligning brightfield channel 1 / 4


100%|██████████| 1200/1200 [01:28<00:00, 13.62it/s]


Aligning gfp channel 2 / 4


100%|██████████| 1200/1200 [01:28<00:00, 13.62it/s]


Aligning rfp channel 3 / 4


100%|██████████| 1200/1200 [01:26<00:00, 13.92it/s]


Aligning irfp channel 4 / 4


100%|██████████| 1200/1200 [01:31<00:00, 13.06it/s]

CPU times: user 5min 11s, sys: 43.9 s, total: 5min 55s
Wall time: 5min 54s





# Batch execute

In [1]:
root_dir = '/home/nathan/data/kraken/commitment/test/'


In [None]:
%%time
alignment(expt_list = ['MK0000', 'MK0001', 'MK0002', 'MK0003'], 
          max_pixel = 200, 
          min_pixel = 2, 
          crop_area = 500)

In [None]:
viewer = napari.Viewer()
for channel in aligned_images.channels:
    viewer.add_image(aligned_images[channel.name], name = channel.name, blending = 'additive', contrast_limits = [0,255], colormap = napari.utils.colormaps.SIMPLE_COLORMAPS.popitem()) # lazy hack to randomly generate different colormaps

In [5]:
def alignment(expt_list, max_pixel, min_pixel, crop_area):

    ### Iterate over all experiments defined in expt_list
    for expt in expt_list:
        # Find all positions in that experiment
        pos_list = [pos for pos in os.listdir(os.path.join(root_dir, expt)) if 'Pos' in pos]
        ### Iterate over all positions in that experiment
        for pos in pos_list:
            ### create new subdir of for raw files and move them all there
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))
                files = sorted(glob.glob(os.path.join(root_dir, f'{expt}/{pos}/*.tif')))
                for file in files:
                    os.rename(file, file.replace(f'{pos}', f'{pos}/{pos}_raw'))

            ### pre load files from raw file dir 
            images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_raw'))

            ### measure mean pixel value arrays and use to find under/over-exposed frames
            # create empty dicts and sets to store values in 
            mean_arrays = {}
            dodgy_frame_list = set([])
            # iterate over channels
            for channel in tqdm(images.channels):
                print(f'Finding mean values of {channel.name.lower()} images', pos, expt)
                # find mean pixel values for each channel
                mean_arrays[channel.name] = np.mean(images[channel.name], axis = (1,2)).compute() 
                # iterate over frames
                for frame, mean_value in enumerate(mean_arrays[channel.name]):
                    if max_pixel < mean_value or mean_value < min_pixel:
                        # if frame does not meet inclusion criteria then add to dodgy list
                        dodgy_frame_list.add(frame)
            dodgy_frame_list = list(dodgy_frame_list)
            print('Number of under/over-exposed frames:', len(dodgy_frame_list), pos, expt)

            ### create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))

            ### Automatically pick reference image to perform alignment on 
            # Pick channel based on index of brightest channel from maximum mean pixel array
            reference_channel = images.channels[max([(channel.value, np.mean(mean_arrays[channel.name])) for channel in images.channels])[0]]
            # Define reference images
            reference_image = images[reference_channel.name]
            reference_image = reference_image[:,int((reference_image.shape[2]-crop_area)/2):int(reference_image.shape[2]-(reference_image.shape[2]-crop_area)/2),int((reference_image.shape[1]-crop_area)/2):int(reference_image.shape[1]-(reference_image.shape[1]-crop_area)/2)].compute()
            reference_image.shape
            print('Automatically selected and cropped reference image:', reference_channel.name)

            ### Register alignment
            print('Registering alignment for', pos, expt)
            # create operator using transformation type (translation)
            sr = StackReg(StackReg.TRANSLATION) 
            # register each frame using reference image to the previous as transformation matrices/tensor
            transform_tensor = sr.register_stack(reference_image, reference = 'previous').astype(np.uint8)
            # save out transform tensor
            np.save(os.path.join(root_dir, f'{expt}/{pos}/{reference_channel.name.lower()}_transform_tensor.npy'), transform_tensor)

            ### Perform alignment
            # create aligned image dir if does not exist 
            if not os.path.exists(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned')):
                os.mkdir(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
            # iterate over channels
            for channel in images.channels:
                print('Aligning', channel.name.lower(), 'channel', channel.value+1, '/', len(images.channels))
                #iterate over all images in channel
                for i in tqdm(range(len(images[channel.name]))):
                    # skip dodgy frames and don't save out into aligned folder
                    if i in dodgy_frame_list:
                        continue
                    # load specific transform matrix for that frame
                    transform_matrix = tf.EuclideanTransform(matrix = transform_tensor[i,...],rotation = None)
                    # transform image
                    transformed_image = (tf.warp(images[channel.name][i,...], transform_matrix)*255).astype(np.uint8)
                    # set transformed image pathname by editing base dir
                    fn = images.files(channel.name)[i].replace('_raw', '_aligned')
                    # save trans image out
                    io.imsave(fn, transformed_image, check_contrast=False)

# Repeat alignment for larger datasets

# Compile stacks and save out?

In [None]:
start()
aligned_images = DaskOctopusLiteLoader(os.path.join(root_dir, f'{expt}/{pos}/{pos}_aligned'))
stop()