In [23]:
from scipy.ndimage import gaussian_filter
from skimage.morphology import remove_small_objects
from skimage.measure import label
import matplotlib.pyplot as plt
from tifffile import imread, imwrite
import cv2
import numpy as np
from scipy import stats
import os

In [8]:
def get_mask(im):
    threshold = 200 # threshold for green channel

    mask = (im[:,:,1] > threshold).astype(float)
    blurred = gaussian_filter(mask, sigma=2)
    mask = blurred < 0.70

    cc_threshold = 100000 # connected components threshold size (keep above)
    mask = remove_small_objects(mask, min_size=cc_threshold)

    labeled_mask = label(mask, connectivity=2)
    return labeled_mask

In [18]:
im_pth = 'test_data/cropped_sections'
images = sorted([os.path.join(im_pth, f) for f in os.listdir(im_pth) if f.endswith('.tif')])

In [27]:
def downsample_mask(mask: np.ndarray, ds: int) -> np.ndarray:
    """Down‐sample a 2D mask by integer factor ds."""
    return mask[::ds, ::ds]

def estimate_affine_transform(
    mask_ref_ds: np.ndarray,
    mask_mov_ds: np.ndarray,
    criteria: tuple,
    warp_init: np.ndarray = None
) -> (np.ndarray, float):
    """
    Estimate a 2×3 affine warp that maps mask_mov_ds → mask_ref_ds
    via ECC. Returns (warp_matrix, correlation_coefficient).
    """
    if warp_init is None:
        warp_init = np.eye(2, 3, dtype=np.float32)
    cc, warp_matrix = cv2.findTransformECC(
        templateImage = mask_ref_ds.astype(np.float32),
        inputImage    = mask_mov_ds.astype(np.float32),
        warpMatrix    = warp_init,
        motionType    = cv2.MOTION_AFFINE,
        criteria      = criteria,
        inputMask     = None,
        gaussFiltSize = 1
    )
    return warp_matrix

def compute_border_color(img: np.ndarray) -> tuple:
    """
    Compute the mode of each channel in an RGB image and
    return it as a BGR tuple for use in cv2.borderValue.
    """
    # faster than np.unique for 8‐bit data
    modes = []
    for c in range(3):
        channel = img[..., c].ravel()
        # bincount of 0–255
        counts = np.bincount(channel, minlength=256)
        modes.append(int(np.argmax(counts)))
    # modes = [mode_R, mode_G, mode_B]
    return (modes[2], modes[1], modes[0])

def upsample_warp_matrix(warp_ds: np.ndarray, ds: int) -> np.ndarray:
    """
    Given a 2×3 matrix estimated on downsampled masks,
    scale the translation terms back up to full‐res.
    """
    M_full = warp_ds.copy()
    M_full[0, 2] *= ds
    M_full[1, 2] *= ds
    return M_full

def warp_image_with_affine(
    img: np.ndarray,
    M: np.ndarray,
    output_shape: tuple,
    border_color: tuple,
    interp=cv2.INTER_NEAREST
) -> np.ndarray:
    """
    Apply a 2×3 affine M to img, producing output_shape=(w,h),
    filling outside pixels with border_color.
    """
    flags = interp | cv2.WARP_INVERSE_MAP
    return cv2.warpAffine(
        src         = img,
        M           = M,
        dsize       = output_shape,
        flags       = flags,
        borderMode  = cv2.BORDER_CONSTANT,
        borderValue = border_color
    )

In [28]:
ds = 8
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 5000, 1e-7)

# Gather your input TIFFs
im_pth = 'test_data/cropped_sections'
images = sorted([
    os.path.join(im_pth, f)
    for f in os.listdir(im_pth)
    if f.lower().endswith('.tif')
])

# Prepare output folder
out_dir = 'test_data/cropped_sections/aligned'
if not os.path.exists(out_dir):
    os.makedirs(out_dir, exist_ok=True)

# 1) Save the first image unchanged, and build its mask
ref_path = images[0]
current_ref = imread(ref_path)
imwrite(os.path.join(out_dir, os.path.basename(ref_path)), current_ref)
current_ref_mask = get_mask(current_ref)

In [30]:
for img_path in images[1:]:
    mov = imread(img_path)
    mov_mask = get_mask(mov)

    # downsample masks
    r_ds = downsample_mask(current_ref_mask, ds)
    m_ds = downsample_mask(mov_mask, ds)

    # ECC on downsampled → upsample warp
    warp_ds = estimate_affine_transform(r_ds, m_ds, criteria)
    M_full  = upsample_warp_matrix(warp_ds, ds)

    # compute fill color from moving image
    fill = compute_border_color(mov)

    # warp and save
    h, w = current_ref.shape[:2]
    aligned = warp_image_with_affine(mov, M_full, (w, h), fill)
    out_p  = os.path.join(out_dir, os.path.basename(img_path))
    imwrite(out_p, aligned)

    # update reference for next iteration
    current_ref = aligned
    current_ref_mask = get_mask(current_ref)

In [None]:
aligned_im_path = 'test_data/cropped_sections/aligned'
aligned_ims = sorted([os.path.join(aligned_im_path, f) for f in os.listdir(aligned_im_path) if f.endswith('.tif')])

# Determine grid size: 4 columns, dynamic rows
n = len(aligned_ims)
cols = 4
rows = round(n / cols)

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
axes = axes.flatten()

# Plot each image in its own subplot
for ax, img_path in zip(axes, aligned_ims):
    img = imread(img_path)
    ax.imshow(img)
    ax.set_title('Aligned ' + os.path.basename(img_path), fontsize=8)
    ax.axis('off')

# Turn off any extra axes if number of images < rows*cols
for ax in axes[n:]:
    ax.axis('off')

plt.tight_layout()
plt.show()