This notebook provides a demonstration of PCA on condensate images.

Nejc Blaznik, based on the code by Allard Mosk, a.p.mosk@uu.nl, July 2025

## Imports

In [None]:
# ─────────────────────────────────────────────────────────────
# IPython / Display
# ─────────────────────────────────────────────────────────────
from IPython.display import display, HTML, clear_output
display(HTML("<style>.container { width:80% !important; }</style>"))

# ─────────────────────────────────────────────────────────────
# Standard Libraries
# ─────────────────────────────────────────────────────────────
import os
import sys
import time
import math
import csv
import glob
import warnings
from io import BytesIO
from decimal import Decimal
import argparse
import gc
import datetime

# ─────────────────────────────────────────────────────────────
# Core Libraries
# ─────────────────────────────────────────────────────────────
import numpy as np
from numpy import random as rng

# ─────────────────────────────────────────────────────────────
# Scientific Libraries
# ─────────────────────────────────────────────────────────────
import scipy
import scipy.signal
import scipy.odr as odr
import scipy.fftpack as fft
from scipy import signal
from scipy.ndimage import rotate, center_of_mass
from scipy.ndimage.interpolation import rotate as rotate_interp  # if needed separately
from scipy.special import eval_legendre, zeta
from scipy.optimize import curve_fit
from skimage.restoration import unwrap_phase
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
import mpmath
import astropy.io.fits as pyfits
from spacepy import pycdf

# ─────────────────────────────────────────────────────────────
# Plotting and Widgets
# ─────────────────────────────────────────────────────────────
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from matplotlib.colors import hsv_to_rgb, LogNorm
from matplotlib.transforms import IdentityTransform
from matplotlib.figure import Figure
from matplotlib.widgets import Slider, EllipseSelector, RectangleSelector
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import Arc

# ─────────────────────────────────────────────────────────────
# Image and File Handling
# ─────────────────────────────────────────────────────────────
import imageio
from PIL import Image
import cv2
import requests

# ─────────────────────────────────────────────────────────────
# Custom Modules
# ─────────────────────────────────────────────────────────────
import inpaintingfunction as inpaint
from inpaintingfunction import wtop, wbottom, wleft, wright
import OAH_functions as f1
from OAHDEV_functions import *
from OAH_refocus import *
from fitfunctions import gaussmod, tfmod, bimodalmod, tfaxialmod, gaussmod_OAH, tfmod_OAH, bimodalmod_OAH
importlib.reload(inpaint)
param_wtop, param_wbottom, param_wleft, param_wright = np.load(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/cut_arrs/CUT_COORDS.npy")
# ─────────────────────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────────────────────
kB = 1.38064852E-23
k_B = kB
m = 3.81923979E-26
m_na = m
hb = 1.0545718E-34
asc = 2.802642E-9
mu0 = 1E-50
e0 = 8.854187E-12
pix_size = 6.5E-6 / 2.63
lamb0 = 589.1E-9
k0 = 2 * np.pi / lamb0

# ─────────────────────────────────────────────────────────────
# Paths and Settings
# ─────────────────────────────────────────────────────────────
folder = '/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/OAH_processing/'

# ─────────────────────────────────────────────────────────────
# Cropping Parameters for Loaded Images
# ─────────────────────────────────────────────────────────────
crop_top    = 10
crop_bottom = -10
crop_left   = 10
crop_right  = -10

# ─────────────────────────────────────────────────────────────
# Chebyshev Polynomial Orders for Background Gradient Correction
# ─────────────────────────────────────────────────────────────
gradientremovalorder_X = 4
gradientremovalorder_Y = 2

# ─────────────────────────────────────────────────────────────
# Image Selection Parameters
# ─────────────────────────────────────────────────────────────
number_of_images          = 150   # Use a subset to speed up processing
number_validation_images  = 50    # Remaining images used for validation

# ─────────────────────────────────────────────────────────────
# Notebook Settings
# ─────────────────────────────────────────────────────────────
%matplotlib notebook
warnings.filterwarnings('ignore')


  from scipy.ndimage.interpolation import rotate as rotate_interp  # if needed separately


#### my custom functions 

In [None]:
from scipy.optimize import curve_fit

def gaussian(x, a, x0, sigma, offset):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2)) + offset

def fit_1d_data(ydata, p0='none'):
    xdata = range(len(ydata))
    # Initial guess: amplitude, center, width, offset
    if p0 == 'none':
        p0 = [max(ydata), xdata[np.argmax(ydata)], np.std(xdata), min(ydata)]
    # Fit the data
    try:
        popt, pcov = curve_fit(gaussian, xdata, ydata, p0=p0)
    
    except: 
        popt = np.array([0, 0, 0, 0])

    return popt  # fitted parameters: a, x0, sigma, offset

def get_center_width(a1_is, a2_is, p0='none'):
    fitted_centers = fit_1d_data((a1_is+a2_is).T[600:900].mean(axis=0), p0=p0)
    center = round(fitted_centers[1])
    width = abs(int(fitted_centers[2]))
    return center, width, fitted_centers




def magnetization_trace_OLD(params, smoothing=30, cut_v=[1, -1, 1, -1], center=None, width=None, mask_treshold=0.3):
    """
    Computes a magnetization trace from angular images.

    Parameters:
        params: array-like
            Contains the date and shot identifier.
        smoothing: int
            Window size for Savitzky-Golay filter (0 disables smoothing).
        cut_v: list
            Cutting vector used for preprocessing.
        center: int
            Center index for extracting the magnetization window.
        width: int
            Half-width of the window around the center.
        mask_treshold: float
            Threshold below which data is considered unreliable and masked.

    Returns:
        np.ndarray
            1D array of averaged magnetization values per frame.
    """

    date = params[0]
    shot = params[1]

    num_of_nums = get_nr_of_atoms(date, shot)
    ang_is = np.load(f"{save_folder}data/{date}_{shot}_ang_is.npy")
    
    if center == None:
        center = ang_is.shape[2]//2
        
    if width == None:
        width = 3
    print(center, width)
    magentization = []
    for i in range(num_of_nums):
        a1_is = rotate(ang_is[i][0][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
        a2_is = rotate(ang_is[i][1][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)

        if smoothing != 0:
            a1_is = savgol_filter(a1_is, smoothing, 3)
            a2_is = savgol_filter(a2_is, smoothing, 3)

        ratio = (a1_is - a2_is) / (a1_is + a2_is)
        mask = (a1_is + a2_is) <= mask_treshold

        # Fill masked values with np.nan and compute nanmean
        ratio_masked = np.ma.array(ratio, mask=mask).filled(np.nan)
        trimmed = ratio_masked[center - width:center + width]
        mean_val = np.nanmean(trimmed)

        magentization.append(mean_val)
    print(date, shot, 4*np.array(magentization).mean())

    return np.array(magentization)




def magnetization_trace_PCA(ang_is, smoothing=30, cut_v=[1, -1, 1, -1], center=None, width=None, mask_treshold=0.3):
    """
    Computes a magnetization trace from angular images.

    Parameters:
        params: array-like
            Contains the date and shot identifier.
        smoothing: int
            Window size for Savitzky-Golay filter (0 disables smoothing).
        cut_v: list
            Cutting vector used for preprocessing.
        center: int
            Center index for extracting the magnetization window.
        width: int
            Half-width of the window around the center.
        mask_treshold: float
            Threshold below which data is considered unreliable and masked.

    Returns:
        np.ndarray
            1D array of averaged magnetization values per frame.
    """

    num_of_nums = len(ang_is)

    if center == None:
        center = ang_is.shape[2]//2
        
    if width == None:
        width = 3

    print(center, width)
    magentization = []
    for i in range(num_of_nums):
        a1_is = rotate(ang_is[i][0], 0.7)
        a2_is = rotate(ang_is[i][1], 0.7)

        if smoothing != 0:
            a1_is = savgol_filter(a1_is, smoothing, 3)
            a2_is = savgol_filter(a2_is, smoothing, 3)

        ratio = (a1_is - a2_is) / (a1_is + a2_is)
        mask = (a1_is + a2_is) <= mask_treshold

        # Fill masked values with np.nan and compute nanmean
        ratio_masked = np.ma.array(ratio, mask=mask).filled(np.nan)
        trimmed = ratio_masked[center - width:center + width]
        mean_val = np.nanmean(trimmed)

        magentization.append(mean_val)
    print(date, shot, 4*np.array(magentization).mean())

    return np.array(magentization)


    
def save_gif(png_path, output_path_gif, output_path_mp4=None, frame_duration=125):
    ### SAVE AS GIF AND MP4 
    # Paths and settings
    folder_path = png_path
    output_gif_path = output_path_gif + ".gif"
    fps = 1000 // frame_duration

    # Get and sort PNG files
    png_files = sorted([
        os.path.join(folder_path, f)
        for f in os.listdir(folder_path)
        if f.endswith(".png")
    ])

    # Load frames (for GIF)
    frames = [Image.open(f) for f in png_files]

    # Save as GIF
    frames[0].save(
        output_gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=frame_duration,
        loop=0
    )
    if output_path_mp4 != None:
        output_mp4_path = output_path_mp4 + ".mp4"
        # Load frames (for MP4)
        frame_array = []
        for file in png_files:
            img = cv2.imread(file)
            height, width, _ = img.shape
            frame_array.append(img)

        # Define codec and create VideoWriter object
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_mp4_path, fourcc, fps, (width, height))

        # Write frames
        for frame in frame_array:
            out.write(frame)
        out.release()
    clear_output()
    
    
def makeATrendArr_OLD(params, smoothing=30, cut_v=[1, -1, 1, -1], center='None', width=4, mask_treshold=0.3):
    """
        params - array of parameters as defined elsewhere. 
        smoothing - window size to use for savgol_filter (0 to use no smoothing)
        returns: image array
    """
        
    date = params[0]
    shot = params[1]
    
    num_of_nums = get_nr_of_atoms(date, shot)
    try: 
        ang_is = np.load(f"{save_folder}data/{date}_{shot}_ang_is.npy")
    except: 
        ang_is = np.load(f"{save_folder}data/{date}_{shot}_ang_is_FAST.npy")

    all_ratios = []
    for i in range(num_of_nums):
        a1_is = rotate(ang_is[i][0][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
        a2_is = rotate(ang_is[i][1][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
        if smoothing != 0:
            a1_is = savgol_filter(a1_is, smoothing, 3)
            a2_is = savgol_filter(a2_is, smoothing, 3)
        ratio = (a1_is - a2_is) / (a1_is + a2_is)
        mask = (a1_is + a2_is) <= mask_treshold
        ratio_masked = np.ma.array(ratio, mask=mask)
        ratio_nan = ratio_masked.filled(np.nan)
        if center == 'None':
            center, _, _ = get_center_width(a1_is, a2_is)                
#         print(center, width)
        trimmed = ratio_nan[center-width:center+width]
        mean_trimmed = np.nanmean(trimmed, axis=0)
        all_ratios.append(mean_trimmed)
    return np.array(all_ratios)

def makeATrendArr_PCA(ang_is, smoothing=30, cut_v=[1, -1, 1, -1], center='None', width=4, mask_treshold=0.25):
    """
        params - array of parameters as defined elsewhere. 
        smoothing - window size to use for savgol_filter (0 to use no smoothing)
        returns: image array
    """

    all_ratios = []
    for i in range(len(ang_is)):
        a1_is = rotate(cut_arr(ang_is[i][0], cut_v), 0.7)
        a2_is = rotate(cut_arr(ang_is[i][1], cut_v), 0.7)
        if smoothing != 0:
            a1_is = savgol_filter(a1_is, smoothing, 3)
            a2_is = savgol_filter(a2_is, smoothing, 3)
        ratio = (a1_is - a2_is) / (a1_is + a2_is)
        mask = (a1_is + a2_is) <= mask_treshold
        ratio_masked = np.ma.array(ratio, mask=mask)
        ratio_nan = ratio_masked.filled(np.nan)
        if center == 'None':
            center, _, _ = get_center_width(a1_is, a2_is)                
#         print(center, width)
        trimmed = ratio_nan[center-width:center+width]
        mean_trimmed = np.nanmean(trimmed, axis=0)
        all_ratios.append(mean_trimmed)
    return np.array(all_ratios)


# Randomize the background, or even get the mean or something? 
def HI_refocus_custom(date, shot, num, dz_focus, quad="quad1", num_flat=0, shift_pix=[0, 0], output="ang", plot=False, cut=[xmin, xmax, zmin, zmax]):
    """
    Main function for generating the output arrays for the holgoraphic imaging, which also includes the refocusing.
    What we do, is we process image - first take a FTT, apply Tukey windows, shift, refocus, and back iFFT.
    ------
    :param date: Input date of the image to be analysed
    :param shot: Input shot of the image to be analysed
    :param num: Input sequence number of the image to be analysed
    :param dz_focus: The focus parameter that determines the focus of the image
    :param quad: Parameter to determine which quad we are cutting.
    :param output: Output the phase (ang) or amplituide (amp)
    :param plot: Boolean to specify whether to plot the image of the angle or not. Better not when used in a loop.
    :return: Returns an array of the angle of the ratio of the iFFT's of the atom and flat image.
    """
    # ------------------------------------------------- IMPORTS -----------------------------------------------------
    path = '/storage/data/' + str(date) + '/'
    image = str(shot).zfill(4) + '/'
    
    xmin, xmax, zmin, zmax = cut

    if num != 0: 
        num_flat = num_flat % num 
        
    # Opening files
    atoms = pyfits.open(path + image + '0.fits')[0].data.astype(float)[num][xmin:xmax, zmin:zmax]
    flat = pyfits.open(path + image + '1.fits')[0].data.astype(float)[num_flat][xmin:xmax, zmin:zmax]
    dark = pyfits.open(path + image + '2.fits')[0].data.astype(float).mean(axis=0)[xmin:xmax, zmin:zmax]

    # ----------------------------------------------- CORRECTIONS ---------------------------------------------------
    # Creates a squaroid dark edge
    atoms = f1.squaroid(atoms - dark, width=0.51)
    flat = f1.squaroid(flat - dark, width=0.51)
    # --------------------------------------------------- FFT --------------------------------------------------------
    # Take the FTT of the atoms
    fft_atoms = np.fft.fft2(atoms)
    fft_flat = np.fft.fft2(flat)

    # We create the Tukey windows cutouts for the data and the background. Additional cuts in x,z direction can be
    # passed, but this is optional; the default value is 0. Note that because you later cut everything in the
    # same size, this may affect the rest too.
    quad1, q1peak = f1.box_cutter_pad_ellips(fft_atoms, quad, 0, 0, edge_x=10, edge_z=80)
    flatq1, f1peak = f1.box_cutter_pad_ellips(fft_flat, quad, 0, 0, edge_x=10, edge_z=80) #indices=q1peak)
    
    # Cutting the quads in the same sizes.
    quad1cut, flatq1cut = f1.sizecomp(quad1, flatq1)

    # ------------------------------------------------ FFT SHIFT ----------------------------------------------------
    # Now we FFT shift the zero frequency to the center
    fft1 = np.fft.fftshift(quad1cut)
    flatfft1 = np.fft.fftshift(flatq1cut)

    
    # ------------------------------------------------ PHASE SHIFT & FFT SHIFT ----------------------------------------------------
    # Shift in real space = linear phase in FFT space (optional)
    shift_x, shift_z = shift_pix

    Nz, Nx = quad1cut.shape  # Shape is (z, x)
    kx = np.fft.fftfreq(Nx)  # Now in cycles per pixel
    kz = np.fft.fftfreq(Nz)
    KX, KZ = np.meshgrid(kx, kz)         # Create full 2D grid

    # Linear phase ramp
    phase_ramp = np.exp(-1j * 2 * np.pi * (KX * shift_x + KZ * shift_z))

#     # Apply ramp before shifting
#     fft1 = np.fft.fftshift(quad1cut * phase_ramp)
#     flatfft1 = np.fft.fftshift(flatq1cut * phase_ramp)

    
    # ------------------------------------------------ REFOCUSING ---------------------------------------------------
    fft_kx = np.fft.fftfreq(fft1.shape[1], d=pix_size)  # Discrete FFT Sample Frequency in x
    fft_ky = np.fft.fftfreq(fft1.shape[0], d=pix_size)  # Discrete FFT Sample Frequency in z
    fft_k2 = fft_kx[None, :] ** 2 + fft_ky[:, None] ** 2  # Discrete FFT Sample Frequency in main axes multiplied

    # Determine the focus factor and refocus
    focus = np.exp(-1j * fft_k2 * dz_focus / (2 * k0))
    fft1 = fft1 * focus * phase_ramp
    flatfft1 = flatfft1 * focus * phase_ramp


    # ------------------------------------- INVERSE FFT -------------------------------------------------
    inv1 = np.fft.ifft2(fft1) / np.fft.ifft2(flatfft1)
    inv1 = inv1[border_x:-border_x, border_z:-border_z]

    # Get Phase
    ang1 = np.angle(inv1)
    ang1 = f1.unwrapper(ang1)
    amp1 = np.abs(inv1) ** 2

    # normalize amplitude?
    normfactor = amp1.mean()  # [300:900, 300:900].mean()
    amp1 = amp1 - normfactor 
    amp1 = normalize(amp1)[0] # Use the function above to normalize the image.

    # Normalize
    normfactor = ang1.mean()  # [300:900, 300:900].mean()
    ang1 = ang1 - normfactor 
    ang1 = normalize(ang1)[0] # Use the function above to normalize the image.

    if plot:
        plt.imshow(ang1, cmap='Greys', interpolation='none', origin="lower")
        plt.title(str(dz_focus))
        plt.colorbar()
        plt.show()

    if output == "amp":
        return amp1
    elif output == "ang": 
        return ang1
    elif output == "quad":
        return quad1
    
    
def binImage(pic, xbin, zbin):
    """ A function to bin the pic file based on the bin parameters. """
    # If pic file not a multiple of bin, cut from the edge so it is.
    if pic.shape[0] % xbin != 0:
        pic = pic[:-(pic.shape[0] % xbin), :]
    if pic.shape[1] % zbin != 0:
        pic = pic[:, :-(pic.shape[1] % zbin)]
    pic = pic.reshape(pic.shape[0] // xbin, xbin, pic.shape[1] // zbin, zbin).mean(axis=3).mean(axis=1)
    return pic


def getPhysical(all_fits):
    phys_vars = []
    for shot_info in all_fits:
        xbin = 4
        zbin = 4
        pixelsize = 2.47148288973384e-06 
        kB = 1.38064852E-23
        m = 3.81923979E-26
        hb = 1.0545718E-34
        asc = 2.802642E-9
        mu0 = 1e-50
        e0 = 8.854187E-12
        fx = 115
        fz = 15
        wavelength = 589e-9
        detuning = 0
        prefactor = float((1 + 4 * detuning ** 2) * 2 * np.pi / (3 * (wavelength ** 2)) * 18. / 5.)

#         par_names = ['offset', 'ampl', 'ang', 'xmid', 'ymid', 'tfamp', 'tfxw', 'tfyw', 'gamp', 'gxw', 'gyw']
#         bin_scaling = np.array([1., 1., 1., xbin, zbin, 1., xbin, zbin, 1., xbin, zbin])
#         rng_offset = np.array([0., 0., 0., xmin, zmin, 0., 0., 0., 0., 0., 0.])
#         to_physical = np.array([1., 1., 1., pixelsize, pixelsize, prefactor, pixelsize, pixelsize, prefactor, pixelsize, pixelsize])
        
        par_names = ['offset', 'ampl', 'ang', 'xmid', 'ymid', 'gamp', 'gxw', 'gyw']
        bin_scaling = np.array([1., 1., 1., xbin, zbin,  1., xbin, zbin])
        rng_offset = np.array([0., 0., 0., xmin, zmin, 0., 0., 0.])
        to_physical = np.array([1., 1., 1., pixelsize, pixelsize, prefactor, pixelsize, pixelsize])

        # Converts the fit results to absolute pixel values in the unbinned image.

        fit_results = shot_info * bin_scaling + rng_offset
        phys_results = fit_results * to_physical

        tof = 0
        ntherm = 0
        ntf = 0
        tx = 0
        tz = 0
        mux = 0
        muz = 0
        mun = 0

        ntherm = 2 * np.pi * phys_results[5] * phys_results[6] * phys_results[7] 
        tx = 1 / kB * m / 1 * (fx * np.pi * 2 * phys_results[6]) ** 2 / (1 + (tof * fx * np.pi * 2) ** 2)
        tz = 1 / kB * m / 1 * (fz * np.pi * 2 * phys_results[7]) ** 2 / (1 + (tof * fz * np.pi * 2) ** 2)
#         mux = m / 1 * (fx * np.pi * 2 * phys_results[6]) ** 2 / (1 + (tof * fx * np.pi * 2) ** 2)
#         muz = m / 1 * (fz * np.pi * 2 * phys_results[7]) ** 2 / (1 + (tof * fz * np.pi * 2) ** 2)
#         mun = 1.47708846953 * np.power(
#             ntf * asc / (np.sqrt(hb / (m * np.power(8 * np.pi ** 3 * fx ** 2 * fz, 1. / 3.)))),
#             2. / 5.) * hb * np.power(8 * np.pi ** 3 * fx ** 2 * fz, 1. / 3.)

        phys_vars.append([ntf, ntherm, tx, tz, mux, muz, mun])
    return phys_vars 
        
    
def cut_arr(arr, cut_v):
    return np.array(arr[cut_v[0]:cut_v[1], cut_v[2]:cut_v[3]])


def get_nr_of_atoms(date, shot):
    fits_path = f'/storage/data/{date}/{str(shot).zfill(4)}/0.fits'
    nr = len(pyfits.open(fits_path)[0].data.astype(float))
    return nr

def get_nr_of_flats(date, shot):
    fits_path = f'/storage/data/{date}/{str(shot).zfill(4)}/1.fits'
    nr = len(pyfits.open(fits_path)[0].data.astype(float))
    return nr




## Allard's Code 

In [494]:
date = 20250507
shot = 66
crop_top    = 5
crop_bottom = -5
crop_left   = 50
crop_right  = -50

In [493]:
# ─────────────────────────────────────────────────────────────
# File Paths
# ─────────────────────────────────────────────────────────────
file_blank = f'full_complex_field_{date}_{shot}_flat.npy'
file_atoms = f'full_complex_field_{date}_{shot}_atoms.npy'
param_wtop, param_wbottom, param_wleft, param_wright = np.load(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/cut_arrs/{date}_{str(shot).zfill(4)}.npy")

# ─────────────────────────────────────────────────────────────
# Load Data
# ─────────────────────────────────────────────────────────────
full_blank_images = np.load(folder + file_blank, allow_pickle=True)
full_complex_images_atoms = np.load(folder + file_atoms)

# ─────────────────────────────────────────────────────────────
# Crop to Region of Interest
# ─────────────────────────────────────────────────────────────
blank_validation_images = full_blank_images[-number_validation_images:, :, crop_top:crop_bottom, crop_left:crop_right]
blank_images = full_blank_images[:number_of_images, :, crop_top:crop_bottom, crop_left:crop_right]
complex_images_atoms = full_complex_images_atoms[:number_of_images, :, crop_top:crop_bottom, crop_left:crop_right]

# ─────────────────────────────────────────────────────────────
# Image Dimensions
# ─────────────────────────────────────────────────────────────
nimages, _, xdim, ydim = blank_images.shape

In [495]:
inpaint.debuglevel=0
inpaint.set_empty_images(blank_images[:,0,:,:],remove_gradients=True, exclude_atoms_from_gradient=False)

In [497]:
# ─────────────────────────────────────────────────────────────
# Inpaint a validation image
# ─────────────────────────────────────────────────────────────
inpainted = inpaint.inpaint1(
    blank_validation_images[0, 0, :, :],
    use_svd=0,
    use_Tikhonov=0,
    remove_gradients=True,
    exclude_atoms_from_gradient=False
)

# ─────────────────────────────────────────────────────────────
# Reference image with only gradient removed
# ─────────────────────────────────────────────────────────────
true_image = inpaint.RemovePhaseGradient(
    blank_validation_images[0, 0],
    exclude_atoms_from_gradient=False
)[param_wtop:param_wbottom, param_wleft:param_wright]

# ─────────────────────────────────────────────────────────────
# Compute difference metric (scaled phase RMS error)
# ─────────────────────────────────────────────────────────────
diff_meas = 1000 * np.linalg.norm(
    np.angle(inpainted.flatten()) - np.angle(true_image.flatten())
) / (true_image.size**0.5)

print(f"{diff_meas:.3f}")

ValueError: operands could not be broadcast together with shapes (63744,) (68582,) 

In [None]:
t = 1
diffs_both = []
for i in range(2):
    inpaint.set_empty_images(blank_images[:,i,:,:],remove_gradients=True, exclude_atoms_from_gradient=False)
    image = complex_images_atoms[t, i, :, :]
    # Run inpainting
    test = inpaint.inpaint(complex_images_atoms[t:t+2, i, :, :], use_svd=0, use_Tikhonov=0, remove_gradients=True, exclude_atoms_from_gradient=False)

    # Prepare images for display
    inpainted_phase = np.angle(test[0])
    original_cropped = inpaint.RemovePhaseGradient(image)[param_wtop:param_wbottom, param_wleft:param_wright]
    original_phase = np.angle(original_cropped)
    difference = original_phase - inpainted_phase    
    diffs_both.append(difference)

In [None]:
for num in [0]: #range(150):
    PCA_phase = PCA_phases[num]
#     PCA_phase = diffs_both
    # ─────────────────────────────────────────────────────────────
    # Figure and Grid Setup
    # ─────────────────────────────────────────────────────────────
    fig = plt.figure(figsize=(17, 9))
    gs = gridspec.GridSpec(3, 5, width_ratios=[1, 1, 1, 1, 1], height_ratios=[1, 1, 2], hspace=0.05, wspace=0.05)

    # Line plot axes
    ax41 = (fig.add_subplot(gs[2, 0]), fig.add_subplot(gs[2, 1]))
    ax43 = (fig.add_subplot(gs[2, 2]), fig.add_subplot(gs[2, 3]))
    names = ["OAH-1", "OAH-2"]

    ax_ratios = (fig.add_subplot(gs[0, 4]), fig.add_subplot(gs[1, 4]),  fig.add_subplot(gs[2, 4]))

    # ─────────────────────────────────────────────────────────────
    # Loop through the two reconstructions
    # ─────────────────────────────────────────────────────────────
    for i in range(2):
        # Phase images
        ax0 = fig.add_subplot(gs[i, 0])
        ax1 = fig.add_subplot(gs[i, 1])
        ax2 = fig.add_subplot(gs[i, 2])
        ax3 = fig.add_subplot(gs[i, 3])

        ax0.set_ylabel(names[i])

        # Prepare rotated images
        in_phase = rotate(inpainted_phase, 0.7)
        or_phase = rotate(original_phase, 0.7)
        diff_phase = rotate(PCA_phase[i], 0.7)
        ang_old = rotate(angs_old_method[num][i][crop_top - 5:crop_bottom, crop_left:crop_right][inpaint.wtop:inpaint.wbottom, inpaint.wleft:inpaint.wright], 0.7)

        # Plot phase images
        im0 = ax0.imshow(in_phase, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')
        im1 = ax1.imshow(or_phase, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')
        im2 = ax2.imshow(diff_phase, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')
        im3 = ax3.imshow(ang_old, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')


        lc_lims = [[18, 25], [350, 600]]

        # Line plots: horizontal average over selected rows
        ax41[i].plot(diff_phase[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0), c='C0', label="PCA")
        ax41[i].plot(ang_old[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0), c='C1', label="Old", alpha=0.4)

        # Line plots: vertical average over selected columns
        ax43[i].plot(diff_phase.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0), c='C0', label="PCA")
        ax43[i].plot(ang_old.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0), c='C1', label="Old", alpha=0.4)

        # Titles for first row
        if i == 0:
            ax0.set_title("Reconstructed Phase")
            ax1.set_title("Original Phase")
            ax2.set_title("PCA")
            ax3.set_title("Old Method")
            ax2.axvline(x=lc_lims[1][0], c='r', ls='--', alpha=0.35)
            ax2.axvline(x=lc_lims[1][1], c='r', ls='--', alpha=0.35)
            ax2.axhline(y=lc_lims[0][0], c='r', ls='--', alpha=0.35)
            ax2.axhline(y=lc_lims[0][1], c='r', ls='--', alpha=0.35)



    a1_is = rotate(angs_old_method[num][0][crop_top - 5:crop_bottom, crop_left:crop_right][inpaint.wtop:inpaint.wbottom, inpaint.wleft:inpaint.wright], 0.7)
    a2_is = rotate(angs_old_method[num][1][crop_top - 5:crop_bottom, crop_left:crop_right][inpaint.wtop:inpaint.wbottom, inpaint.wleft:inpaint.wright], 0.7)
    ratio = (a1_is - a2_is) / (a1_is + a2_is)
    mask = (a1_is + a2_is) < 0.25
    ratio = np.ma.array(ratio, mask=mask)
    ax_ratios[1].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
    ax_ratios[1].text(len(a1_is[0])//2, 40, "OLD", weight='bold')
    ax_ratios[2].plot(ratio[18:25].sum(axis=0)/len(a1_is), c='C1', alpha=0.4)



    a1_is = rotate(PCA_phase[0], 0.7)
    a2_is = rotate(PCA_phase[1], 0.7)
    ratio = (a1_is - a2_is) / (a1_is + a2_is)
    mask = (a1_is + a2_is) < 0.25
    ratio = np.ma.array(ratio, mask=mask)
    im = ax_ratios[0].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
    ax_ratios[0].text(len(a1_is[0])//2, 40, "PCA", weight='bold')
    ax_ratios[2].plot(ratio[18:25].sum(axis=0)/len(a1_is), c='C0')

    ax_ratios[2].axhline(y=0, c='k', ls='--', alpha=0.5)

    ax_ratios[2].set_ylim([-0.2, 0.2])

    # ─────────────────────────────────────────────────────────────
    # Final Touches
    # ─────────────────────────────────────────────────────────────
    # Add legends to line plots
    ax43[0].legend(loc=1)
    # Remove ticks from all subplots
    for ax in fig.axes:
        ax.set_xticks([])
        ax.set_yticks([])

    # Global figure title
    fig.suptitle(f"{date} - {shot} - {num}")
    plt.tight_layout()
    plt.subplots_adjust(left=0.05, right=1.1)
    plt.colorbar(im, ax=fig.axes, aspect=30, pad=0.01)
    
    os.makedirs(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/PreviewPlots/{date}_{shot}/", exist_ok=True)
    plt.savefig(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/PreviewPlots/{date}_{shot}/{str(num).zfill(4)}.png")
    plt.show()


## Run Tests 

In [50]:
import numpy as np
import matplotlib.pyplot as plt
import inpaintingfunction as inpaint  # your custom module

# Select one validation image
image = blank_Validation_images[0, 0, :, :]

# Parameter grid
svd_values = range(0, 400, 1)
tikhonov_values = [0, 1e-10, 1e-9, 1e-8, 1e-7]

results = []

# Sweep parameters
for i, thk in enumerate(tikhonov_values):
    rec = inpaint.inpaint1(image, use_svd=0, use_Tikhonov=thk, remove_gradients=True, exclude_atoms_from_gradient=False)
    inpainted = rec
    true = inpaint.RemovePhaseGradient(image, exclude_atoms_from_gradient=False)[inpaint.wtop:inpaint.wbottom, inpaint.wleft:inpaint.wright]
    diff = np.angle(inpainted) - np.angle(true)
    diffmeas = 1000 * np.linalg.norm(diff.flatten()) / (true.size ** 0.5)
    results.append(diffmeas)
    print(i, "/100 - ", diffmeas)


0 /100 -  37.56256330876287
1 /100 -  1966.3258423706666
2 /100 -  1966.808310980083
3 /100 -  1968.6008089935071
4 /100 -  1948.7015352183482


## Comparison Plots

### Load the files

In [517]:
all_params = [
    [20250422,  20],  #0
    [20250422,  21],  #1
    [20250422,  26],  #2
    [20250422,  32],  #3
    [20250422,  33],  #4
    [20250423,  29],  #5
    [20250423,  49],  #6
    [20250423,  50],  #7
    [20250423,  51],  #8
    [20250424,  37],  #9
    [20250424,  38],  #10
    [20250424,  46],  #11
    [20250424,  48],  #12
    [20250424,  54],  #13
    [20250428,  31],  #14
    [20250428,  33],  #15
    [20250428,  34],  #16
    [20250428,  35],  #17
    [20250428,  56],  #18
    [20250428,  58],  #19
    [20250428,  70],  #20
    [20250428,  71],  #21
    [20250428,  72],  #22
    [20250428,  73],  #23
    [20250428,  74],  #24
    [20250428,  77],  #25
    [20250428,  80],  #26
    [20250430,   6],  #27
    [20250430,  64],  #28
    [20250430,  68],  #29
    [20250430,  72],  #30
    [20250430,  99],  #31
    [20250430, 103],  #32
    [20250430, 104],  #33
    [20250501,  42],  #34
    [20250501,  57],  #35
    [20250501,  72],  #36
    [20250501,  80],  #37
    [20250501,  81],  #38
    [20250502,  16],  #39
    [20250502,  17],  #40
    [20250502,  18],  #41
    [20250502,  21],  #42
    [20250502,  32],  #43
    [20250502,  34],  #44
    [20250502,  60],  #45
    [20250502,  65],  #46
    [20250502,  74],  #47
    [20250502,  75],  #48
    [20250502,  80],  #49
    [20250502,  82],  #50
    [20250502,  83],  #51
    [20250502,  89],  #52
    [20250502,  93],  #53
    [20250502, 102],  #54
    [20250502, 102],  #55
    [20250502, 109],  #56
    [20250502, 113],  #57
    [20250507,  11],  #58
    [20250507,  17],  #59
    [20250507,  24],  #60
    [20250507,  26],  #61
    [20250507,  28],  #62
    [20250507,  29],  #63
    [20250507,  37],  #64
    [20250507,  38],  #65
    [20250507,  39],  #66
    [20250507,  45],  #67
    [20250507,  49],  #68
    [20250507,  50],  #69
    [20250507,  51],  #70
    [20250507,  65],  #71
    [20250507,  66],  #72
    [20250507,  67],  #73
    [20250507,  68],  #74
]

all_spinflip_params = [
    [20250423,  49],  #6
    [20250423,  50],  #7
    [20250423,  51],  #8
    [20250424,  37],  #9
    [20250424,  38],  #10
    [20250424,  46],  #11
    [20250428,  31],  #14
    [20250428,  33],  #15
    [20250428,  34],  #16
    [20250428,  35],  #17
    [20250501,  80],  #37
    [20250502,  16],  #39
    [20250502,  17],  #40
    [20250502,  18],  #41
    [20250502,  74],  #47
    [20250502,  75],  #48
    [20250507,  45],  #67
    [20250507,  49],  #68
    [20250507,  50],  #69
    [20250507,  51],  #70
    [20250507,  65],  #72
    [20250507,  66],  #73
    [20250507,  67],  #74
] 

all_spinpull_params = [
    [20250422,  20],  #0
    [20250422,  21],  #1
    [20250422,  26],  #2
    [20250422,  32],  #3
    [20250422,  33],  #4
    [20250423,  29],  #5
    [20250424,  37],  #9
    [20250424,  38],  #10
    [20250424,  46],  #11
]

### Load 

In [521]:
## Make sure to run the background analysis through the processPCA.py script - for which you need to have the raw files ready as well. 
param = all_params[6]
date = param[0]
shot = param[1]

date = 20250428
shot = 35

print(date, shot)

save_folder_old = '/home/bec_lab/Desktop/imgs/SOAH/SpinAnalysisApril2025/' 
save_folder_pca = '/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/PCA_processing' 
angs_old_method = np.load(f"{save_folder_old}data/{date}_{shot}_ang_is.npy")
PCA_phases = np.array([np.load(f"{save_folder_pca}/{date}_{shot}/{file}") for file in sorted(os.listdir(f"{save_folder_pca}/{date}_{shot}/"))])

# Additional Phase Unwrapping? 
PCA_phases = [[unwrap_phase(unwrap_phase(PCA_phases[i][j])) for j in range(len(PCA_phases[0]))] for i in range(len(PCA_phases))]

# Get the new cutting window 
param_wtop, param_wbottom, param_wleft, param_wright = np.load(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/cut_arrs/{date}_{str(shot).zfill(4)}.npy")
print(param_wtop, param_wbottom, param_wleft, param_wright)

20250428 35
25 78 722 2201


### Plot

In [522]:
# Preview Single plot 
arr_trend_old = makeATrendArr_OLD([date, shot], smoothing=0, mask_treshold=0.25)
arr_trend_pca = makeATrendArr_PCA(np.array(PCA_phases), smoothing=0, mask_treshold=0.25)

nr_atoms = get_nr_of_atoms(date, shot)
for num in [0]: #range(nr_atoms):
    PCA_phase = PCA_phases[num]
#     PCA_phase = diffs_both
    # ─────────────────────────────────────────────────────────────
    # Figure and Grid Setup
    # ─────────────────────────────────────────────────────────────
    fig = plt.figure(figsize=(17, 9))
    gs = gridspec.GridSpec(3, 8, width_ratios=[1, 1, 1, 1, 1, 1, 2, 2], height_ratios=[1, 1, 2], hspace=0.05, wspace=0.05)

    # Line plot axes
    ax41 = (fig.add_subplot(gs[2, 0:2]), fig.add_subplot(gs[2, 2:4]))
    ax43 = (fig.add_subplot(gs[2, 4]), fig.add_subplot(gs[2, 5]))
    names = ["OAH-1", "OAH-2"]

    ax_ratios = (fig.add_subplot(gs[0, 4:6]), fig.add_subplot(gs[1, 4:6]))
    ax_magnetization = (fig.add_subplot(gs[0:3, 6]), fig.add_subplot(gs[0:3, 7]))

    # ─────────────────────────────────────────────────────────────
    # Loop through the two reconstructions
    # ─────────────────────────────────────────────────────────────
    for i in range(2):
        # Phase images
        ax2 = fig.add_subplot(gs[i, 0:2])
        ax3 = fig.add_subplot(gs[i, 2:4])
        ax2.set_ylabel(names[i])

        # Prepare rotated images
        diff_phase = unwrap_phase(rotate(PCA_phase[i], 0.7))
        ang_old = rotate(angs_old_method[num][i][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)

        # Plot phase images
        im2 = ax2.imshow(diff_phase, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')
        im3 = ax3.imshow(ang_old, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')

        
        ax2.set_ylabel(f"OAH-{i+1}")
        
        lc_lims = [[28, 50], [550, 800]]

        # Line plots: horizontal average over selected rows
        horizontal_linecut_PCA = diff_phase[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0)
        horizontal_linecut_OLD = ang_old[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0)
        ax41[i].plot(horizontal_linecut_PCA, c='C0', label="PCA")
        ax41[i].plot(horizontal_linecut_OLD, c='C1', label="Old", alpha=0.4)

        # Line plots: vertical average over selected columns
        vertical_linecut_PCA = diff_phase.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0)
        vertical_linecut_OLD = ang_old.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0)
        ax43[i].plot(vertical_linecut_PCA, c='C0', label="PCA")
        ax43[i].plot(vertical_linecut_OLD, c='C1', label="Old", alpha=0.4)
        ax43[i].set_xlabel(f"OAH-{i+1}")
        
        ax43[i].set_ylim([-0.2, max(vertical_linecut_PCA.max(), vertical_linecut_OLD.max()) * 1.1])
        ax41[i].set_ylim([-0.2, max(horizontal_linecut_PCA.max(), horizontal_linecut_OLD.max()) * 1.1])
        
        
        # Titles for first row
        if i == 0:
            ax2.set_title("PCA")
            ax3.set_title("Old Method")
            ax2.axvline(x=lc_lims[1][0], c='r', ls='--', alpha=0.35)
            ax2.axvline(x=lc_lims[1][1], c='r', ls='--', alpha=0.35)
            ax2.axhline(y=lc_lims[0][0], c='r', ls='--', alpha=0.35)
            ax2.axhline(y=lc_lims[0][1], c='r', ls='--', alpha=0.35)


    a1_is = rotate(angs_old_method[num][0][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
    a2_is = rotate(angs_old_method[num][1][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
    ratio = (a1_is - a2_is) / (a1_is + a2_is)
    mask = (a1_is + a2_is) < 0.25
    ratio = np.ma.array(ratio, mask=mask)
    ax_ratios[1].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
    ax_ratios[1].text(len(a1_is[0])//2, 5*a1_is.shape[0]//6, "OLD", weight='bold')

    a1_is = rotate(PCA_phase[0], 0.7)
    a2_is = rotate(PCA_phase[1], 0.7)
    ratio = (a1_is - a2_is) / (a1_is + a2_is)
    mask = (a1_is + a2_is) < 0.25
    ratio = np.ma.array(ratio, mask=mask)
    im = ax_ratios[0].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
    ax_ratios[0].text(len(a1_is[0])//2, 5*a1_is.shape[0]//6, "PCA", weight='bold')
    
    
    mask_treshold = 0.251215
    vmin = -0.25
    vmax = 0.25

    
    ax_magnetization[0].imshow(arr_trend_old, aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)
    ax_magnetization[1].imshow(arr_trend_pca, aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)
    
    ax_magnetization[0].axhline(y=num, ls='--', c='r', alpha=0.9)
    ax_magnetization[1].axhline(y=num, ls='--', c='r', alpha=0.9)

    ax_magnetization[0].set_ylim([min(len(arr_trend_old), len(arr_trend_pca)), 0])
    ax_magnetization[1].set_ylim([min(len(arr_trend_old), len(arr_trend_pca)), 0])
    
    ax_magnetization[0].set_title("OLD")
    ax_magnetization[1].set_title("PCA")

    ax43[0].spines['right'].set_visible(False)
    ax43[1].spines['left'].set_visible(False)
    
    # ─────────────────────────────────────────────────────────────
    # Final Touches
    # ─────────────────────────────────────────────────────────────
    # Add legends to line plots
    ax41[0].legend(loc=1)
    # Remove ticks from all subplots
    for ax in fig.axes:
        ax.set_xticks([])
        ax.set_yticks([])

    # Global figure title
    fig.suptitle(f"{date} - {shot} - {num}")
    plt.tight_layout()
    plt.subplots_adjust(left=0.05, right=1.1)
    plt.colorbar(im, ax=fig.axes, aspect=30, pad=0.01)

    os.makedirs(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/raw_images_for_gifs/{date}_{shot}/", exist_ok=True)
#     plt.savefig(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/raw_images_for_gifs/{date}_{shot}/{str(num).zfill(4)}.png")

    # Layout adjustment
    plt.show()


<IPython.core.display.Javascript object>

In [523]:
center = 40

arr_trend_old = makeATrendArr_OLD([date, shot], smoothing=0, mask_treshold=0.25, center=center)
arr_trend_pca = makeATrendArr_PCA(np.array(PCA_phases), smoothing=0, mask_treshold=0.25, center=center)

trace_PCA = magnetization_trace_PCA(np.array(PCA_phases), center=center)
trace_OLD = magnetization_trace_OLD([date, shot], center=center)

fig, ax = plt.subplots(3, 1, figsize=(10, 8))

im = ax[0].imshow(arr_trend_old.T[::-1, ::-1], aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)
ax[1].imshow(arr_trend_pca.T[::-1, ::-1], aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)

ax[0].set_xlim([min(len(arr_trend_old), len(arr_trend_pca)), 0])
ax[1].set_xlim([min(len(arr_trend_old), len(arr_trend_pca)), 0])
ax[2].set_xlim([0, 25])

ax[0].set_ylabel("OLD")
ax[1].set_ylabel("PCA")

ax[2].plot(np.linspace(0, 25, len(arr_trend_old)), trace_PCA*8, 'o-', alpha=0.8, label="PCA")
ax[2].plot(np.linspace(0, 25, len(arr_trend_old)), trace_OLD*8, 'o-', alpha=0.2, label="OLD")
ax[2].axhline(y=0, ls='--', c='k', alpha=0.4)
ax[2].legend()

ax[2].set_yticks([-1, 1])

for axl in [ax[0], ax[1]]:
    axl.set_xticks([])
    axl.set_yticks([])
    

fig.suptitle(f"Magnetization through imaging - {date} - {shot}")
plt.savefig(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/total_magnetization/{date}_{shot}.png")
plt.tight_layout()

plt.colorbar(im, ax=ax, pad=0.01)
plt.show()


40 3
20250428 35 0.2555137178440719
40 3
20250428 35 0.1121064935902486


<IPython.core.display.Javascript object>

### Loop It and Save it 

In [503]:
for param in all_params:
    date = param[0]
    shot = param[1]
    if not os.path.exists(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/gifs_videos/{date}_{shot}.gif") and os.path.exists(f'/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/PCA_processing/{param[0]}_{param[1]}'):
        print(date, shot)
        
for param in all_params:
    date = param[0]
    shot = param[1]
    if not os.path.exists(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/gifs_videos/{date}_{shot}.gif") and os.path.exists(f'/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/PCA_processing/{param[0]}_{param[1]}'):
        date = param[0]
        shot = param[1]
        print(date, shot)

        save_folder_old = '/home/bec_lab/Desktop/imgs/SOAH/SpinAnalysisApril2025/' 
        save_folder_pca = '/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/PCA_processing' 
        angs_old_method = np.load(f"{save_folder_old}data/{date}_{shot}_ang_is.npy")
        PCA_phases = np.array([np.load(f"{save_folder_pca}/{date}_{shot}/{file}") for file in sorted(os.listdir(f"{save_folder_pca}/{date}_{shot}/"))])

        # Additional Phase Unwrapping? 
        # PCA_phases = [[unwrap_phase(unwrap_phase(PCA_phases[i][j])) for j in range(len(PCA_phases[0]))] for i in range(len(PCA_phases))]

        # Get the new cutting window 
        param_wtop, param_wbottom, param_wleft, param_wright = np.load(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/data/cut_arrs/{date}_{str(shot).zfill(4)}.npy")
        print(param_wtop, param_wbottom, param_wleft, param_wright)


        ## This whole thing, in a loop: 
        arr_trend_old = makeATrendArr_OLD([date, shot], smoothing=0, mask_treshold=0.25)
        arr_trend_pca = makeATrendArr_PCA(np.array(PCA_phases), smoothing=0, mask_treshold=0.25)

        nr_atoms = get_nr_of_atoms(date, shot)
        for num in range(nr_atoms):
            PCA_phase = PCA_phases[num]
        #     PCA_phase = diffs_both
            # ─────────────────────────────────────────────────────────────
            # Figure and Grid Setup
            # ─────────────────────────────────────────────────────────────
            fig = plt.figure(figsize=(17, 9))
            gs = gridspec.GridSpec(3, 8, width_ratios=[1, 1, 1, 1, 1, 1, 2, 2], height_ratios=[1, 1, 2], hspace=0.05, wspace=0.05)

            # Line plot axes
            ax41 = (fig.add_subplot(gs[2, 0:2]), fig.add_subplot(gs[2, 2:4]))
            ax43 = (fig.add_subplot(gs[2, 4]), fig.add_subplot(gs[2, 5]))
            names = ["OAH-1", "OAH-2"]

            ax_ratios = (fig.add_subplot(gs[0, 4:6]), fig.add_subplot(gs[1, 4:6]))
            ax_magnetization = (fig.add_subplot(gs[0:3, 6]), fig.add_subplot(gs[0:3, 7]))

            # ─────────────────────────────────────────────────────────────
            # Loop through the two reconstructions
            # ─────────────────────────────────────────────────────────────
            for i in range(2):
                # Phase images
                ax2 = fig.add_subplot(gs[i, 0:2])
                ax3 = fig.add_subplot(gs[i, 2:4])
                ax2.set_ylabel(names[i])

                # Prepare rotated images
                diff_phase = unwrap_phase(rotate(PCA_phase[i], 0.7))
                ang_old = rotate(angs_old_method[num][i][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)

                # Plot phase images
                im2 = ax2.imshow(diff_phase, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')
                im3 = ax3.imshow(ang_old, cmap='Greys', vmin=-0.4, vmax=0.4, origin='lower', aspect='auto')


                ax2.set_ylabel(f"OAH-{i+1}")

                lc_lims = [[28, 50], [550, 800]]

                # Line plots: horizontal average over selected rows
                horizontal_linecut_PCA = diff_phase[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0)
                horizontal_linecut_OLD = ang_old[lc_lims[0][0]:lc_lims[0][1]].mean(axis=0)
                ax41[i].plot(horizontal_linecut_PCA, c='C0', label="PCA")
                ax41[i].plot(horizontal_linecut_OLD, c='C1', label="Old", alpha=0.4)

                # Line plots: vertical average over selected columns
                vertical_linecut_PCA = diff_phase.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0)
                vertical_linecut_OLD = ang_old.T[lc_lims[1][0]:lc_lims[1][1]].mean(axis=0)
                ax43[i].plot(vertical_linecut_PCA, c='C0', label="PCA")
                ax43[i].plot(vertical_linecut_OLD, c='C1', label="Old", alpha=0.4)
                ax43[i].set_xlabel(f"OAH-{i+1}")

                ax43[i].set_ylim([-0.2, max(vertical_linecut_PCA.max(), vertical_linecut_OLD.max()) * 1.1])
                ax41[i].set_ylim([-0.2, max(horizontal_linecut_PCA.max(), horizontal_linecut_OLD.max()) * 1.1])


                # Titles for first row
                if i == 0:
                    ax2.set_title("PCA")
                    ax3.set_title("Old Method")
                    ax2.axvline(x=lc_lims[1][0], c='r', ls='--', alpha=0.35)
                    ax2.axvline(x=lc_lims[1][1], c='r', ls='--', alpha=0.35)
                    ax2.axhline(y=lc_lims[0][0], c='r', ls='--', alpha=0.35)
                    ax2.axhline(y=lc_lims[0][1], c='r', ls='--', alpha=0.35)


            a1_is = rotate(angs_old_method[num][0][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
            a2_is = rotate(angs_old_method[num][1][param_wtop:param_wbottom, param_wleft:param_wright], 0.7)
            ratio = (a1_is - a2_is) / (a1_is + a2_is)
            mask = (a1_is + a2_is) < 0.25
            ratio = np.ma.array(ratio, mask=mask)
            ax_ratios[1].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
            ax_ratios[1].text(len(a1_is[0])//2, 5*a1_is.shape[0]//6, "OLD", weight='bold')

            a1_is = rotate(PCA_phase[0], 0.7)
            a2_is = rotate(PCA_phase[1], 0.7)
            ratio = (a1_is - a2_is) / (a1_is + a2_is)
            mask = (a1_is + a2_is) < 0.25
            ratio = np.ma.array(ratio, mask=mask)
            im = ax_ratios[0].imshow(ratio, vmin=-0.25, vmax=0.25, cmap='RdYlBu', aspect='auto', origin='lower')
            ax_ratios[0].text(len(a1_is[0])//2, 5*a1_is.shape[0]//6, "PCA", weight='bold')


            mask_treshold = 0.251215
            vmin = -0.25
            vmax = 0.25


            ax_magnetization[0].imshow(arr_trend_old, aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)
            ax_magnetization[1].imshow(arr_trend_pca, aspect='auto', cmap = 'RdYlBu',  vmin=vmin, vmax=vmax)

            ax_magnetization[0].axhline(y=num, ls='--', c='r', alpha=0.9)
            ax_magnetization[1].axhline(y=num, ls='--', c='r', alpha=0.9)

            ax_magnetization[0].set_ylim([min(len(arr_trend_old), len(arr_trend_pca)), 0])
            ax_magnetization[1].set_ylim([min(len(arr_trend_old), len(arr_trend_pca)), 0])

            ax_magnetization[0].set_title("OLD")
            ax_magnetization[1].set_title("PCA")

            ax43[0].spines['right'].set_visible(False)
            ax43[1].spines['left'].set_visible(False)

            # ─────────────────────────────────────────────────────────────
            # Final Touches
            # ─────────────────────────────────────────────────────────────
            # Add legends to line plots
            ax41[0].legend(loc=1)
            # Remove ticks from all subplots
            for ax in fig.axes:
                ax.set_xticks([])
                ax.set_yticks([])

            # Global figure title
            fig.suptitle(f"{date} - {shot} - {num}")
            plt.tight_layout()
            plt.subplots_adjust(left=0.05, right=1.1)
            plt.colorbar(im, ax=fig.axes, aspect=30, pad=0.01)

            os.makedirs(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/raw_images_for_gifs/{date}_{shot}/", exist_ok=True)
            plt.savefig(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/raw_images_for_gifs/{date}_{shot}/{str(num).zfill(4)}.png")

            # Layout adjustment
            plt.close()
        save_gif(f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/raw_images_for_gifs/{date}_{shot}/", output_path_gif=f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/gifs_videos/{date}_{shot}", output_path_mp4=f"/home/bec_lab/Desktop/imgs/SOAH/PCA_Analysis/images/gifs_videos/{date}_{shot}" )
        clear_output()