In [None]:
import os
import time
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from skimage import io
from cupyx.scipy.ndimage import fourier_shift

###############################################################################
# Utility Functions
###############################################################################

def load_raw_stack_from_folder(folder, save_first_cropped=True):
    t0 = time.time()

    angle_order = ['45', '90', '135']
    approx_phase_values_rad = [0, 2*cp.pi/3, 4*cp.pi/3]
    angle_to_files = {angle: [] for angle in angle_order}

    print(f"\n📂 Scanning folder: {folder}")
    for f in os.listdir(folder):
        if f.lower().endswith((".tif", ".tiff")) and "angle" in f.lower():
            try:
                parts = f.split('-')
                angle = parts[0].split()[1]
                phase_str = parts[1].split('_')[1].split('.')[0] + '.' + parts[1].split('_')[1].split('.')[1]
                norm_phase_val = float(phase_str)
                phase_rad = norm_phase_val * (2 * cp.pi / 3) / 0.071
                full_path = os.path.join(folder, f)
                angle_to_files[angle].append((phase_rad, full_path))
                print(f"✔️ Found: {f} → angle={angle}, phase={phase_rad:.3f} rad")
            except Exception as e:
                print(f"⚠️ Skipped: {f} → Error: {e}")

    file_list = []
    for angle in angle_order:
        files = angle_to_files[angle]
        if len(files) < 3:
            raise ValueError(f"❌ Not enough files for angle {angle}. Found {len(files)}.")
        sorted_files = sorted(
            files,
            key=lambda x: approx_phase_values_rad.index(
                min(approx_phase_values_rad, key=lambda p: cp.abs(p - x[0]))
            )
        )
        file_list.extend([f[1] for f in sorted_files])

    if len(file_list) != 9:
        raise ValueError(f"\n❌ Only {len(file_list)} files selected (expected 9). Check your angles/phases.")

    images = []
    for i, f in enumerate(file_list):
        img = io.imread(f).astype(np.float32)

        # To Center crop to 50% put n = 2 (n = 1 for full size image)
        n = 1
        h, w = img.shape
        crop_h, crop_w = h // n, w // n
        start_h, start_w = (h - crop_h) // n, (w - crop_w) // n
        img_cropped = img[start_h:start_h + crop_h, start_w:start_w + crop_w]

        if i == 0 and save_first_cropped:
            ref_path = os.path.join(folder, "Raw_cropped_reference.tiff")
            io.imsave(ref_path, img_cropped.astype(np.float32))
            print(f"💾 Saved cropped reference image: {ref_path}")

        images.append(img_cropped)
        print(f"📏 Loaded & Cropped: {os.path.basename(f)} → shape: {img_cropped.shape}")

    stack_cpu = np.stack(images, axis=0)
    print(f"⏱️ Image loading + cropping time: {time.time() - t0:.2f} s")
    return cp.asarray(stack_cpu)


###############################################################################
# SIM Reconstruction (GPU-accelerated)
###############################################################################

def wiener_filter(f, otf, alpha=0.01):
    return cp.conj(otf) * f / (cp.abs(otf)**2 + alpha)

def sim_reconstruct(stack, 
                    phases=(0, 2*cp.pi/3, 4*cp.pi/3),
                    angles=(cp.pi/4, cp.pi/2, 3*cp.pi/4),
                    k_exc_rel=0.25,
                    alpha=0.01):
    t0 = time.time()

    Ny, Nx = stack.shape[1:]
    ky, kx = cp.meshgrid(cp.fft.fftfreq(Ny), cp.fft.fftfreq(Nx), indexing="ij")
    k_abs = cp.sqrt(kx**2 + ky**2)

    F = cp.fft.fftn(stack, axes=(-2, -1))
    orders = []

    phases_arr = cp.asarray(phases)

    for o, theta in enumerate(angles):
        i0 = o * 3
        Fphi = F[i0:i0+3]

        kx_unit = cp.cos(theta)
        ky_unit = cp.sin(theta)
        k_exc = k_exc_rel * cp.max(k_abs)

        P = cp.exp(1j * cp.outer(cp.arange(3), phases_arr)) / 3.0
        P_conj = cp.exp(-1j * cp.outer(cp.arange(3), phases_arr)) / 3.0

        F_plus  = cp.tensordot(P[1], Fphi, axes=(0, 0))
        F_minus = cp.tensordot(P_conj[1], Fphi, axes=(0, 0))
        F_zero  = cp.tensordot(P[0], Fphi, axes=(0, 0))

        shift_vec = (ky_unit * k_exc * Ny, kx_unit * k_exc * Nx)
        F_plus_shifted  = fourier_shift(F_plus,  shift_vec)
        F_minus_shifted = fourier_shift(F_minus, (-shift_vec[0], -shift_vec[1]))

        orders.extend([F_zero, F_plus_shifted, F_minus_shifted])

    otf_mask = k_abs <= 0.5
    otf = otf_mask.astype(cp.float32)

    F_comb = cp.zeros_like(orders[0])
    for F_ord in orders:
        F_comb += wiener_filter(F_ord, otf, alpha)

    hi_res = cp.abs(cp.fft.ifftn(F_comb))
    print(f"⏱️ SIM reconstruction time: {time.time() - t0:.2f} s")
    return hi_res


###############################################################################
# Main Execution
###############################################################################

if __name__ == "__main__":
    total_start = time.time()

    folder = r"C:\Users\sd80731\Desktop\Codes\OneDrive_2025-07-29\Tulane sample for sim"

    stack = load_raw_stack_from_folder(folder)
    print("\n✅ SIM stack loaded:", stack.shape)

    sim_img_gpu = sim_reconstruct(stack, alpha=0.02)
    sim_img = cp.asnumpy(sim_img_gpu)

    # Save SIM image
    output_path = os.path.join(folder, "SIM_output_halfsize.tiff")
    io.imsave(output_path, sim_img.astype(np.float32))
    print(f"💾 SIM image saved to: {output_path}")

    # Display
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(io.imread(os.path.join(folder, "Raw_cropped_reference.tiff")), cmap="gray")
    plt.title("Cropped Raw (Angle 45, Phase 0)")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(sim_img, cmap="gray")
    plt.title("SIM Reconstructed Image")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    print(f"\n⏱️ Total runtime: {time.time() - total_start:.2f} s")


In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from scipy.ndimage import fourier_shift

###############################################################################
# Utility Functions
###############################################################################

def load_raw_stack_from_folder(folder, save_first_cropped=True):
    t0 = time.time()

    angle_order = ['45', '90', '135']
    approx_phase_values_rad = [0, 2*np.pi/3, 4*np.pi/3]
    angle_to_files = {angle: [] for angle in angle_order}

    print(f"\n📂 Scanning folder: {folder}")
    for f in os.listdir(folder):
        if f.lower().endswith((".tif", ".tiff")) and "angle" in f.lower():
            try:
                parts = f.split('-')
                angle = parts[0].split()[1]
                phase_str = parts[1].split('_')[1].split('.')[0] + '.' + parts[1].split('_')[1].split('.')[1]
                norm_phase_val = float(phase_str)
                phase_rad = norm_phase_val * (2 * np.pi / 3) / 0.071
                full_path = os.path.join(folder, f)
                angle_to_files[angle].append((phase_rad, full_path))
                print(f"✔️ Found: {f} → angle={angle}, phase={phase_rad:.3f} rad")
            except Exception as e:
                print(f"⚠️ Skipped: {f} → Error: {e}")

    file_list = []
    for angle in angle_order:
        files = angle_to_files[angle]
        if len(files) < 3:
            raise ValueError(f"❌ Not enough files for angle {angle}. Found {len(files)}.")
        sorted_files = sorted(
            files,
            key=lambda x: approx_phase_values_rad.index(
                min(approx_phase_values_rad, key=lambda p: abs(p - x[0]))
            )
        )
        file_list.extend([f[1] for f in sorted_files])

    if len(file_list) != 9:
        raise ValueError(f"\n❌ Only {len(file_list)} files selected (expected 9). Check your angles/phases.")

    images = []
    for i, f in enumerate(file_list):
        img = io.imread(f).astype(np.float32)

        # To Center crop to 50%, put n = 2 (n = 1 for full-sized image)
        n = 1
        h, w = img.shape
        crop_h, crop_w = h // n, w // n
        start_h, start_w = (h - crop_h) // n, (w - crop_w) // n
        img_cropped = img[start_h:start_h + crop_h, start_w:start_w + crop_w]

        if i == 0 and save_first_cropped:
            ref_path = os.path.join(folder, "Raw_cropped_reference.tiff")
            io.imsave(ref_path, img_cropped.astype(np.float32))
            print(f"💾 Saved cropped reference image: {ref_path}")

        images.append(img_cropped)
        print(f"📏 Loaded & Cropped: {os.path.basename(f)} → shape: {img_cropped.shape}")

    stack = np.stack(images, axis=0)
    print(f"⏱️ Image loading + cropping time: {time.time() - t0:.2f} s")
    return stack


###############################################################################
# SIM Reconstruction (CPU-based)
###############################################################################

def wiener_filter(f, otf, alpha=0.01):
    return np.conj(otf) * f / (np.abs(otf)**2 + alpha)

def sim_reconstruct(stack, 
                    phases=(0, 2*np.pi/3, 4*np.pi/3),
                    angles=(np.pi/4, np.pi/2, 3*np.pi/4),
                    k_exc_rel=0.25,
                    alpha=0.01):
    t0 = time.time()

    Ny, Nx = stack.shape[1:]
    ky, kx = np.meshgrid(np.fft.fftfreq(Ny), np.fft.fftfreq(Nx), indexing="ij")
    k_abs = np.sqrt(kx**2 + ky**2)

    F = np.fft.fftn(stack, axes=(-2, -1))
    orders = []

    phases_arr = np.asarray(phases)

    for o, theta in enumerate(angles):
        i0 = o * 3
        Fphi = F[i0:i0+3]

        kx_unit = np.cos(theta)
        ky_unit = np.sin(theta)
        k_exc = k_exc_rel * np.max(k_abs)

        P = np.exp(1j * np.outer(np.arange(3), phases_arr)) / 3.0
        P_conj = np.exp(-1j * np.outer(np.arange(3), phases_arr)) / 3.0

        F_plus  = np.tensordot(P[1], Fphi, axes=(0, 0))
        F_minus = np.tensordot(P_conj[1], Fphi, axes=(0, 0))
        F_zero  = np.tensordot(P[0], Fphi, axes=(0, 0))

        shift_vec = (ky_unit * k_exc * Ny, kx_unit * k_exc * Nx)
        F_plus_shifted  = fourier_shift(F_plus,  shift_vec)
        F_minus_shifted = fourier_shift(F_minus, (-shift_vec[0], -shift_vec[1]))

        orders.extend([F_zero, F_plus_shifted, F_minus_shifted])

    otf_mask = k_abs <= 0.5
    otf = otf_mask.astype(np.float32)

    F_comb = np.zeros_like(orders[0], dtype=np.complex64)
    for F_ord in orders:
        F_comb += wiener_filter(F_ord, otf, alpha)

    hi_res = np.abs(np.fft.ifftn(F_comb))
    print(f"⏱️ SIM reconstruction time: {time.time() - t0:.2f} s")
    return hi_res


###############################################################################
# Main Execution
###############################################################################

if __name__ == "__main__":
    total_start = time.time()

    folder = r"C:\Users\sd80731\Desktop\Codes\OneDrive_2025-07-28\Data for Sim Angles 45-90-135"

    stack = load_raw_stack_from_folder(folder)
    print("\n✅ SIM stack loaded:", stack.shape)

    sim_img = sim_reconstruct(stack, alpha=0.02)

    # Save SIM image
    output_path = os.path.join(folder, "SIM_output_halfsize.tiff")
    io.imsave(output_path, sim_img.astype(np.float32))
    print(f"💾 SIM image saved to: {output_path}")

    # Display
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(io.imread(os.path.join(folder, "Raw_cropped_reference.tiff")), cmap="gray")
    plt.title("Cropped Raw (Angle 45, Phase 0)")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(sim_img, cmap="gray")
    plt.title("SIM Reconstructed Image")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    print(f"\n⏱️ Total runtime: {time.time() - total_start:.2f} s")
