In [None]:
# Copyright (c) 2025, ETH Zurich

In [1]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.stats as stats
import scipy.ndimage
from tqdm import tqdm
import spekpy as spk
import h5py

from scipy import interpolate

from scipy.interpolate import griddata
from scipy import interpolate

In [2]:
rave_sim_dir = Path('../rave-sim').resolve()
simulations_dir = Path('<PATH_TO_STORE_SIMULATIONS>')
scratch_dir = simulations_dir

sys.path.insert(0, str(rave_sim_dir / "nist_lookup"))
from nist_lookup.xraydb_plugin import xray_delta_beta

In [3]:
sys.path.insert(0, str(rave_sim_dir / "big-wave"))
import multisim
import config
import util
import propagation


In [4]:
def calculate_G1_height(eng):
    # constants
    h = 6.62607004 * 10**(-34) # planck constant in mˆ2 kg / s
    c_0 = 299792458 # speed of light in m / s
    eV_to_joule = 1.602176634*10**(-19)
    N_A = 6.02214086 * 10**23 #[1/mol]
    
    lambda_ = h * c_0 / (eng*eV_to_joule)
    delta_diff = xray_delta_beta('Si', 2.34, eng)[0]
    height = np.pi  * lambda_ / (2*np.pi * delta_diff)

    return height


def signal_retrieval_least_squares(data, period=None, axis=-1):
    if axis != -1:
        data = np.moveaxis(data, axis, -1)

    nsteps = data.shape[-1]

    if period is None:
        period = nsteps

    phi = np.linspace(0, 2 * np.pi * nsteps / (period), nsteps, endpoint=False)
    M = np.c_[np.sin(phi), np.cos(phi), np.ones(nsteps)]
    res, chi2, _rank, _sing_vals = np.linalg.lstsq(M, data.reshape((-1, nsteps)).T, rcond=-1)

    res = res.T.reshape((*data.shape[:-1], -1))

    dabs = res[...,2]
    dphase = -np.arctan2(res[...,0], res[...,1])
    dvis = np.sqrt(res[...,0]**2 + res[...,1]**2) / dabs

    # normalization to the total number of counts
    dabs *= nsteps

    return dabs, dphase, dvis, np.nanmean(chi2)


def calculate_pixel_intensity(x, fringe, pxEdges, statistics = 'sum'):
    fringeStats = stats.binned_statistic(x, fringe, bins=pxEdges, statistic = statistics)
    return fringeStats.statistic

def perform_binned_signal_retrieval(x, wf, pxSize, nrSteps, plot_curve = True):
    leftSide = np.arange(0-pxSize/2, np.min(x), -pxSize)
    rightSide = np.arange(0 + pxSize/2, np.max(x), pxSize)
    pxEdges = np.concatenate([np.flip(leftSide), rightSide])
    int_px = []
    for i in range(nrSteps):
        int_px.append(calculate_pixel_intensity(x, wf[i,:], pxEdges))
    int_px = np.asarray(int_px)

    trans, phase, vis, _ = signal_retrieval_least_squares(int_px, period = nrSteps, axis = 0)
    if plot_curve:
        plot_curves(trans, phase, vis, pxEdges)
    return int_px, trans, phase, vis, pxEdges


def perform_binning(x, wf, pxSize):
    leftSide = np.arange(0-pxSize/2, np.min(x), -pxSize)
    rightSide = np.arange(0 + pxSize/2, np.max(x), pxSize)
    pxEdges = np.concatenate([np.flip(leftSide), rightSide])
    int_px = calculate_pixel_intensity(x, wf, pxEdges, statistics = 'mean')
    return int_px

    
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def get_subdict(dict_, idx):
    sub_dict = {}
    for key in dict_.keys():
        sub_dict[key] = dict_[key][idx]
    return sub_dict

def calc_Vis_theoretical(eng, Edes,m):
    V = 2/np.pi * np.abs(np.sin(np.pi / 2 * Edes / eng)**2 * np.sin(m * np.pi / 2 * Edes / eng))
    return V

def mu_h2o(eng):
    lambda_ = h * c_0 / (eng*eV_to_joule)
    beta = xray_delta_beta('H2O', 1.0, eng)[1]
    return 4 * np.pi / lambda_ * beta

In [5]:
# constants
h = 6.62607004 * 10**(-34) # planck constant in mˆ2 kg / s
c_0 = 299792458 # speed of light in m / s
eV_to_joule = 1.602176634*10**(-19)
N_A = 6.02214086 * 10**23 #[1/mol]
E_des = 46000

lambda_ = h * c_0 / (E_des*eV_to_joule)
p2 = 4.2*10**(-6)
p0 = p1 = p2

Dn_5 = 5*p2**2/(2*lambda_) / 2

z_g0 = 0.1
z_g1 = z_g0 + Dn_5
z_g2 = z_g0 + 2*Dn_5
z_detector = z_g2 + 0.01

h0 = h2 = 180e-6
h1 = 59e-6

print("Z0: ", z_g0)
print("Z1: ", z_g1)
print("Z2: ", z_g2)
print("Z Detector: ", z_detector)

print(Dn_5)

Z0:  0.1
Z1:  0.9180881351464869
Z2:  1.736176270292974
Z Detector:  1.746176270292974
0.8180881351464869


In [6]:
N = 2**26
max_energy = 70000
dx = propagation.max_dx(z_g0, 200e-6, N, propagation.convert_energy_wavelength(max_energy))

In [7]:
dx

1.586158759891987e-10

In [9]:
s = spk.Spek(kvp=70, dk = 0.1, th = 10) # Create a spectrum
s.multi_filter((('Be', 0.15), ('Al', 3))) # Create a spectrum
k, f = s.get_spectrum(edges=True) # Get the spectrum

energyRange = [4000, 70000]
dE = 100
filtering = 0.000

energies = np.arange(5, 70+0.1, 0.1)*1e3


tube_spectrum_txt = interpolate.interp1d(k*1e3, f, fill_value = 'extrapolate')
spec_txt = tube_spectrum_txt(energies)

with h5py.File('../spectra/spectrum_70_spekpy_filtered_3mmAl.h5', 'w') as h5:
    h5.create_dataset('pdf', data =  spec_txt/ np.sum(spec_txt))
    h5.create_dataset('energy', data = energies)

path_to_spectrum = os.path.abspath('../spectra/spectrum_70_spekpy_filtered_3mmAl.h5')

In [10]:
path_to_spectrum

'/usr/terminus/data-xrm-01/stamplab/users/vieirapa/Work/Papers/GITHUB_REPOS_FOR_PUBLICATION/grating_defects/spectra/spectrum_70_spekpy_filtered_3mmAl.h5'

In [None]:
plt.plot(energies, spec_txt)

In [None]:
structure_shift = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
notebook_dir = os.getcwd()
path_to_models = os.path.abspath(os.path.join(notebook_dir, "..", "grating__modells/Models_Crack"))


grid_path_g0 = [os.path.join(path_to_models, 'Phase_shift_42_05.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_1.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_15.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_2.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_25.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_3.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_35.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_4.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_45.npy'),
                os.path.join(path_to_models, 'Phase_shift_42_5.npy'),
               ]

source_sizes = np.array([10, 50, 100, 150, 200])*1e-6

dx_g0 = 1.0151416062788299e-08  # Grid size of modells in x-direction
dz_g0 = 180e-6 / 20  # Grid size of model in z-direction

In [None]:
len(grid_path_g0) * len(source_sizes)

In [None]:
for grid_path in grid_path_g0:
    for ss in source_sizes:

        config_dict = {
                    "sim_params": {
                        "N": N,
                        "dx": dx,
                        "z_detector": z_g2 + 500e-6,
                        "detector_size": 0.004,
                        "detector_pixel_size_x": 1e-7,
                        "detector_pixel_size_y": 1.0,
                        "chunk_size": 256 * 1024 * 1024 // 16,  # use 256MB chunks
                    },
                    "use_disk_vector": False,
                    "save_final_u_vectors": False,
                    "dtype": "c8",
                    "multisource": {
                        "type": "points",
                        "energy_range": [11000, 70000],
                        "x_range": [-ss.item(), ss.item()],
                        "z": 0.0,
                        "nr_source_points": 200,
                        "seed": 1,
                        "spectrum": path_to_spectrum,
                    },
                    "elements": [
                        {
                            "type": "sample",
                            "z_start": z_g0,
                            "pixel_size_x": dx_g0,
                            "pixel_size_z": dz_g0,
                            "grid_path": grid_path,
                            "materials": [["Au", 19.32],["C5H8O2", 1.19]],
                            "x_positions": [0.0],
                        },
                        {
                            "type": "grating",
                            "pitch": 4.2*1e-6,
                            "dc": [1.0, 1.0],
                            "z_start": z_g0 + 180*1e-6,
                            "thickness": 500*1e-6 - h0,
                            "nr_steps": 1,
                            "x_positions": [0.0],
                            "substrate_thickness": 0.0,
                            "mat_a": ["C", 2.26],
                            "mat_b": None,
                            "mat_substrate": None
                        },           
                        {
                            "type": "grating",
                            "pitch": p1,
                            "dc": [0.5, 0.5],
                            "z_start": z_g1,
                            "thickness": h1,
                            "nr_steps": 10,
                            "x_positions": [0.0],
                            "substrate_thickness": 200 * 1e-6 - h1,
                            "mat_a": ["Si", 2.34],
                            "mat_b": None,
                            "mat_substrate": ["Si", 2.34],
                        },
                        {
                            "type": "grating",
                            "pitch": p2,
                            "dc": [0.5, 0.5],
                            "z_start": z_g2,
                            "thickness": h2,
                            "nr_steps": 30,
                            "x_positions": (np.arange(5) * p2/5).tolist(),
                            "substrate_thickness": 500*1e-6 - h2,
                            "mat_a": ["C5H8O2", 1.19],
                            "mat_b": ["Au", 19.32],
                            "mat_substrate": ["C", 2.26],
                        },
                    ],
            }
        sim_path = multisim.setup_simulation(config_dict, Path("."), simulations_dir)

        for i in tqdm(range(200)):
            os.system(f"CUDA_VISIBLE_DEVICES=0 ../rave-sim/fast-wave/build-Release/fastwave -s {i} {sim_path}")

In [None]:
# matplotlib style
plt.style.use("default")

# set FIGWIDTH to latex's \textwidth
FIGWIDTH = 3
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 6
plt.rcParams["figure.figsize"] = (FIGWIDTH, FIGWIDTH * 2 / 3)
plt.rcParams["figure.dpi"] = 300
plt.rcParams["figure.constrained_layout.use"] = "True"

# images
plt.rcParams["image.interpolation"] = "bicubic"
plt.rcParams["image.cmap"] = "Greys_r"

# axes
# plt.rcParams["axes.spines.right"] = False
# plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.edgecolor"] = "0.7"
plt.rcParams["axes.linewidth"] = "1"

# legend
plt.rcParams["legend.frameon"] = False

plt.rcParams["lines.markersize"] = 3
# plt.rcParams["lines.markerfacecolor"] = "white"
# Okabe-Ito palette
plt.rcParams["axes.prop_cycle"] = plt.cycler(
    color=[
        "#000000",
        "#E69F00",
        "#56B4E9",
        "#009E73",
        "#F0E442",
        "#0072B2",
        "#D55E00",
        "#CC79A7",
    ],
    marker=["o", "^", "s", "p", "D", "v", "v", "d"],
)

plt.rcParams['axes.grid'] = False

In [None]:
def calculate_dictionary(simulation_dir, pxSize, sample_thickness = None, mu_sample = None):
    all_sims = os.listdir(os.path.join("<INSERT_STORAGE_PATH_OF_SIMULATIONS>"))    

    I = []
    VIS = []
    SNR = []
    energies = []
    sourceSize = []
    shift_grating = []
    transmission_lines = []
    visibility_lines = []
    phase_lines = []
    p1 = []
    
    for idx, sim in tqdm(enumerate(all_sims)):
        sim_path = Path(os.path.join("<INSERT_SOTRAGE_PATH_OF_SIMULATIONS>", sim))
        try:
            config_dict = config.load(Path(sim_path / 'config.yaml'))
            sp = config_dict["sim_params"]
            detector_x = util.detector_x_vector(sp["detector_size"], sp["detector_pixel_size_x"])
            pixel_rectangle = np.abs(detector_x) <= pxSize
    
            wavefronts = util.load_wavefronts_filtered(sim_path, x_range=None, energy_range=[10000, 70000])
    
            sourceSize.append(2*config_dict["multisource"]["x_range"][1])
            try:
                shift_grating.append(float(config_dict["elements"][0]["grid_path"][-7:-4]))

                p1.append(config_dict["elements"][2]["pitch"])
            except:
                shift_grating.append(0.0)
                p1.append(config_dict["elements"][1]["pitch"])
            phase_steps = len(config_dict["elements"][-1]["x_positions"])

    
            # Calculate the wavefront
            summed_wf = np.zeros_like(wavefronts[0][0])
            energies = []
            for k, point in enumerate(wavefronts):
                wf, x_point, eng = point
                energies.append(eng)
                if sample_thickness:
                    summed_wf += wf * np.exp(-mu_sample(eng) * sample_thickness) 
                else:
                    summed_wf += wf
    
            # Convolve the wavefront
            convolved = []
            for i in range(phase_steps):
                convolved.append(np.convolve(summed_wf[i,:], pixel_rectangle, mode = 'same'))
            convolved = np.asarray(convolved)
            
            int_px, trans, phase, vis, pxEdges = perform_binned_signal_retrieval(detector_x, convolved[:,:], pxSize, phase_steps, plot_curve = False)
    
            transmission_lines.append(trans)
            visibility_lines.append(vis)
            phase_lines.append(phase)
            
            shape_ = trans.shape[0]
            middle_axis = int(shape_/2)
            idx_start = middle_axis - int(middle_axis / 2)
            idx_end = middle_axis + int(middle_axis / 2)

            print(idx_start - idx_end)
            #idx_start = 5
            #idx_end = 22
            I.append(np.mean(trans[idx_start:idx_end]))
            VIS.append(np.mean(vis[idx_start:idx_end]))
            SNR.append(np.sqrt(1e5*I[-1]) * VIS[-1])

        except Exception as e:
            print(e)
            print(f"rejected {sim}")
            pass
    I = np.asarray(I)
    VIS = np.asarray(VIS)
    SNR = np.asarray(SNR)
    sourceSize = np.asarray(sourceSize)
    shift_grating = np.asarray(shift_grating)

    visibility_lines = np.asarray(visibility_lines)
    transmission_lines = np.asarray(transmission_lines)
    phase_lines = np.asarray(phase_lines)
    
    
    print(idx_start)
    print(idx_end)
    energies = np.asarray(energies)
    p1 = np.asarray(p1)
    return {
        "I": I,
        "VIS": VIS,
        "Signal": SNR,
        "sourceSize": sourceSize,
        "shift_grating": shift_grating,
        "visibility": visibility_lines,
        "transmission": transmission_lines,
        "phase": phase_lines,
        "p1": p1,
        #"energies": energies
        }, energies
    

In [None]:
dict_plots, energies = calculate_dictionary(simulations_dir, 75*1e-6)

In [None]:
# Account for bridge fraction

dict_plots['VIS'] = dict_plots['VIS'] * 0.91
dict_plots['visibility'] = dict_plots['visibility'] * 0.91

In [None]:
def plot_against_keys(dict_, axis1, axis2, scale1, scale2, c1, c2, keys = None, figsize = (30, 16), return_set = False):
    titels = {'I': 'Intensity',
             'VIS': 'Visibility',
             'Signal': r"$\sqrt{I} * V$",
             'ANG': r"$p_2 / D$",
             'DOSE_G1': 'Intensity at G1',
             'SNR': r"$\sqrt{I} * V * \frac{D}{p_2}$",
             'weighted_SNR': "SNR for equal dose as GI-BCT",
             'normalized_dose': "weighted dose at G1"}
    

    
    if keys == None:
        keys = ['VIS', 'I', 'Signal', 'ANG', 'normalized_dose', 'SNR']
    N = len(keys)
    nrows = N//2
    ncols = N//nrows
    #if ax is None:
    fig, ax = plt.subplots(figsize = figsize, sharex = True, sharey = True, nrows = nrows, ncols = ncols)
    ax = ax.ravel()

    conts = []
    Tis = []
    for i, k in enumerate(keys):

        
        # Determine the common color scale
        vmin = np.min(dict_[k])
        vmax = np.max(dict_[k])
        level_boundaries = np.linspace(vmin, vmax, 5)
        print(vmin)
        #ax[i].scatter(x=dict_[axis1]*scale1, y=dict_[axis2]*scale2, c=dict_[k])
        ax[i].set_title(f'{titels[k]}', fontsize = 20)
        # Interpolate using three different methods and plot
        Ti = griddata((dict_[axis1]*scale1, dict_[axis2]*scale2), dict_[k], (c1, c2), method='linear')

        Tis.append(Ti)

        if k == 'I':
            level_boundaries = np.linspace(15, 25, 10)
            im = ax[i].imshow(Ti[:47, :42], vmin = 18, vmax = 23, cmap ='inferno',  extent = [0.0, 0.4925, 20.0, 296.99],)
        else:
            im = ax[i].contourf(c1, c2, Ti, levels = 20, vmin = vmin, vmax = vmax, cmap = 'inferno')
            conts.append(im) #ax.colorbar()

    
        cbar = plt.colorbar(im,  ticks=level_boundaries)
        cbar.ax.tick_params(axis='both', which='both', labelsize=15)
        ax[i].set_xlim(dict_[axis1].min()*scale1, dict_[axis1].max()*0.985*scale1)
        ax[i].set_ylim(dict_[axis2].min()*scale2, dict_[axis2].max()*0.99*scale2)
        print(dict_[axis1].min()*scale1, dict_[axis1].max()*0.985*scale1)
        print(dict_[axis2].min()*scale2, dict_[axis2].max()*0.99*scale2)
        #ax[i].set_xlabel(f'Shift as a fraction of π', fontsize = 20)
        #ax[i].set_ylabel(f'Source size [µm]', fontsize = 20)

        ax[i].tick_params(axis='x', labelsize=20)
        ax[i].tick_params(axis='y', labelsize=20)
        ax[i].axes.set_aspect('auto')
        

    fig.supxlabel(f'Shift as a fraction of 2π', fontsize = 20, y = -0.05)
    fig.supylabel(f'Source size [µm]', fontsize = 20, x = 0.07)
    fig.suptitle('Influence of local line shift in G0', fontsize = 20, y = 1.05)
    if return_set:
        return conts, ax, Tis


In [None]:
sh_grat = np.arange(0, 0.6, 0.012)
s_size = np.arange(20, 320, 6)

p0v, p2v = np.meshgrid(sh_grat, s_size)

keys = ['VIS', 'Signal', 'I']
plot_against_keys(dict_plots, 'shift_grating', 'sourceSize', 1, 1e6, p0v, p2v, keys, figsize = (19,5), return_set = False)

In [None]:
dict_plots['VIS'].shape
plt.figure(figsize=(3.5,3))
vmin = np.min(dict_plots['VIS'])
vmax = np.max(dict_plots['VIS'])
level_boundaries = np.linspace(vmin, vmax, 5)

scale1 = 1
scale2 = 1e6

sh_grat = np.arange(0, 0.6, 0.012)
s_size = np.arange(20, 320, 6)

c1, c2 = np.meshgrid(sh_grat, s_size)


Ti = griddata((dict_plots['shift_grating']*scale1, dict_plots['sourceSize']*scale2), dict_plots['VIS'], (c1, c2), method='linear')

im = plt.contourf(c1, c2, Ti, levels = 20, vmin = vmin, vmax = vmax, cmap = 'inferno')
cbar = plt.colorbar(im,  ticks=level_boundaries)

cbar.ax.tick_params(axis='both', which='both', labelsize=8)
cbar.set_ticks([0.07, 0.12, 0.17])
plt.xlim(dict_plots['shift_grating'].min()*scale1, dict_plots['shift_grating'].max()*0.985*scale1)
plt.ylim(dict_plots['sourceSize'].min()*scale2, dict_plots['sourceSize'].max()*0.985*scale2)

plt.tick_params(axis='x', labelsize=8)
plt.tick_params(axis='y', labelsize=8)
plt.gca().set_aspect('auto')
plt.title('Visibility', fontsize = 8)
plt.xlabel(f'Shift as a fraction of 2π', fontsize = 8)
plt.ylabel(f'Source size [µm]', fontsize = 8)
