In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread, imwrite
from scipy.signal import fftconvolve
from scipy.ndimage import shift as nd_shift
from tqdm import tqdm

INPUT_FOLDER = r"L:\43-RVZ\AIMicroscopy\Mitarbeiter\2_Data\1_NikonTi2\2025_07_29 60xWI 1.5x nanobody titration\TIF\stacked"
SEGMENTATION_FOLDER = r"L:\43-RVZ\AIMicroscopy\Mitarbeiter\2_Data\1_NikonTi2\2025_07_29 60xWI 1.5x nanobody titration\TIF\stacked"
OUTPUT_FOLDER = os.path.join(SEGMENTATION_FOLDER, "aligned_mean_shift")
REPORT_FOLDER = os.path.join(SEGMENTATION_FOLDER, "cross_correlation_reports")
SUMMARY_REPORT = os.path.join(SEGMENTATION_FOLDER, "cross_correlation_summary.png")
MAX_SHIFT_DISPLAY = 50

os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs(REPORT_FOLDER, exist_ok=True)

In [2]:
def load_and_normalize(filepath):
    stack = imread(filepath)
    dic = stack[0].astype(np.float32)
    dic = (dic - np.mean(dic)) / np.std(dic)
    fluorescence = stack[1].astype(np.float32)
    fluorescence = (fluorescence - np.mean(fluorescence)) / np.std(fluorescence)
    return dic, fluorescence

In [3]:
def fft_cross_correlation(dic, fluorescence):
    correlation = fftconvolve(dic, fluorescence[::-1, ::-1], mode='same')
    max_y, max_x = np.unravel_index(np.argmax(correlation), correlation.shape)
    center_y, center_x = np.array(correlation.shape) // 2
    shift_y = max_y - center_y
    shift_x = max_x - center_x
    return correlation, (shift_y, shift_x)

In [4]:
def plot_correlation_heatmap(correlation, shift, filename, save_folder, max_shift_display=50):
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    center_y, center_x = np.array(correlation.shape) // 2
    cropped = correlation[
        center_y - max_shift_display:center_y + max_shift_display + 1,
        center_x - max_shift_display:center_x + max_shift_display + 1
    ]
    extent = [-max_shift_display, MAX_SHIFT_DISPLAY, MAX_SHIFT_DISPLAY, -MAX_SHIFT_DISPLAY]
    plt.figure(figsize=(6,5))
    plt.imshow(cropped, cmap='viridis', extent=extent)
    plt.scatter(shift[1], shift[0], color='red', label=f'Shift: ({shift[1]}, {shift[0]})')
    plt.colorbar(label='Cross-correlation')
    plt.xlabel('X shift (pixels)')
    plt.ylabel('Y shift (pixels)')
    plt.title('FFT-based Cross-Correlation')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    outpath = os.path.join(save_folder, f"Cross-Correlation_{os.path.splitext(filename)[0]}.png")
    plt.savefig(outpath)
    plt.close()

In [5]:
def plot_shift_summary(shifts, max_shift_display, outpath):
    shifts = np.array(shifts)
    mean_shift = np.mean(shifts, axis=0)
    plt.figure(figsize=(6,6))
    plt.scatter(shifts[:,1], shifts[:,0], c='blue', alpha=0.7, label='Individual Shifts')
    plt.scatter(mean_shift[1], mean_shift[0], c='red', marker='x', s=100, label=f"Mean Shift ({mean_shift[1]:.2f}, {mean_shift[0]:.2f})")
    plt.xlim([-max_shift_display/2, max_shift_display/2])
    plt.ylim([-max_shift_display/2, max_shift_display/2])
    plt.xlabel('X shift (pixels)')
    plt.ylabel('Y shift (pixels)')
    plt.title('Summary of Channel Shifts')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(outpath)
    plt.close()
    return mean_shift

In [6]:
def calculate_shifts(input_folder, max_shift_display, report_folder):
    tif_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.tif')]
    shifts = []
    filenames = []

    for filename in tqdm(tif_files, desc="Calculating shifts"):
        filepath = os.path.join(input_folder, filename)
        try:
            dic, fluorescence = load_and_normalize(filepath)
        except Exception as e:
            print(f"Error reading {filename}: {e}")
            continue
        correlation, shift = fft_cross_correlation(dic, fluorescence)
        shifts.append(shift)
        filenames.append(filename)
        plot_correlation_heatmap(correlation, shift, filename, report_folder, max_shift_display)
    return np.array(shifts), filenames

In [7]:
def apply_mean_shift_and_crop(input_folder, output_folder, mean_shift):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    tif_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.tif')]
    for filename in tqdm(tif_files, desc="Applying mean shift and cropping"):
        filepath = os.path.join(input_folder, filename)
        stack = imread(filepath)
        shift_y, shift_x = mean_shift
        if shift_y > 0:
            crop_top = int(shift_y)
            crop_bottom = 0
        else:
            crop_top = 0
            crop_bottom = int(abs(shift_y))
        if shift_x > 0:
            crop_left = int(shift_x)
            crop_right = 0
        else:
            crop_left = 0
            crop_right = int(abs(shift_x))
        cropped_channels = []
        for channel in stack:
            cropped_channel = channel[crop_top:channel.shape[0]-crop_bottom, crop_left:channel.shape[1]-crop_right]
            cropped_channels.append(cropped_channel)
        cropped_stack = np.stack(cropped_channels)
        output_path = os.path.join(output_folder, f"aligned_{filename}")
        imwrite(output_path, cropped_stack)

In [8]:
def main():
    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    os.makedirs(REPORT_FOLDER, exist_ok=True)
    shifts, filenames = calculate_shifts(INPUT_FOLDER, MAX_SHIFT_DISPLAY, REPORT_FOLDER)
    valid_mask = (np.abs(shifts[:, 0]) <= MAX_SHIFT_DISPLAY) & (np.abs(shifts[:, 1]) <= MAX_SHIFT_DISPLAY)
    valid_shifts = shifts[valid_mask]
    valid_filenames = np.array(filenames)[valid_mask]
    mean_shift = plot_shift_summary(valid_shifts, MAX_SHIFT_DISPLAY, SUMMARY_REPORT)
    print(f"\nProcessed {len(filenames)} files.")
    print(f"Valid shifts for summary: {len(valid_shifts)}")
    print(f"Mean shift: (Y: {mean_shift[0]:.2f}, X: {mean_shift[1]:.2f})")
    print(f"Summary plot saved to: {SUMMARY_REPORT}")
    apply_mean_shift_and_crop(INPUT_FOLDER, OUTPUT_FOLDER, mean_shift)

if __name__ == "__main__":
    main()

Calculating shifts: 100%|██████████████████████████████████████████████████████████████| 93/93 [04:17<00:00,  2.77s/it]



Processed 93 files.
Valid shifts for summary: 89
Mean shift: (Y: -14.36, X: -4.76)
Summary plot saved to: L:\43-RVZ\AIMicroscopy\Mitarbeiter\2_Data\1_NikonTi2\2025_07_29 60xWI 1.5x nanobody titration\TIF\stacked\cross_correlation_summary.png


Applying mean shift and cropping: 100%|████████████████████████████████████████████████| 93/93 [01:15<00:00,  1.23it/s]
