In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib widget

from time import time
import numpy as np
from pathlib import Path
import tifffile
import concurrent.futures
from typing import List, Tuple, Dict
import napari
import SimpleITK as sitk

from vrAnalysis import fileManagement as fm

In [2]:
def load_and_organize_stack(folder_path: str, num_planes: int, flyback=0, pattern: str = "*_00001_*") -> List[np.ndarray]:
    """
    Load and organize a multi-channel z-stack from tiff files.
    
    Parameters
    ----------
    folder_path : str
        Path to folder containing tiff files
    num_planes: int
        Number of planes per stack
    flyback: int, optional
        Number of flyback planes, if more than 0 will be removed from head of stack
    pattern : str, optional
        Glob pattern to match relevant files
        
    Returns
    -------
    stacks : List[np.ndarray]
        List of 3D arrays of data from each color channel (z, y, x)
    """
    folder = Path(folder_path)
    files = sorted(folder.glob(pattern))
    
    if not files:
        raise ValueError(f"No files found matching pattern {pattern}")
    
    # Load all tiff files in parallel
    with concurrent.futures.ThreadPoolExecutor() as executor:
        tiff_data = list(executor.map(tifffile.imread, files))
    
    # Read first file to get dimensions
    if tiff_data[0].ndim == 3:
        n_channels = 1
        height, width = tiff_data[0].shape[1:]
    else:
        n_channels, height, width = tiff_data[0].shape[1:]

    # Preallocate list for storing data from each channel 
    channels = [[] for _ in range(n_channels)]

    for data in tiff_data:
        if n_channels == 1:
            data = data.reshape(data.shape[0], 1, height, width)
        
        for ichannel in range(n_channels):
            channels[ichannel].append(data[:, ichannel])

        data = None
    
    # Combine all files
    channels = [np.concatenate(channel, axis=0) for channel in channels]
    num_extra = channels[0].shape[0] % num_planes
    channels = [np.concatenate((channel, np.full((num_planes - num_extra, *channel.shape[1:]), np.nan)), axis=0) for channel in channels]
    num_repeats = channels[0].shape[0] // num_planes

    # Reshape by plane
    stacks = [np.nanmean(channel.reshape(num_repeats, num_planes, *channel.shape[1:]), axis=0) for channel in channels]
    
    # Remove flyback planes
    if flyback > 0:
        stacks = [stack[flyback:] for stack in stacks]
        
    return stacks

In [44]:
from enum import Enum
from typing import Tuple, Optional, Dict, Any
import SimpleITK as sitk
import numpy as np

class MetricType(Enum):
    MUTUAL_INFO = 'mi'
    CORRELATION = 'nc'
    MEAN_SQUARES = 'ms'

def register_translation(
    moving: np.ndarray,
    fixed: np.ndarray,
    verbose: bool = False
) -> Tuple[np.ndarray, sitk.Transform]:
    """
    Register two 3D image stacks using translation only.
    
    Parameters
    ----------
    moving : np.ndarray
        Moving image stack (z, y, x)
    fixed : np.ndarray
        Fixed image stack (z, y, x)
    verbose : bool
        Print progress
        
    Returns
    -------
    registered : np.ndarray
        Registered moving image
    transform : sitk.Transform
        The computed transform
    """
    # Convert to SimpleITK images
    moving_sitk = sitk.GetImageFromArray(moving.astype(np.float32))
    fixed_sitk = sitk.GetImageFromArray(fixed.astype(np.float32))
    
    # Initialize registration
    registration = sitk.ImageRegistrationMethod()
    
    # Translation transform
    transform = sitk.TranslationTransform(3)
    registration.SetInitialTransform(transform)
    
    # Simple normalized correlation metric
    registration.SetMetricAsCorrelation()
    
    # Simple optimizer
    registration.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=100
    )
    
    # Linear interpolation
    registration.SetInterpolator(sitk.sitkLinear)
    
    if verbose:
        def callback(filter):
            print(f"Iteration: {filter.GetOptimizerIteration()}, "
                  f"Metric: {filter.GetMetricValue():.4f}")
        registration.AddCommand(sitk.sitkIterationEvent,
                              lambda: callback(registration))
    
    # Perform registration
    final_transform = registration.Execute(fixed_sitk, moving_sitk)
    
    # Apply transform
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_sitk)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetTransform(final_transform)
    
    registered_sitk = resampler.Execute(moving_sitk)
    registered = sitk.GetArrayFromImage(registered_sitk)
    
    if verbose:
        # Print the final translation
        translation = final_transform.GetParameters()
        print(f"\nFinal translation (x,y,z): {translation}")
    
    return registered, final_transform



def register_constrained_rigid(
    moving: np.ndarray,
    fixed: np.ndarray,
    max_rotation_degrees: float = 15.0,  # Maximum rotation allowed in any direction
    translation_scales: Tuple[float, float, float] = (1.0, 1.0, 0.5),  # Scale factors for x,y,z translation
    verbose: bool = False
) -> Tuple[np.ndarray, sitk.Transform]:
    """
    Register two 3D image stacks using constrained rigid transformation.
    
    Parameters
    ----------
    moving : np.ndarray
        Moving image stack (z, y, x)
    fixed : np.ndarray
        Fixed image stack (z, y, x)
    max_rotation_degrees : float
        Maximum allowed rotation in any direction (degrees)
    translation_scales : tuple
        Scale factors for translation optimization (x,y,z)
        Use smaller values for z if z-resolution is different
    verbose : bool
        Print progress
    """
    # Convert to SimpleITK images
    moving_sitk = sitk.GetImageFromArray(moving.astype(np.float32))
    fixed_sitk = sitk.GetImageFromArray(fixed.astype(np.float32))
    
    # Initialize registration
    registration = sitk.ImageRegistrationMethod()
    
    # Euler3D transform (rotation + translation)
    transform = sitk.Euler3DTransform()
    registration.SetInitialTransform(transform)
    
    # Set transform parameter scaling to constrain rotation
    # Order is: rotX, rotY, rotZ, transX, transY, transZ
    max_rot_rad = np.deg2rad(max_rotation_degrees)
    scales = [1/max_rot_rad]*3 + [1.0/s for s in translation_scales]
    registration.SetOptimizerScales(scales)
    
    # Regular step gradient descent with small steps for rotation
    registration.SetOptimizerAsRegularStepGradientDescent(
        learningRate=0.1,
        minStep=0.0001,
        numberOfIterations=200,
        relaxationFactor=0.5,
        gradientMagnitudeTolerance=1e-6
    )
    
    # Normalized correlation metric
    registration.SetMetricAsCorrelation()
    
    # Linear interpolation
    registration.SetInterpolator(sitk.sitkLinear)
    
    # Multi-resolution approach
    registration.SetShrinkFactorsPerLevel([4, 2, 1])
    registration.SetSmoothingSigmasPerLevel([2, 1, 0])
    
    if verbose:
        def callback(filter):
            if filter.GetOptimizerIteration() == 0:
                print("\nStarting registration...")
            print(f"Iteration: {filter.GetOptimizerIteration()}, "
                  f"Metric: {filter.GetMetricValue():.4f}")
            # Check if rotations are within bounds
            params = filter.GetOptimizerPosition()
            rotations_deg = np.rad2deg(params[:3])
            if np.any(np.abs(rotations_deg) > max_rotation_degrees):
                print("Warning: Rotation exceeding specified maximum!")
        registration.AddCommand(sitk.sitkIterationEvent,
                              lambda: callback(registration))
    
    # Perform registration
    final_transform = registration.Execute(fixed_sitk, moving_sitk)
    
    if verbose:
        # Print the final parameters
        params = final_transform.GetParameters()
        rotations_deg = np.rad2deg(params[:3])
        translations = params[3:]
        print("\nFinal transform parameters:")
        print(f"Rotations (deg) - X: {rotations_deg[0]:.2f}, "
              f"Y: {rotations_deg[1]:.2f}, Z: {rotations_deg[2]:.2f}")
        print(f"Translations - X: {translations[0]:.2f}, "
              f"Y: {translations[1]:.2f}, Z: {translations[2]:.2f}")
    
    # Apply transform
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_sitk)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetTransform(final_transform)
    
    registered_sitk = resampler.Execute(moving_sitk)
    registered = sitk.GetArrayFromImage(registered_sitk)
    
    return registered, final_transform



def register_zstacks(
    moving: np.ndarray,
    fixed: np.ndarray,
    transform_type: str = 'rigid',
    metric: MetricType = MetricType.CORRELATION,
    params: Optional[Dict[str, Any]] = None,
    verbose: bool = False
) -> Tuple[np.ndarray, sitk.Transform]:
    """
    Register two 3D image stacks using SimpleITK.
    
    Parameters
    ----------
    moving : np.ndarray
        Moving image stack (z, y, x)
    fixed : np.ndarray
        Fixed image stack (z, y, x)
    transform_type : str
        Type of registration: 'rigid', 'affine', or 'rigid+nonrigid'
    metric : MetricType
        Similarity metric to use
    params : dict, optional
        Override default parameters for registration
    verbose : bool
        Whether to print registration progress
        
    Returns
    -------
    registered : np.ndarray
        Registered moving image
    transform : sitk.Transform
        The computed transform(s)
    """
    # Set default parameters
    default_params = {
        'shrink_factors': (4, 2, 1),
        'smooth_sigmas': (2, 1, 0),
        'learning_rate': 0.1,
        'iterations': 200,
        'grad_tol': 1e-6,
        'sampling_percentage': 0.01
    }
    if params:
        default_params.update(params)
    params = default_params

    # Convert to SimpleITK images and normalize
    moving_sitk = sitk.GetImageFromArray(moving.astype(np.float32))
    fixed_sitk = sitk.GetImageFromArray(fixed.astype(np.float32))
    moving_sitk = sitk.Normalize(moving_sitk)
    fixed_sitk = sitk.Normalize(fixed_sitk)

    if transform_type == 'rigid+nonrigid':
        # First do rigid registration
        registered_sitk, rigid_transform = _register_single_transform(
            moving_sitk, fixed_sitk,
            transform='rigid',
            metric=metric,
            params=params,
            verbose=verbose
        )
        # Then do nonrigid registration starting from rigid result
        registered_sitk, bspline_transform = _register_single_transform(
            registered_sitk, fixed_sitk,
            transform='bspline',
            metric=metric,
            params=params,
            verbose=verbose
        )
        # Combine transforms
        composite_transform = sitk.CompositeTransform([bspline_transform, rigid_transform])
        return sitk.GetArrayFromImage(registered_sitk), composite_transform
    else:
        # Single transform registration
        registered_sitk, transform = _register_single_transform(
            moving_sitk, fixed_sitk,
            transform=transform_type,
            metric=metric,
            params=params,
            verbose=verbose
        )
        return sitk.GetArrayFromImage(registered_sitk), transform

def _register_single_transform(
    moving_sitk: sitk.Image,
    fixed_sitk: sitk.Image,
    transform: str,
    metric: MetricType,
    params: Dict[str, Any],
    verbose: bool
) -> Tuple[sitk.Image, sitk.Transform]:
    """Helper function for single-transform registration."""
    
    registration = sitk.ImageRegistrationMethod()
    
    # Set up transform
    if transform == 'rigid':
        initial_transform = sitk.Euler3DTransform()
    elif transform == 'affine':
        initial_transform = sitk.AffineTransform(3)
    elif transform == 'bspline':
        transform_domain_mesh_size = [8] * 3
        initial_transform = sitk.BSplineTransformInitializer(
            fixed_sitk, transform_domain_mesh_size
        )
    else:
        raise ValueError(f"Unknown transform type: {transform}")
    
    registration.SetInitialTransform(initial_transform)
    
    # Set up metric
    if metric == MetricType.MUTUAL_INFO:
        registration.SetMetricAsMattesMutualInformation()
    elif metric == MetricType.CORRELATION:
        registration.SetMetricAsCorrelation()
    elif metric == MetricType.MEAN_SQUARES:
        registration.SetMetricAsMeanSquares()
    
    # Set up optimizer
    if transform == 'bspline':
        registration.SetOptimizerAsLBFGSB(
            gradientConvergenceTolerance=params['grad_tol'],
            numberOfIterations=params['iterations'],
            maximumNumberOfCorrections=5,
            maximumNumberOfFunctionEvaluations=1000,
            costFunctionConvergenceFactor=1e7
        )
    else:
        registration.SetOptimizerAsGradientDescent(
            learningRate=params['learning_rate'],
            numberOfIterations=params['iterations'],
            convergenceMinimumValue=params['grad_tol'],
            convergenceWindowSize=10
        )
    
    # Set up sampling
    registration.SetMetricSamplingPercentage(params['sampling_percentage'])
    registration.SetMetricSamplingStrategy(registration.RANDOM)
    
    # Set up interpolator
    registration.SetInterpolator(sitk.sitkLinear)
    
    # Set up multi-resolution framework
    registration.SetShrinkFactorsPerLevel(params['shrink_factors'])
    registration.SetSmoothingSigmasPerLevel(params['smooth_sigmas'])
    
    if verbose:
        def callback(filter):
            print(f"Iteration: {filter.GetOptimizerIteration()}, "
                  f"Metric: {filter.GetMetricValue():.4f}")
        registration.AddCommand(sitk.sitkIterationEvent,
                              lambda: callback(registration))
    
    # Perform registration
    final_transform = registration.Execute(fixed_sitk, moving_sitk)
    
    # Apply transform
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_sitk)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetTransform(final_transform)
    
    return resampler.Execute(moving_sitk), final_transform


def visualize_registration(moving: np.ndarray, fixed: np.ndarray, registered: np.ndarray) -> None:
    """
    Visualize registration results using napari.
    
    Parameters
    ----------
    moving : np.ndarray
        Original moving image
    fixed : np.ndarray
        Fixed image
    registered : np.ndarray
        Registered moving image
    """
    viewer = napari.Viewer()
    
    # Add layers
    viewer.add_image(fixed, name='Fixed', colormap='green', blending='additive')
    viewer.add_image(registered, name='Registered', colormap='red', blending='additive')
    viewer.add_image(moving, name='Original Moving', colormap='orange', blending='additive', visible=False)
    
    return viewer

def compare_metrics(moving: np.ndarray, fixed: np.ndarray, metrics: list = ['mi', 'nc', 'ms'], verbose=False) -> Dict[str, Tuple[np.ndarray, sitk.Transform]]:
    """
    Compare different registration metrics.
    
    Parameters
    ----------
    moving : np.ndarray
        Moving image stack
    fixed : np.ndarray
        Fixed image stack
    metrics : list
        List of metrics to try
    verbose : bool
        Whether to print registration progress
        
    Returns
    -------
    dict
        Dictionary of results for each metric
    """
    results = {}
    
    for metric in metrics:
        print(f"Trying {metric} metric...")
        registered, transform = register_zstacks(moving, fixed, metric=metric, verbose=verbose)
        results[metric] = (registered, transform)
        
    return results

In [46]:
import numpy as np
from typing import Tuple
from scipy import fftpack

def phase_correlation_3d(
    moving: np.ndarray,
    fixed: np.ndarray,
    upsample_factor: int = 1
) -> Tuple[np.ndarray, Tuple[int, int, int]]:
    """
    Register two 3D image stacks using phase correlation.
    
    Parameters
    ----------
    moving : np.ndarray
        Moving image stack (z, y, x)
    fixed : np.ndarray
        Fixed image stack (z, y, x)
    upsample_factor : int
        Factor for subpixel accuracy. 1 means pixel accuracy.
        
    Returns
    -------
    registered : np.ndarray
        Registered moving image
    shifts : tuple
        (z, y, x) shifts needed to register images
    """
    # Ensure float arrays
    moving = moving.astype(float)
    fixed = fixed.astype(float)
    
    # Compute FFTs
    fixed_fft = fftpack.fftn(fixed)
    moving_fft = fftpack.fftn(moving)
    
    # Compute cross-power spectrum
    cross_power = moving_fft * fixed_fft.conjugate()
    
    # Normalize to get only phase information
    eps = np.finfo(float).eps
    normalized_cross_power = cross_power / (np.abs(cross_power) + eps)
    
    # Compute phase correlation
    correlation = np.real(fftpack.ifftn(normalized_cross_power))
    
    # Find shift that gives maximum correlation
    shifts = np.unravel_index(np.argmax(correlation), correlation.shape)
    
    # Convert shifts to centered coordinates
    shifts = [
        s if s <= shape//2 else s - shape 
        for s, shape in zip(shifts, correlation.shape)
    ]
    
    # Apply shift to register images
    registered = np.roll(np.roll(np.roll(
        moving,
        shifts[0], axis=0),
        shifts[1], axis=1),
        shifts[2], axis=2)
    
    return registered, tuple(shifts)

In [5]:
# parameters
num_planes = 30
flyback = 4
folder_path_0 = fm.serverPath(zaru=False) / "ATL060" / "2024-07-12" / "702"
folder_path_1 = fm.serverPath(zaru=False) / "ATL060" / "2024-07-15" / "702"

# get stacks
stacks_0 = load_and_organize_stack(folder_path_0, num_planes, flyback=flyback)
stacks_1 = load_and_organize_stack(folder_path_1, num_planes, flyback=flyback)

In [48]:
# Load and preprocess stacks
moving_stack = stacks_0[1]
fixed_stack = stacks_1[1]   

# Try registration
registered, transform = register_constrained_rigid(moving_stack, fixed_stack, max_rotation_degrees=5.0, translation_scales=(1.0, 1.0, 0.5), verbose=False)
params = transform.GetParameters()
translations = params[3:]
print("Translations:")
print(f"X: {translations[0]:.3f}")
print(f"Y: {translations[1]:.3f}")
print(f"Z: {translations[2]:.3f}")

registered, shifts = phase_correlation_3d(moving_stack, fixed_stack)
print(f"Detected shifts (z,y,x): {shifts}")

viewer = visualize_registration(moving_stack, fixed_stack, registered)


Translations:
X: 0.023
Y: 0.017
Z: -0.001
Detected shifts (z,y,x): (2, 1, 0)


In [23]:
# For rigid registration
registered, transform = register_zstacks(stacks_0[1], stacks_1[1], transform_type='rigid', verbose=False)

# # For rigid + nonrigid registration
# registered, transform = register_zstacks(
#     moving_stack, 
#     fixed_stack,
#     transform_type='rigid+nonrigid',
#     metric=MetricType.CORRELATION,
#     params={'learning_rate': 0.05, 'iterations': 300},
#     verbose=True
# )

In [9]:
results = compare_metrics(stacks_0[1], stacks_1[1])


Trying mi metric...

Trying nc metric...

Trying ms metric...


In [24]:
viewer = visualize_registration(stacks_0[1], stacks_1[1], stacks_0[1])

In [31]:
# register stacks to each other
best_metric = 'nc'  # or whichever you found works best
best_channel = 0 # green to start
registered, transform = register_zstacks(stacks_0[best_channel], stacks_1[best_channel], metric=best_metric, verbose=False)

In [32]:
viewer = visualize_registration(stacks_0[best_channel], stacks_1[best_channel], registered)

In [29]:
import napari

# view stacks[0] in napari
viewer = napari.Viewer()
viewer.add_image(stacks_1[0], name='Channel 0', colormap='green', blending='additive')
viewer.add_image(stacks_1[1], name='Channel 1', colormap='red', blending='additive')

# show viewer
viewer.show()

In [14]:
import time
from imageio.v3 import imread
import concurrent.futures

def benchmark_loading(folder_path: str, pattern: str = "*_00001_*"):
    files = sorted(Path(folder_path).glob(pattern))
    
    # Test tifffile
    start = time.time()
    a = [tifffile.imread(f) for f in files]
    print(f"tifffile sequential: {time.time() - start:.2f}s")
    
    # Test tifffile with threading
    start = time.time()
    with concurrent.futures.ThreadPoolExecutor() as executor:
        b = list(executor.map(tifffile.imread, files))
    print(f"tifffile threaded: {time.time() - start:.2f}s")
    
    # Test imageio
    start = time.time()
    c = [imread(f) for f in files]
    print(f"imageio sequential: {time.time() - start:.2f}s")

    # Test tifffile again
    start = time.time()
    d = [tifffile.imread(f) for f in files]
    print(f"tifffile sequential again: {time.time() - start:.2f}s")

    print(all([np.allclose(x, y) for x, y in zip(a, b)]))
    print(all([np.allclose(x, y) for x, y in zip(a, c)]))
    print(all([np.allclose(x, y) for x, y in zip(a, d)]))

In [15]:
benchmark_loading(folder_path)

tifffile sequential: 16.22s
tifffile threaded: 0.97s
imageio sequential: 1.52s
tifffile sequential again: 1.36s
True
True
True
