In [1]:
import tifffile
import euler_gpu
import os, random
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import numpy as np
import torch
import h5py
from tqdm import tqdm
from math import pi as PI
device = torch.device("cuda:0")
BATCH_SIZE = 256
red_channel = 1
fixed_frame = 0

MAX_NUM_CENTROIDS = 50

In [2]:
tiff_path = "/storage/fs/store1/brian/swimming_videos_RFa/Folder_20250214153740_RFa/20250214_Experiment_01_0-1999.tif"
inputs = tifffile.imread(tiff_path)

## Align Video by adjacent frames over multiple files

In [3]:
TH_CONST = np.array([0])
XY_RANGE_STEP2 = np.concatenate((np.linspace(-0.005, 0.005, 10, dtype=np.float32), TH_CONST)) # Just using this here so we don't have ot redefine
ALIGN_TH_RANGE = np.concatenate((np.linspace(0, .5, 10, dtype=np.float32), np.linspace(359.5, 360, 10, dtype=np.float32), TH_CONST))

In [4]:
tiff_folder = "/storage/fs/store1/brian/swimming_videos_RFa/Folder_20250214153740_RFa"
files = ["20250214_Experiment_01_0-1999.tif", "20250214_Experiment_01_2000-3999.tif", "20250214_Experiment_01_4000-5999.tif"]


In [31]:
import cv2
import numpy as np
from scipy.ndimage import label
import tifffile
import matplotlib.pyplot as plt

def segment_sparse_blobs(
    image,
    blur_sigma=1.5,
    threshold_method="adaptive",  # "adaptive" or "global"
    global_thresh_val=100,
    min_area=20,
    morph_kernel_size=3
):
    """
    Segment sparse blobs in a grayscale image.

    Args:
        image (np.ndarray): Grayscale image.
        blur_sigma (float): Gaussian blur strength.
        threshold_method (str): "adaptive" or "global".
        global_thresh_val (int): Threshold value if using global.
        min_area (int): Minimum area (pixels) to keep a blob.
        morph_kernel_size (int): Size for closing small holes.

    Returns:
        labeled_mask (np.ndarray): Connected component labels.
        centroids (np.ndarray): (x, y) centroid coordinates of blobs.
        contours (list): List of contour arrays.
    """

    # Convert to 8-bit if needed
    if image.dtype != np.uint8:
        image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Step 1: Smooth image
    blurred = cv2.GaussianBlur(image, (0, 0), blur_sigma) # Pretty sure this doesn't do anything, but it works so I'm not going to mess with it

    # Step 2: Threshold
    if threshold_method == "adaptive":
        binary = cv2.adaptiveThreshold(
            blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY, blockSize=35, C=-10)
    else:
        _, binary = cv2.threshold(blurred, global_thresh_val, 255, cv2.THRESH_BINARY)

    # Step 3: Morphological cleanup
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_kernel_size,) * 2)
    cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)

    # Step 4: Label connected components
    labeled, num_labels = label(cleaned)

    # Step 5: Filter small blobs and compute centroids
    centroids = np.zeros((num_labels, 2), dtype=np.float32)
    contours = []
    output_mask = np.zeros_like(labeled)

    for label_val in range(1, num_labels + 1):
        mask = (labeled == label_val).astype(np.uint8)
        area = np.sum(mask)

        if area >= min_area:
            output_mask[mask > 0] = label_val
            M = cv2.moments(mask)
            if M["m00"] != 0:
                cx = int(M["m10"] / M["m00"])
                cy = int(M["m01"] / M["m00"])
                # centroids.append((cx, cy))
                centroids[label_val - 1] = (cx, cy)
            cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            contours.extend(cnts)

    return output_mask, centroids, contours

def pad_centroids(centroids): # Not entirely necessary, but I like being able to catch these centroid overflows
    """
    Pad the centroids array to the globally specificed maximum number of centroids.

    Args:
        centroids (np.ndarray): Array of centroids.

    Returns:
        np.ndarray: Padded centroids array.
    """

    if centroids.shape[0] > MAX_NUM_CENTROIDS: # Handle too many centroids by randomly sampling
        print(f"Warning: Too many centroids ({centroids.shape[0]}). Truncated after {MAX_NUM_CENTROIDS} centroids. This is NOT a truly random set.")
        # return centroids[np.random.choice(centroids.shape[0], MAX_NUM_CENTROIDS, replace=False)] # Was going to sample but this allows us entry consistency
        return centroids[:MAX_NUM_CENTROIDS]
    padded_centroids = np.full((MAX_NUM_CENTROIDS, 2), -1, dtype=np.float32)
    num_centroids = centroids.shape[0]
    padded_centroids[:num_centroids] = centroids
    return padded_centroids


def get_cents_from_rois(roi_img):
    centroids = np.zeros((MAX_NUM_CENTROIDS, 2), dtype=np.float32)

    for label_val in range(1, np.max(roi_img) + 1):
        mask = (roi_img == label_val).astype(np.uint8)

        M = cv2.moments(mask)
        if M["m00"] != 0:
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])
            # centroids.append((cx, cy))
            centroids[label_val - 1] = (cx, cy)
        
    return centroids

In [33]:
from skimage.filters import median, gaussian
from skimage.morphology import disk

In [37]:
## Align Video by adjacent frames (i.e. look for small variations of the previous transformation of the image that align would align it best with the fixed image)
other_chan = (red_channel - 1) % inputs.shape[1]

fixed_image = inputs[0, red_channel]
dtyp = fixed_image.dtype
fixed_image = (segment_sparse_blobs(fixed_image)[0] > 0).astype(dtyp)

pre_centroids = np.full((inputs.shape[0], MAX_NUM_CENTROIDS, 2), fill_value = -1, dtype=np.float32)
rig_centroids = np.full((inputs.shape[0], MAX_NUM_CENTROIDS, 2), fill_value = -1, dtype=np.float32)

prev_transform = (torch.tensor([0.0], device='cuda:0'), torch.tensor([0.0], device='cuda:0'), torch.tensor([0.0], device='cuda:0'))

for file in files:
    if file == files[0]: # Just saves some time loading large files.
        ins = inputs
    else:
        ins = tifffile.imread(os.path.join(tiff_folder, file))

    out = np.empty_like(ins)
    for frame in tqdm(range(ins.shape[0])):
        moving_image = ins[frame, red_channel]

        ## Image Modifications before Alignment (make sure match fixed image)
        rois, centroids, _ = segment_sparse_blobs(moving_image)
        moving_image = (rois > 0).astype(dtyp)

        pre_centroids[frame] = pad_centroids(centroids)

        align_x_range_2 = np.add(XY_RANGE_STEP2, prev_transform[0].cpu().numpy())
        align_y_range_2 = np.add(XY_RANGE_STEP2, prev_transform[1].cpu().numpy())
        align_the_range = np.mod(np.add(ALIGN_TH_RANGE, (prev_transform[2].cpu().numpy() * 180) / PI), 360) # For some reason the search is in degrees but the warp is in radians
        memory_dict = euler_gpu.initialize(fixed_image, moving_image, align_x_range_2, align_y_range_2, align_the_range, BATCH_SIZE, device)
        best_score, best_transformation = euler_gpu.grid_search(memory_dict)
        
        in_img = ins[frame, red_channel]
        in_img = torch.Tensor(in_img[np.newaxis, np.newaxis, ...]).to(device=device)
        out[frame, red_channel] = euler_gpu.transform_image(in_img, best_transformation[0], best_transformation[1], best_transformation[2], memory_dict).cpu().numpy()
        in_img = ins[frame, other_chan]
        in_img = torch.Tensor(in_img[np.newaxis, np.newaxis, ...]).to(device=device)
        out[frame, other_chan] = euler_gpu.transform_image(in_img, best_transformation[0], best_transformation[1], best_transformation[2], memory_dict).cpu().numpy()
        

        ## Noting this here because I can't think of a better place: centroid entry consistency is only guaranteed before and after the transformation, not between frames
            # e.g. the first entry in pre_ will be the same centroid as in the rig_ but not if you increase the index by one
        rig_centroids[frame] = pad_centroids(get_cents_from_rois(
            euler_gpu.transform_image(
                torch.tensor(rois[np.newaxis, np.newaxis, ...], device=device, dtype=in_img.dtype), 
                best_transformation[0], best_transformation[1], best_transformation[2], memory_dict, interpolation="nearest").cpu().numpy().astype(int)[0, 0]))
        
        if file == files[0] and frame == 0:
            # print(best_transformation) # This should be all zeros
            assert np.all((pre_centroids[frame] < 0) | (rig_centroids[frame] < 0) | (rig_centroids[frame] == pre_centroids[frame])), "Centroids should not warp on the first frame"

        prev_transform = best_transformation # Restart the next loop searching around the last best transformation

    print(prev_transform) # Printing so you can restore if something crashes
    tifffile.imwrite(os.path.join(tiff_folder, "RIG_" + file), out, imagej=True)
    # break

    with h5py.File(os.path.join(tiff_folder, "RIG_CENTS_" + file.replace("tiff", "h5").replace("tif", "h5")), 'w') as f:
        f.create_dataset("raw_img_centroids", data=pre_centroids)
        f.create_dataset("post_rig_centroids", data=rig_centroids)


100%|██████████| 2000/2000 [43:30<00:00,  1.31s/it]


(tensor([0.0089], device='cuda:0'), tensor([0.0222], device='cuda:0'), tensor([6.6283e-05], device='cuda:0'))


100%|██████████| 2000/2000 [43:23<00:00,  1.30s/it]


(tensor([-0.0161], device='cuda:0'), tensor([-0.0133], device='cuda:0'), tensor([0.0060], device='cuda:0'))


100%|██████████| 2000/2000 [45:31<00:00,  1.37s/it]


(tensor([-9.3132e-10], device='cuda:0'), tensor([-0.0056], device='cuda:0'), tensor([0.0090], device='cuda:0'))




In [36]:
with h5py.File(os.path.join(tiff_folder, "RIG_CENTS_" + file.replace("tiff", "h5").replace("tif", "h5")), 'w') as f:
    f.create_dataset("raw_img_centroids", data=pre_centroids)
    f.create_dataset("post_rig_centroids", data=rig_centroids)