In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread, imwrite
from scipy.signal import fftconvolve
from scipy.ndimage import shift as nd_shift
from tqdm import tqdm

INPUT_FOLDER = r"D:\Data\FK_P001_EX005_2025_03_21 60xWIA 1.5x CRISPRi mCh\TIF\results\stacked"
MASK_FOLDER = r"D:\Data\FK_P001_EX005_2025_03_21 60xWIA 1.5x CRISPRi mCh\TIF\results\stacked\cyto3_gpu_results"
OUTPUT_FOLDER = os.path.join(INPUT_FOLDER, "aligned_shift_bymask")
REPORT_FOLDER = os.path.join(INPUT_FOLDER, "cross_correlation_reports")
SUMMARY_REPORT = os.path.join(INPUT_FOLDER, "cross_correlation_summary.png")
MAX_SHIFT_DISPLAY = 50

In [2]:
def load_and_normalize(filepath, mask):
    stack = imread(filepath)
    fluorescence = stack[1].astype(np.float32)
    fluorescence = (fluorescence - np.mean(fluorescence)) / np.std(fluorescence)
    # Use the mask as channel 0
    dic = mask.astype(np.float32)
    dic = (dic - np.mean(dic)) / np.std(dic)
    return dic, fluorescence

In [3]:
def fft_cross_correlation(dic, fluorescence):
    correlation = fftconvolve(dic, fluorescence[::-1, ::-1], mode='same')
    max_y, max_x = np.unravel_index(np.argmax(correlation), correlation.shape)
    center_y, center_x = np.array(correlation.shape) // 2
    shift_y = max_y - center_y
    shift_x = max_x - center_x
    return correlation, (shift_y, shift_x)

In [4]:
def plot_correlation_heatmap(correlation, shift, filename, save_folder, max_shift_display=50):
    center_y, center_x = np.array(correlation.shape) // 2
    cropped = correlation[
        center_y - max_shift_display:center_y + max_shift_display + 1,
        center_x - max_shift_display:center_x + max_shift_display + 1
    ]
    extent = [-max_shift_display, MAX_SHIFT_DISPLAY, MAX_SHIFT_DISPLAY, -MAX_SHIFT_DISPLAY]
    plt.figure(figsize=(6,5))
    plt.imshow(cropped, cmap='viridis', extent=extent)
    plt.scatter(shift[1], shift[0], color='red', label=f'Shift: ({shift[1]}, {shift[0]})')
    plt.colorbar(label='Cross-correlation')
    plt.xlabel('X shift (pixels)')
    plt.ylabel('Y shift (pixels)')
    plt.title('FFT-based Cross-Correlation')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    outpath = os.path.join(save_folder, f"Cross-Correlation_{os.path.splitext(filename)[0]}.png")
    plt.savefig(outpath)
    plt.close()

In [5]:
def plot_shift_summary(shifts, max_shift_display, outpath):
    shifts = np.array(shifts)
    mean_shift = np.mean(shifts, axis=0)
    plt.figure(figsize=(6,6))
    plt.scatter(shifts[:,1], shifts[:,0], c='blue', alpha=0.7, label='Individual Shifts')
    plt.scatter(mean_shift[1], mean_shift[0], c='red', marker='x', s=100, label=f"Mean Shift ({mean_shift[1]:.2f}, {mean_shift[0]:.2f})")
    plt.xlim([-max_shift_display, max_shift_display])
    plt.ylim([-max_shift_display, max_shift_display])
    plt.xlabel('X shift (pixels)')
    plt.ylabel('Y shift (pixels)')
    plt.title('Summary of Channel Shifts')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(outpath)
    plt.close()
    return mean_shift

In [6]:
def get_mask_path(stack_filename, mask_folder):
    base_name = stack_filename.replace(".tif", "")
    mask_filename = base_name.replace("_stack", "_stack_masks.tif")
    mask_path = os.path.join(mask_folder, mask_filename)
    return mask_path

In [7]:
def process_stack(filepath, output_folder, report_folder, max_shift_display, mask_folder):
    filename = os.path.basename(filepath)
    mask_path = get_mask_path(filename, mask_folder)

    if not os.path.exists(mask_path):
        print(f"Mask not found for {filename}, skipping")
        return filename, (0, 0)

    mask = imread(mask_path)
    # Convert mask to binary
    mask = (mask > 0).astype(np.uint8)

    dic, fluorescence = load_and_normalize(filepath, mask)
    correlation, shift = fft_cross_correlation(dic, fluorescence)

    plot_correlation_heatmap(correlation, shift, filename, report_folder, max_shift_display)

    # Apply the shift to the fluorescence channel
    shifted_fluorescence = nd_shift(fluorescence, shift, mode='nearest')

    # Create a stack with the mask and shifted fluorescence
    output_stack = np.stack([mask, shifted_fluorescence])
    output_path = os.path.join(output_folder, f"aligned_{filename}")
    imwrite(output_path, output_stack.astype(np.float32))  # Save as float32

    return filename, shift

In [8]:
def plot_shift_summary(shifts, max_shift_display, outpath):
    shifts = np.array(shifts)
    mean_shift = np.mean(shifts, axis=0)
    plt.figure(figsize=(6,6))
    plt.scatter(shifts[:,1], shifts[:,0], c='blue', alpha=0.7, label='Individual Shifts')
    plt.scatter(mean_shift[1], mean_shift[0], c='red', marker='x', s=100, label=f"Mean Shift ({mean_shift[1]:.2f}, {mean_shift[0]:.2f})")
    plt.xlim([-max_shift_display, max_shift_display])
    plt.ylim([-max_shift_display, max_shift_display])
    plt.xlabel('X shift (pixels)')
    plt.ylabel('Y shift (pixels)')
    plt.title('Summary of Channel Shifts')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(outpath)
    plt.close()
    return mean_shift

In [None]:
def main():
    tif_files = [f for f in os.listdir(INPUT_FOLDER) if f.lower().endswith('.tif')]
    shifts = []
    filenames = []
    for filename in tqdm(tif_files, desc="Processing stacks"):
        filepath = os.path.join(INPUT_FOLDER, filename)
        fname, shift = process_stack(filepath, OUTPUT_FOLDER, REPORT_FOLDER, MAX_SHIFT_DISPLAY, MASK_FOLDER)
        shifts.append(shift)
        filenames.append(fname)
    shifts = np.array(shifts)
    filenames = np.array(filenames)
    # Filter out large shifts
    valid_mask = (np.abs(shifts[:, 0]) <= MAX_SHIFT_DISPLAY) & (np.abs(shifts[:, 1]) <= MAX_SHIFT_DISPLAY)
    valid_shifts = shifts[valid_mask]
    valid_filenames = filenames[valid_mask]
    mean_shift = plot_shift_summary(valid_shifts, MAX_SHIFT_DISPLAY, SUMMARY_REPORT)
    print(f"\nProcessed {len(tif_files)} files.")
    print(f"Valid shifts for summary: {len(valid_shifts)}")
    print(f"Mean shift: (Y: {mean_shift[0]:.2f}, X: {mean_shift[1]:.2f})")
    print(f"Summary plot saved to: {SUMMARY_REPORT}")
    # Optionally, print outliers
    outlier_filenames = filenames[~valid_mask]
    if len(outlier_filenames) > 0:
        print(f"\nExcluded {len(outlier_filenames)} outlier(s) from summary (shift > {MAX_SHIFT_DISPLAY}):")
        for fname in outlier_filenames:
            print(f"  {fname}")

if __name__ == "__main__":
    main()

Processing stacks:  11%|██████▌                                                       | 41/389 [02:30<22:12,  3.83s/it]