# PyGRAPES Simulation script

## with default settings, this requires a GPU and uses approximately 5GB of VRAM. 
#### The simulation reconstructs a Au Siemens Star with height of 5 nm from simulated data.
#### running the notebook on CPU is in principle possible, but has not been tested.

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import time
import scipy.io
import h5py
import matplotlib.patches as patches
import sys
import torchvision.transforms as transforms
import random
import gc
import torchvision
from datetime import datetime
from skimage.draw import disk, ellipse
from scipy.ndimage import gaussian_filter
from IPython.display import display, clear_output
from torchvision.transforms import v2
import os
import re 

## Check CUDA is available.

### in principle, the code should be able to run without a CUDA device, but this has not been tested. 

In [None]:

!nvidia-smi



In [None]:
print("cuda device count:",torch.cuda.device_count())
print("cuda available:",torch.cuda.is_available())
if torch.cuda.is_available():
    torch.set_default_device('cuda:0')
    torch.cuda.set_device('cuda:0')
    print("setting default device to cuda")
    num_gpus = torch.cuda.device_count()

    for i in range(num_gpus):
        # Get device properties
        # device = torch.device(f'cuda:{i}')
        properties = torch.cuda.get_device_properties(torch.device(f'cuda:{i}'))

        # Print the total memory on the current GPU
        print(f"GPU {i}: {properties.name}, Total Memory: {properties.total_memory / (1024 ** 3):.2f} GB")

    print("CUDA version:", torch.version.cuda)
    current_device = torch.cuda.current_device()
    print("current CUDA device is", current_device)
else: 
    print("cuda is not available. default device is CPU. Not tested")

In [None]:
def gpu_memcheck():
    allocated_memory = torch.cuda.memory_allocated()
    print("Currently allocated memory:", allocated_memory // (1024 * 1024), "MB")


# Input Settings 

### The following 3 cells are where reconstruction settings are defined, such as simulation size, resolution, wavelength, and many others.

## (1/3)  Define Initial volume size and detector window size

In [None]:
voxel_size = torch.tensor((5e-9,40e-9)) # (x,y) (normal to surface, perpendicular to x-ray propagation direction)
slab_thickness = torch.tensor(500e-9) #z-direction resolution, parallel to x-ray propagation direction (thickness of 1 slice in multislice)
full_sim_size = [1300,250,750] #a tuple specifying x,y,z size of the full volume.
params_size = [10,full_sim_size[1],full_sim_size[2]] # a tuple specifying x,y,z size of optimizable structure (default y, and z are the same as full_sim_size).
crop_window_size = [400,180] #[vertical pixels, horizontal pixels] equivalent to the window size of the detector used in in conventional ptychography ("asize" in ptychoshelves, for example).
print("Voxel size, x:", (voxel_size[0]*1e9).item(), "nm, y:",(voxel_size[1]*1e9).item(), "nm")
print("Slab thickness:", (slab_thickness*1e9).item(),"nm")


probe_buffer = 0 #not used, but leave as 0
shift_amount = 0 #not used, but leave as 0


## (2/3) Define Experimental Wavelength, and define any refractive indices to be used in the reconstruction.

##### the code refers to Audelta and Aubeta but these can be changed for other materials if the variable name is kept the same.

In [None]:
lambda_val = torch.tensor(1.99983862803066e-10) #simulation Wavelength. Default  = 6.02 keV (commonly used at cSAXS, PSI)
Audelta = torch.tensor(7.99106419e-05) #Since this simulation is only Au, we are just defining Au delta and beta for this energy.
Aubeta = torch.tensor(1.24435501e-05) #Since this simulation is only Au, we are just defining Au delta and beta for this energy.



## (3/3) Define Reconstruction specific Parameters. 

In [None]:
# Optimization settings.
using_n_scans_per_pixel = True #leave as true
do_gradient_accumulation = True #whether structure is optimized after each scan (false), or after all scans are calculated (True) . Best left as true.
load_previous = False #if loading a previously saved/optimised 
use_per_scan_TV = False #True means TV is calculated just for scan region rather than entire volume. best to leave as false.
use_multiple_probe_modes = True #true probe mode modelling is not  implemented, but leave as true regardless.
use_noise = True #best left as true, if noise is not needed, you can set noise to zero and noise LR to zero. False disables noise modelling entirely.
optimize_scan_offsets = True #whether subpixel shifts and scan positions are optimized. Can leave as true, and set LR to zero.
oversample_structure = True #whether to apply oversampling, more z slices than are used in the simulation. Leave as true, and set to 1, to not use oversampling.
oversample_factor = 1 #leave as 1 to effectively turn oversampling off.

# set learning rates and initial guesses.
init_xr_scaling_factor = 0.02 #the magnitude (in voxels) of random fluctuatoin of your initial guess.
init_xr_substrateamt = 0.00 #provide an offset of all initial values of initial guess.
init_xr_learning_rate = 1e5 #initial learning rate seems high, but 1e5 has been tested to work reasonably well. 
grads_target = 0.1 #not used, leave at 0.1
scan_positions_LR_fine = 1e-15 #per scan position optimization
scan_positions_LR_coarse = 5e-15#position optimization of an entire set of scans. useful for coarse motor translations in stitch scans.
probe_prop_LR = 1e-4 #not used at the moment.
noise_std = 0 #noise init guess follows a gaussian dist, this provides Standard dev.
noise_mean = 0 #noise init guess follows a gaussian dist, this provides mean.
probe_grads_target = 1.0e-8 #rather than a learning rate, the probe has an average value of gradients, so that optimization is on average around the magnitude of this parameter

#total variation regularization
tvx_strength = 1 #to change the relative intensity of total variation in X direction. Not used in demo script.
tvy_strength = 1 #to change the relative intensity of total variation in Y direction. Best left at 1 initially.
tvz_strength = 1 #to change the relative intensity of total variation in Z direction. Best left at 1 initially.
total_TV_strength = 1e-3 #total contributoin of ALL total variation penalty.

#additional simulation settings.
substrate_layers = 200 #number of layers from the bottom upwards of substrate. This shuold be enough to fully reflect/absorb the incoming beam.
probe_substrate_buffer = 2 #extra spacing between probe and substrate.
top_buffer = 0 #add padding in the final multislice component if the beam reflects too early or too steeply.
pre_prop_dist_multiplier = 1.5 #add extra propagation distance before the beam encounters the structure.
post_prop_dist_multiplier = 1 #add extra propagation distance after the beam encounters the structure.
test_inc_angle = 0.7 # the incidence angle to test out these settings.
slab_pad_pre = 0 #add extra slices, not used, but necessary atm for code to run.
slab_pad_post = 0 #add extra slices, not used , but necessary atm for code to run.

# hyperparamters
num_iters = 50 #number of iteratoins to run for.
probestart = 2 #iteration number at which the probe will begin optimising.
tv_start = 2 #iteration number at which TV penalty will begin to apply.
patience = 5 #number of iterations to wait for reducing LR if loss function is increasing.
divergence_count_start = 5 #when LR scheduling begins


# Define Auxiliary Functions

### run this all and skip forward in the notebook. 

In [None]:
def wave_interact_partial_voxels(wave, c, slab_thickness, wavelength,ndelta,nbeta):
    return wave * (c*torch.exp(1j * 2 * torch.pi * (ndelta+1j*nbeta) * slab_thickness / wavelength) + (1-c)*torch.exp(1j * 2 * torch.pi * (0+1j*0) * slab_thickness / wavelength))

def wave_interact_partial_voxels_AuAg(wave, c, slab_thickness, wavelength,ndelta,nbeta):
    wavetop,wavesub = torch.split(wave,[full_sim_size[0]-substrate_layers,substrate_layers])
    ctop,csub = torch.split(c,[full_sim_size[0]-substrate_layers,substrate_layers])
    top = wavetop* (ctop*torch.exp(1j * 2 * torch.pi * (ndelta+1j*nbeta) * slab_thickness / wavelength) + (1-ctop)*torch.exp(1j * 2 * torch.pi * (0+1j*0) * slab_thickness / wavelength))
    substrate = wavesub * (csub*torch.exp(1j * 2 * torch.pi * (Audelta+1j*Aubeta) * slab_thickness / wavelength))
    return torch.cat((top,substrate),dim=0)

def wave_interact_full_complex(wave, c, slab_thickness, wavelength):
    return wave * torch.exp(1j * 2 * torch.pi * (c) * slab_thickness / wavelength) 

def wave_interact_AndProp_partial_voxels_deltabeta(wave, c_r, c_i, slab_thickness, wavelength,ndelta,nbeta,tf):
    wave = torch.fft.ifft2(torch.fft.fft2(wave*torch.exp(1j * 2 * torch.pi * 
            (ndelta*(c_r)+1j*nbeta*c_i) * slab_thickness / wavelength))*torch.exp(tf)) 
    return wave

def optimise_probe_prop(probe_in,propdist,lambda_val,voxel_size):
    h1 = get_tf_longprop(propdist,lambda_val,voxel_size,probe_in.size())
    propped_probe_optimised = longprop(probe_in,h1)

    return propped_probe_optimised



In [None]:
def create_freq_mesh(voxel_size,shape):
    u = torch.fft.fftfreq(shape[1])
    v = torch.fft.fftfreq(shape[0])
    vv,uu = torch.meshgrid(v,u)
    vv = vv/voxel_size[0]
    uu = uu/voxel_size[1]
    return uu,vv
# uutest,_ = create_freq_mesh(voxel_size,full_sim_size)
# plt.imshow(uutest.cpu())
def get_tf_longprop(propdist,lambda_val,voxel_size,grid_shape):
    u,v = create_freq_mesh(voxel_size,grid_shape)
    H = torch.exp(-1 * 1j*torch.pi*lambda_val*propdist*(u**2+v**2))
    return H

def get_tf_nearfield(propdist,lambda_val,voxel_size,grid_shape):
    u,v = create_freq_mesh(voxel_size,grid_shape)
    quad = 1-(u**2+v**2)*(lambda_val**2)
    quad_inner = torch.clamp(quad,min=0)
    quad_mask = quad>0
    H = (2j * torch.pi * (propdist / lambda_val)*torch.sqrt(quad_inner))
    
    return H * quad_mask

def farfield_PSI_prop(wave_in,lambda_val,propdist,voxel_size):
    N = wave_in.size()
    g1 = torch.arange(-(N[0]/2),(np.floor((N[0]-1)/2)))
    g2 = torch.arange(-(N[1]/2),(np.floor((N[1]-1)/2)))
    [x,y] = torch.meshgrid(g1,g2)
    r2 = x**2+y**2
    propdist = propdist/voxel_size[0]
    lambda_val = lambda_val/voxel_size[0]
    wout = -1j * torch.exp(1j * torch.pi * lambda_val * propdist * r2 / (N[0]*N[1])) * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(wave_in * torch.exp(1j * torch.pi * r2 / (lambda_val*propdist)))))
    return wout

def farfield_PSI_prop_2(wave_in,lambda_val,propdist,voxel_size):
    N = wave_in.size()
    g1 = torch.arange(-(N[0]/2),(np.floor((N[0]-1)/2)))
    g2 = torch.arange(-(N[1]/2),(np.floor((N[1]-1)/2)))
    [x,y] = torch.meshgrid(g1,g2)
    r2 = x**2+y**2
    u,v = create_freq_mesh(voxel_size,N)
    u = torch.fft.fftshift(u)/(2*torch.max(u))
    v = torch.fft.fftshift(v)/(2*torch.max(v))
#     propdist = propdist/voxel_size[0]
#     lambda_val = lambda_val/voxel_size[0]

    H = torch.exp(1j * torch.pi * lambda_val/voxel_size[0] * propdist/voxel_size[0] * r2/(N[0]*N[1]))
    pre_exp = -1j * H#torch.exp(1j * torch.pi * lambda_val * propdist * r2 / (N[0]*N[1]))
    tf_inner = torch.exp(1j * torch.pi * r2 / ((lambda_val/voxel_size[0])*(propdist/voxel_size[0])))
    tf_inner_x = torch.exp(1j * torch.pi * (x**2) / ((lambda_val/voxel_size[1])*(propdist/voxel_size[1])))
    tf_inner_y = torch.fft.fftshift(torch.exp(1j * torch.pi * (y**2) / ((lambda_val/voxel_size[0])*(propdist/voxel_size[0]))))
    wout = pre_exp * torch.fft.ifftshift(torch.fft.fft(torch.fft.fft(torch.fft.fftshift(wave_in*tf_inner_x),dim=0)*tf_inner_y,dim=1))
    return wout

def farfield_PSI_prop1d(wave_in,lambda_val,propdist,voxel_size):
    N = wave_in.size()
    g1 = torch.arange(-(N[0]/2),(np.floor((N[0]-1)/2)))
    g2 = torch.arange(-(N[1]/2),(np.floor((N[1]-1)/2)))
    [x,y] = torch.meshgrid(g1,g2)
    r2 = x**2#+y**2
    propdist = propdist/voxel_size
    lambda_val = lambda_val/voxel_size
    wout = -1j * torch.exp(1j * torch.pi * lambda_val * propdist * r2 / (N[0]**2)) * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(wave_in * torch.exp(1j * torch.pi * r2 / (lambda_val*propdist)))))
    return wout


def longprop(wave_in,h):

    f1 = torch.fft.fft2(wave_in)
    oldflux = torch.sum(torch.abs(f1))
    fh = (f1*h)
    newflux = torch.sum(torch.abs(fh))
    fluxratio = oldflux/newflux

    return torch.fft.ifft2(fh)

def wave_propagate_2d_RI(wavein,h):
    f1 = torch.fft.fft2(wavein)
    oldflux = torch.sum(torch.abs(f1))
    h_exp = (torch.exp(h))
#     h_real = torch.real(h_exp)
#     h_imag = torch.imag(h_exp)
    f1_real = torch.real(f1)
    f1_imag = torch.imag(f1)
    fh_real = f1_real * h_exp - f1_imag * h_exp
    fh_imag = f1_real * h_exp + f1_imag * h_exp
    fh = (fh_real)+1j*(fh_imag)
    newflux = torch.sum(torch.abs(fh))
    fluxratio = oldflux/newflux
    return torch.fft.ifft2((fh*fluxratio))

def fresnel_exit(wave_in,lambda_val,distance,dx,dy):
    
    k = 2*torch.pi/lambda_val
    nx,ny = wave_in.shape
    x = torch.linspace(-nx//2, nx//2 - 1, nx) * dx
    y = torch.linspace(-ny//2, ny//2 - 1, ny) * dy
    X, Y = torch.meshgrid(x, y, indexing='ij')
    fx = torch.fft.fftfreq(nx, d=dx)
    fy = torch.fft.fftfreq(ny, d=dy)
    FX, FY = torch.meshgrid(fx, fy, indexing='ij')
    input_wave_fft = torch.fft.fft2(wave_in)

      # Compute the Fresnel propagation phase term in spatial domain
    phase_term = torch.exp(1j * k * distance) * torch.exp(-1j * k / (2 * distance) * (X**2 + Y**2))
    # Apply the phase term in the Fourier domain
    output_wave_fft = input_wave_fft * phase_term
    output_wave = (output_wave_fft)

    
    return output_wave,phase_term
    

In [None]:
def MSForward_GI_SS_novol_partialvoxel_pre(xr,full_sim_size,params_size,voxel_size,slab_thickness,inc_angle,slab_pad_pre,slab_pad_post,probe,probe_buffer,wavemask,shift_amt,substrate_layers,init_xr_substrateamt,top_buffer,subpixel_shift_y,subpixel_shift_z,hx_shift,hz_shift):
        
        old_max = torch.max(torch.abs(torch.fft.fft2(probe)))
        substrate_layers_pre = substrate_layers
        substrate_start = (full_sim_size[0]-substrate_layers_pre)*voxel_size[0]
#         probe_substrate_buffer = 20
        probe_in = probe
        probe_halfsize = (probe_in.size(0)*voxel_size[0])/2
        probe_insertion_Y = int((full_sim_size[1] - probe_in.size(1))/2)
        reflection_distance = 2*(probe_halfsize+probe_substrate_buffer*voxel_size[0])/(torch.tan(torch.deg2rad(torch.tensor(inc_angle))))
        pre_post_amt = (reflection_distance - (xr.size(2)*slab_thickness)) / 2
        new_c_wave_size = probe_in.size(0)+substrate_layers_pre+probe_substrate_buffer
        if new_c_wave_size < full_sim_size[0]:
            new_c_wave_size = full_sim_size[0]
        c_wave = torch.complex(torch.zeros((new_c_wave_size),full_sim_size[1]),torch.zeros((new_c_wave_size),full_sim_size[1])).to(torch.complex64)

        if pre_post_amt<0:
            pre_post_amt = 1e-6
        
        n_slices = 50
        pre_prop_dist = pre_post_amt/n_slices
        newphantom = torch.zeros(c_wave.size(0),c_wave.size(1))
        newphantom[-int(substrate_layers_pre):,:] = 1
#         print(pre_prop_dist)
#         print(new_c_wave_size)
#         tf2 = (get_tf_nearfield(pre_prop_dist,lambda_val,voxel_size,c_wave.size()))
        tf2 = get_tf_longprop(pre_prop_dist,lambda_val,voxel_size,c_wave.size())
        probe_insertion_x_coords = ((substrate_layers+probe_substrate_buffer+probe_in.size(0)),(substrate_layers+probe_substrate_buffer))
        if probe_insertion_x_coords[1] < substrate_layers:
            probe_cutoff_amount = substrate_layers - probe_insertion_x_coords[1]
            probe_in[-probe_cutoff_amount:,:] = 0
        c_wave[-probe_insertion_x_coords[0]:-probe_insertion_x_coords[1],probe_insertion_Y:int(probe_insertion_Y+probe.size(1))] = probe_tilt_gradient(probe_in,inc_angle,voxel_size,pre_prop_dist)
        c_wave = probe_subpixel_shift_fourier(c_wave,subpixel_shift_z,subpixel_shift_y,slab_thickness,voxel_size,hx_shift,hz_shift)
        
        for n1 in range(n_slices):
            c_wave = wave_interact_partial_voxels(c_wave,newphantom,pre_prop_dist,lambda_val,Audelta,Aubeta)
            c_wave = longprop(c_wave,tf2)#wave_propagate_2d_faster_usefilter(c_wave,tf2,1)#
        
        c_wave = torch.fft.ifft2(torch.fft.fft2(c_wave))
        c_wave = c_wave[-int(full_sim_size[0]):,:]
        

        
        
    
        return c_wave,pre_post_amt,probe_insertion_Y
        

In [None]:
def MSForward_GI_SS_novol_partialvoxel_mid(xr,cwavein,full_sim_size,params_size,voxel_size,slab_thickness,inc_angle,substrate_layers):
        c_wave = cwavein
        
#         tf = get_tf_nearfield(slab_thickness,lambda_val,voxel_size,c_wave.size())
        tf = get_tf_longprop(slab_thickness,lambda_val,voxel_size,c_wave.size())
        upscale_slices = int(torch.round(torch.tensor(full_sim_size[2]/xr.size(2))))
        upscale_slices_x = int(torch.round(torch.tensor(params_size[0]/xr.size(0))))
        us_size = (params_size[0],params_size[1])
        midslice = torch.zeros(c_wave.size(1),xr.size(2))
        sideslice = torch.zeros(c_wave.size(0),xr.size(2))
        for n1 in range(xr.size(2)):
            S1 = xr[:,:,n1]

            this_slab = torch.nn.functional.pad(torch.nn.functional.pad(S1,(0,0,int(full_sim_size[0]-S1.size(0)-substrate_layers),0),value=0),(0,0,0,substrate_layers),value=1)

            for n2 in range(upscale_slices):
                if (torch.cuda.memory_allocated() / (1024**2)) > 24000:

                    with torch.autograd.graph.save_on_cpu(pin_memory=True):
                        c_wave = wave_interact_partial_voxels(c_wave,this_slab,slab_thickness,lambda_val,Audelta,Aubeta)
                        c_wave = longprop(c_wave,tf)#wave_propagate_2d_faster_usefilter(c_wave,tf,1)
#                         c_wave=c_wave*wavemask
                else:
                    c_wave = wave_interact_partial_voxels(c_wave, this_slab, slab_thickness, lambda_val,Audelta,Aubeta)

                    c_wave = longprop(c_wave,tf)

                with torch.no_grad():
                    midslice[:,n1] = torch.abs(c_wave[-substrate_layers,:])
                    sideslice[:,n1] = torch.abs(c_wave[:,int(full_sim_size[1]/2)])

        return c_wave, midslice,sideslice
    


In [None]:
def MSForward_GI_SS_novol_partialvoxel_post(xr,cwavein,full_sim_size,params_size,voxel_size,slab_thickness,inc_angle,substrate_layers,prepostdist,probe_insertion_Y,probesize,init_xr_substrateamt):
        n_slices = 100
        side_buffer = 50
        pre_prop_dist = (prepostdist*post_prop_dist_multiplier)/n_slices
        top_x_post = 200    
        top_buffer = top_x_post
        substrate_layers_pre = substrate_layers
        c_wave = torch.nn.functional.pad(cwavein,(side_buffer,side_buffer,top_buffer,0),value=0)
        
        
        newphantom = torch.zeros(c_wave.size(0),c_wave.size(1))
        newphantom[-int(substrate_layers_pre):,:] = 1
        tf2 = get_tf_longprop(pre_prop_dist,lambda_val,voxel_size,c_wave.size())
        for n1 in range(n_slices):
            c_wave = wave_interact_partial_voxels(c_wave,newphantom,pre_prop_dist,lambda_val,Audelta,Aubeta)
            c_wave = longprop(c_wave,tf2)
        c_wave = probe_tilt_gradient(c_wave,inc_angle,voxel_size,pre_prop_dist)

        return c_wave

In [None]:
def multislice_3stage(this_xr,
                full_sim_size, params_size, voxel_size, slab_thickness,
                inc_angle, slab_pad_pre, slab_pad_post, probes_in_shifted,
                probe_buffer,wave_deletion_mask,shift_amount,substrate_layers,init_xr_substrateamt,top_buffer,blankew,subpixel_shift_y,subpixel_shift_z,hx_shift,hz_shift):
    probesize = probes_in_shifted.size()
    

    pre_EW,pre_post_amt,probe_insertion_Y = MSForward_GI_SS_novol_partialvoxel_pre(
                        this_xr,
                        full_sim_size, params_size, voxel_size, slab_thickness,
                        inc_angle, slab_pad_pre, slab_pad_post, probes_in_shifted,
                        probe_buffer,1,0,substrate_layers,init_xr_substrateamt,0,subpixel_shift_y,subpixel_shift_z,hx_shift,hz_shift)
    midEW,midslice,sideslice = MSForward_GI_SS_novol_partialvoxel_mid(this_xr,pre_EW,
                   full_sim_size,params_size,voxel_size,slab_thickness,inc_angle,(substrate_layers))

    exit_wave = MSForward_GI_SS_novol_partialvoxel_post(this_xr,midEW,
                    full_sim_size,params_size,voxel_size,slab_thickness,inc_angle,(substrate_layers),pre_post_amt,probe_insertion_Y,probesize,init_xr_substrateamt)

    F1 = torch.abs(torch.fft.fftshift(torch.fft.fft2(exit_wave.flip(0))))
    F1 = torch.nn.functional.interpolate(F1.unsqueeze(0).unsqueeze(0),size=(probesize),mode='bilinear').squeeze()

    return F1,midslice


In [None]:
    

def probe_tilt_gradient(probe,inc_angle,voxel_size,slab_thickness):
        
        phase_gradient_Y = torch.linspace(0,voxel_size[0]*probe.shape[0]*torch.tan(torch.deg2rad(torch.tensor(inc_angle))),probe.shape[0])
        phase_gradient_X = torch.linspace(0,voxel_size[1]*probe.shape[1]*torch.tan(torch.deg2rad(torch.tensor(inc_angle))),probe.shape[1])
        
        PhaserampmatY, _ = torch.meshgrid(-phase_gradient_Y,phase_gradient_X)

        probe_amp = torch.abs(probe)
        probe_phase = torch.angle(probe)
        
        tilted_probe = probe_amp * torch.exp(1j*(probe_phase-(2*torch.pi*(PhaserampmatY/lambda_val))))
        
        return tilted_probe
  



In [None]:
def fill_3d_tensor(A,structure_height,oversample_structure,oversample_factor):

    # Get the dimensions of the 2D tensor
    width,height = A.shape
    output3d = A.repeat(structure_height,1,1)-torch.arange(structure_height).unsqueeze(1).unsqueeze(1).expand(structure_height,width,height)
    output3d = torch.clamp(output3d,0,1)
    if oversample_structure == True and oversample_factor !=1:
        weight = torch.ones(1,1,1,1,oversample_factor)/oversample_factor
    
        output3d = torch.nn.functional.conv3d(output3d.unsqueeze(0).unsqueeze(0), weight, stride=(1,1,oversample_factor), padding=0, dilation=1, groups=1).squeeze()

    if not output3d.size(1) == params_size[1]:
        sd = output3d.size(1) - params_size[1]
        if sd > 0:
            output3d = output3d[:,:-int(sd),:]
        if sd < 0:
            output3d = torch.nn.functional.interpolate(output3d.unsqueeze(0).unsqueeze(0), size=params_size, mode='trilinear').squeeze()
    
    return output3d.flip(0)


In [None]:
def interp_3d_tensor(A,structure_height,oversample_structure,oversample_factor):

    if oversample_structure == True:
        # weight = torch.ones(1,1,1,1,oversample_factor)/oversample_factor
    
        # output3d = torch.nn.functional.conv3d(A.unsqueeze(0).unsqueeze(0), weight, stride=(1,1,oversample_factor), padding=0, dilation=1, groups=1).squeeze()
        output3d = A # pooler3d = torch.nn.AvgPool1d(kernel_size=(oversample_factor))
        # output3d = pooler3d(output3d)
    if not output3d.size(1) == params_size[1]:
        sd = output3d.size(1) - params_size[1]
        if sd > 0:
            output3d = output3d[:,:-int(sd),:]
        if sd < 0:
            output3d = torch.nn.functional.interpolate(output3d.unsqueeze(0).unsqueeze(0), size=params_size, mode='trilinear').squeeze()
    
    return output3d


In [None]:
def get_maximum_propdist(full_sim_size,voxel_size,lambda_val):
    Propagator_wave_dim = full_sim_size[0:2]
    propagator_wave_size = torch.tensor(Propagator_wave_dim) * torch.tensor(voxel_size)
    dist_picker_r = Propagator_wave_dim[0] * (voxel_size[0])**2 / lambda_val
    dist_picker_c = Propagator_wave_dim[1] * (voxel_size[1])**2 / lambda_val
    
    print('the maximum prop dist for these params =',"%5.5E" % (dist_picker_r),'m')
    
def get_maximum_propdist2(delta_res,lambda_val):
    z = (0.32 * (delta_res)**2) / lambda_val
    print('the maximum prop dist for these params =',"%5.5E" % (z),'m')
    


In [None]:
def flux_rescale(probe_in, flux_in): 
    current_flux = torch.sum(torch.abs(probe_in) ** 2)
    flux_rescale_factor = flux_in / current_flux
    probe_out = (flux_rescale_factor)**2 * torch.abs(probe_in) * (torch.cos(torch.angle(probe_in)) + 1j * torch.sin(torch.angle(probe_in)))
    return probe_out
def retain_flux(previous_amp,previous_phase, current_amp, current_phase):
    current_probe = current_amp * torch.exp(1j*current_phase)
    previous_probe = previous_amp * torch.exp(1j*previous_phase)
    current_flux = torch.sum(torch.abs(current_probe) ** 2)
    previous_flux = torch.sum(torch.abs(previous_probe) ** 2)
    flux_rescale_factor = current_flux / previous_flux
    current_probe_rescaled  = (flux_rescale_factor)**2 * torch.abs(current_probe) * (torch.cos(torch.angle(current_probe)) + 1j * torch.sin(torch.angle(current_probe)))
    amp = torch.abs(current_probe_rescaled)
    phase = torch.angle(current_probe_rescaled)
    
    return amp, phase

In [None]:
def make_prop_mask(inputwave,rad,blur,aspr):
    if type(aspr) == torch.Tensor:
        aspr = aspr.cpu().numpy()
        
    inputsize = inputwave.size()
    propmaskinit = np.zeros(inputsize)
    mask_size = inputsize

    # Create a grid of coordinates
    x, y = np.meshgrid(np.arange(mask_size[1])-mask_size[1]/2, (np.arange(mask_size[0])-mask_size[0]/2))

    # Calculate the distance from the center of the circle
    distance = np.sqrt((x)**2 + (y/aspr)**2)

    # Create a binary circle mask
    circle_mask = np.where(distance <= rad, 1, 0)
    circle_mask = circle_mask.astype(float)
    # propmaskinit[30:-30,30:-30] = 1
    sigma = blur
    propmask_b = torch.tensor(gaussian_filter(circle_mask,sigma)).to(torch.float32)

    # propmask_b = torch.tensor(circle_mask).to(torch.float32)

    propmask_b[propmask_b>1] = 1
    return propmask_b



In [None]:
def get_tf_modelprobe(propdist,lambda_val,voxel_size,grid_shape):
    u,v = create_freq_mesh(voxel_size,grid_shape)
    u = u/voxel_size[0]
    v = v/voxel_size[1]
    H = torch.exp(-1j*torch.pi*lambda_val*propdist*(u**2+v**2))
    return H

    
def make_csaxs_model_probe(wavelength,probe_dims,desired_pixel_size,
                           probe_diameter,central_stop_diameter,zone_plate_diameter,
                          outer_zone_width,prop_dist):
    
    zp_f = zone_plate_diameter*outer_zone_width/wavelength
#     print("zp_f:", zp_f)
    upsample = 10
    voxel_ratio = desired_pixel_size[1]/desired_pixel_size[0]
    
    
    defocus = prop_dist
    Nprobe = upsample*torch.tensor(probe_dims)
    padsize = int(((Nprobe[0]*(voxel_ratio-1))/2))
#     print("guessed pad:", padsize)
#     print("desired pixel size before psi rando trans",desired_pixel_size)
    desired_pixel_size = (zp_f + defocus) * wavelength / (Nprobe*desired_pixel_size)
#     print("desired pixel size after psi rando trans",desired_pixel_size)
    r1_pix = probe_diameter/desired_pixel_size
#     print(r1_pix)
    r2_pix = central_stop_diameter/desired_pixel_size
    xvec = torch.arange(-Nprobe[1]/2,torch.floor((Nprobe[1]-1)/2))
    yvec = torch.arange(-Nprobe[1]/2,torch.floor((Nprobe[1]-1)/2))
    x,y = torch.meshgrid(xvec,yvec)
    r2 = (x*desired_pixel_size[1])**2+(y*desired_pixel_size[1])**2
    w = make_prop_mask(r2,(r1_pix[1]/2).cpu().numpy(),0.1,1)
    w += -make_prop_mask(r2,(r2_pix[1]/2).cpu().numpy(),0.1,1)
    tf = torch.exp(-1j*torch.pi*(r2)/(wavelength*zp_f))
    wc = w*tf
#     probe_hr1 = farfield_PSI_prop(wc,wavelength,zp_f+defocus,desired_pixel_size)
    
    N = Nprobe
    
    wcp = torch.nn.functional.pad(wc,(0,0,padsize,padsize),mode='constant',value=0)
    r2p = torch.nn.functional.pad(r2,(0,0,padsize,padsize),mode='constant',value=0)
    propdist = (zp_f+defocus)
    probe_hr1 = -1j * (torch.exp(1j * torch.pi * lambda_val * propdist * r2p / (N[0]*N[1])) 
                       * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(
                           wcp * torch.exp(1j * torch.pi * r2p / (lambda_val*propdist))))))
    
    
    phrs = torch.tensor((probe_hr1.size(0)/2,probe_hr1.size(1)/2))
    cropinds = (int(float(phrs[0])-float((probe_dims[0]+padsize*2/upsample)/2)),
                int(float(phrs[0])+float((probe_dims[0]+padsize*2/upsample)/2)),
                int(float(phrs[1])-float((probe_dims[1])/2)),
                int(float(phrs[1])+float((probe_dims[1])/2)),
               )
#     print(cropinds)
    model_probe1 = probe_hr1[cropinds[0]:cropinds[1],cropinds[2]:cropinds[3]]
    return model_probe1


In [None]:
def probe_subpixel_shift(inputprobe,shift_amount_vert,shift_amount_hor,slab_thickness,y_voxel_size):
    
    # the amount of y shift corresponds to 1 pixel 
    # up or down is 1 slice further or closer along the z direction.
    #so if we wish to shift by 500nm and hte slab thickness is 1000nm, the corresponding shift is 0.5 pixels.
    shift_amount_pixelsv = shift_amount_vert 
    shift_amount_normalisedv = shift_amount_pixelsv / (inputprobe.size(0)*0.5)
    shift_amount_pixelsh = shift_amount_hor 
    shift_amount_normalisedh = shift_amount_pixelsh / (inputprobe.size(1)*0.5)
    
    input_tensor = torch.zeros(1,2,inputprobe.size(0),inputprobe.size(1))
    input_tensor[:,0,:,:] = torch.real(inputprobe) 
    input_tensor[:,1,:,:] = torch.imag(inputprobe)
    
    inputsize = input_tensor.size()
#     print(inputsize)
    grid = torch.zeros(inputsize[0],inputsize[2],inputsize[3],2)
    new_y  = torch.linspace(-1,1,inputprobe.size(0)) + shift_amount_normalisedv
    new_x  = torch.linspace(-1,1,inputprobe.size(1)) + shift_amount_normalisedh
    newx_mesh,newy_mesh = torch.meshgrid(new_y,new_x)
    grid[:,:,:,0] = newy_mesh
    grid[:,:,:,1] = newx_mesh
    output = torch.nn.functional.grid_sample(input_tensor,grid,align_corners=True,mode="bilinear").squeeze()
    output_complex = (output[0,:,:] + 1j*output[1,:,:]).to(torch.complex64)
#     print(output_complex.size())
    return output_complex
def probe_subpixel_shift_fourier(inputprobe,shift_amount_vert,shift_amount_hor,slab_thickness,voxel_size,hx_shift=0,hz_shift=0):
    probe_F = torch.fft.fft2(inputprobe)
    xx,yy = create_freq_mesh(voxel_size,probe_F.size())
    m1 = torch.exp(-1j*2*torch.pi*(xx*(shift_amount_hor+hx_shift)+yy*(shift_amount_vert+hz_shift)))
    # print(torch.sum(torch.isnan(m1)))
    probe_shifted = torch.fft.ifft2(probe_F*m1)
    return probe_shifted
 

In [None]:
def rescale_0_1(input):
    imin = torch.min(input)
    imax = torch.max(input-imin)
    if imax == 0:
        output = input
    else:    
        output = (input-imin)/(imax)
    return output

In [None]:
gpu_memcheck()

## Load scan positions, define ROI of scans, and specify incidence angle:
### An incidence angle must be defined for each subscan, even if they are the same.
### This simulation script comes with a premade position and scan list (demo_pz_values.npy,demo_px_values.npy). You can modify it yourself, or load your own one in as needed.


In [None]:

np.random.seed(2)
file_paths = [r'blank']

all_px_values = []
all_pz_values = []
all_ROIlist = []
all_ROI_inds = []
full_scan_identifier_list = []
# Loop over the list of file paths
file_iter = 0
n_sub_scans = 1 #load this number of sub scans iwthin a scan
dataset_shape = np.zeros(len(file_paths)*n_sub_scans)
all_pz_values = np.load("demo_pz_values.npy")
all_px_values = np.load("demo_px_values.npy")
for file_path in file_paths:

            full_scan_identifier_list += [file_iter]*all_pz_values.shape[1]
            
            print("subscan",file_iter, "loaded")
            file_iter = file_iter + 1
            




# all_pz_values *= 0.6
# Calculate ROIlist as in your code
px_block = [5,25] 
ROI_inds = np.where((all_px_values < px_block[1]) & (all_px_values > px_block[0]) & (all_pz_values < 105) & (all_pz_values > 25) | #Keep at 65, 35!!
                    (all_px_values < px_block[1]) & (all_px_values > px_block[0]) & (all_pz_values < 0) & (all_pz_values > 0) |
                   
                    (all_px_values < px_block[1]) & (all_px_values > px_block[0]) & (all_pz_values < 0) & (all_pz_values > 0)) 

cvals = np.zeros((1,all_px_values.shape[1]))

cvals[0][ROI_inds[1]] = 1
print("size of scans in ROI:",all_px_values[ROI_inds].size)

scan_categories = np.cumsum(dataset_shape)
print("scan_categories",scan_categories)

#which subset is a list where each value is each file it belongs to.
which_subset = np.digitize(ROI_inds[1],scan_categories)

#this corrects the scan number by subtracging the previous number of scans from ecah scan, since each subscan starts from zero... 
#this is to correctly load the h5 with teh correct indices.
scan_no_corrector = np.zeros_like(which_subset)

sc2 = np.insert(scan_categories,0,0)[:-1]
# for n1 in range(1,(which_subset.shape[0])):
#     scan_no_corrector[n1] = sc2[which_subset[n1]]#-sc2[which_subset[n1]]
#some bug?
scan_no_corrector[0] = scan_no_corrector[1]

corrected_ROI_inds = (ROI_inds[1]-scan_no_corrector)

ROI_inds_sub = [[] for _ in range(len(file_paths)*n_sub_scans)]
for ii in range(len(file_paths)*n_sub_scans):
    ROI_inds_sub[ii] = corrected_ROI_inds[which_subset==(ii)].astype(int).tolist()

num_scans = all_px_values[ROI_inds].size


scan_inc_angle_index = [0.7]
for ii in range(len(ROI_inds_sub)):
    
    if ii == 0:
        inc_angle_list = np.ones(len(ROI_inds[1]))*scan_inc_angle_index[ii]
        
        scan_identifier_list = np.ones(len(ROI_inds[1]))*ii
    else:
        inc_angle_list = np.concatenate((inc_angle_list,np.ones(len(ROI_inds_sub[ii]))*scan_inc_angle_index[ii]))
        scan_identifier_list = np.concatenate((scan_identifier_list,np.ones(len(ROI_inds_sub[ii]))*ii))

#fix inc angles list 


colors = []
for val in cvals[0]:
    
    if val == 1:
        colors.append('red')
    else: 
        colors.append('blue')
markers_list = ['x' if value == 0 else 'x' for value in full_scan_identifier_list]
# Define a mapping from scan identifiers to color transformations
color_table = {
    0: {'blue': 'mediumblue', 'red': 'cornflowerblue'},
    1: {'blue': 'mediumblue', 'red': 'lightsteelblue'},
    2: {'blue': 'mediumblue', 'red': 'fuchsia'},
    3: {'blue': 'darkviolet', 'red': 'deeppink'},
    4: {'blue': 'darkviolet', 'red': 'hotpink'},
    5: {'blue': 'darkviolet', 'red': 'gold'},
    7: {'blue': 'darkred', 'red': 'tomato'},
    7: {'blue': 'darkred', 'red': 'orangered'},
    8: {'blue': 'darkred', 'red': 'coral'},
    9: {'blue': 'darkolivegreen', 'red': 'springgreen'},
    10: {'blue': 'darkolivegreen', 'red': 'limegreen'},
    11: {'blue': 'darkolivegreen', 'red': 'palegreen'},
}

# Iterate through each index and update colors based on full_scan_identifier_list
for ii in range(len(full_scan_identifier_list)):
    scan_id = full_scan_identifier_list[ii]
    if scan_id in color_table:
        current_color = colors[ii]
        if current_color in color_table[scan_id]:
            colors[ii] = color_table[scan_id][current_color]

plt.figure(figsize=[15,10])
for apx, apz, color,marks in zip(all_px_values[0], all_pz_values[0], colors,markers_list):
    plt.scatter(apx, apz, c=color, s=70,marker=marks)


## Convert scan positions into voxel positions. 
### This includes discretization to the nearest voxel. Discretization errors are compensated in the simulation through probe shifts, which are also calculated. 
#### Scan/probe positions are optimized during reconstruction according to their learning rate. Set learning rate to 0 to disable this.

In [None]:
num_scans = len(ROI_inds[0])
print("num scans in this sim",num_scans)

scan_size_y = params_size[1]
scan_size_z = params_size[2]

#visualise where the scans are.
plt.figure()



#in the case we load a custom subset:
px_voxels = (all_px_values[ROI_inds]*1e-6) / voxel_size[1].cpu().numpy()
pz_voxels = (all_pz_values[ROI_inds]*1e-6) / slab_thickness.cpu().numpy()
px_voxels = px_voxels.reshape(-1,1) 
pz_voxels = pz_voxels.reshape(-1,1) 
#add the minimum so that we only have positive numbers. Also add half the sim size so that the final scans are not cut off.
if np.min(px_voxels) < 0:
    px_voxels += np.round(abs(np.min(px_voxels))+scan_size_y/2)

elif np.min(px_voxels) >= 0:
    px_voxels -= np.round(abs(np.min(px_voxels))-scan_size_y/2)
    
if np.min(pz_voxels) < 0:
    pz_voxels += np.round(abs(np.min(pz_voxels))+scan_size_z/2)
elif np.min(pz_voxels) >= 0:
    pz_voxels -= np.round(abs(np.min(pz_voxels))-scan_size_z/2)
# #check it makes sense
# plt.scatter(px_voxels,pz_voxels)
# plt.scatter(np.round(px_voxels,decimals=0),np.round(pz_voxels,decimals=0))

# plt.xlabel("voxel position (X)")
# plt.ylabel(r"voxel position (Z) (beamdir $\rightarrow$)")

#caluclate the subpixelshifts
subpixel_shifts_z =  -torch.tensor((np.round(pz_voxels,decimals=0) - pz_voxels))
subpixel_shifts_y =  torch.tensor((np.round(px_voxels,decimals=0) - px_voxels))


#now create zposindex and yposindex.
#note 21 03 24: we  add a +1 to everything so that no indices are zero, this helps if oversampling is used.
zposindex = []
for i in range(num_scans):
    zposindex.append((int(np.round(pz_voxels[i]-scan_size_z/2)+1),int(np.round(pz_voxels[i]+scan_size_z/2+1)+1)))
yposindex = []
for i in range(num_scans):
    yposindex.append((int(np.round(px_voxels[i]-scan_size_y/2)+1),int(np.round(px_voxels[i]+scan_size_y/2+1)+1)))

# plt.figure()
# plt.subplot(2,1,1)
# plt.plot((subpixel_shifts_z).cpu())
# plt.title("subpixel shifts Z")
# plt.subplot(2,1,2)
# plt.plot((subpixel_shifts_y).cpu())
# plt.title("subpixel shifts Y")


In [None]:

full_sim_size_zdir = max(max(zposindex))
full_sim_size_ydir = max(max(yposindex))

xr_z_size = int(full_sim_size_zdir/1)
print("size of all sims [y,z],",full_sim_size_ydir,full_sim_size_zdir)
downscaling_factor_z = (full_sim_size_zdir/xr_z_size)
print("downscaling factor z (number of slices 1 z voxel represents):" ,downscaling_factor_z)
#so here we use this to convert the z indices
z_indices_downscaled = zposindex
new_zposindex = [(int(x / downscaling_factor_z), int(y / downscaling_factor_z)) for x, y in zposindex]
# print(new_zposindex)

xr_y_size = int(full_sim_size_ydir/1)
# print("size of all sims [y,z],",full_sim_size_ydir,full_sim_size_zdir)
downscaling_factor_y = (full_sim_size_ydir/xr_y_size)
print("downscaling factor y (number of slices 1 y voxel represents):" ,downscaling_factor_y)
#so here we use this to convert the z indices
y_indices_downscaled = yposindex
new_yposindex = [(int(x / downscaling_factor_y), int(y / downscaling_factor_y)) for x, y in yposindex]
# print(new_yposindex)


In [None]:
#tilted plane correction has not been fully implemented, but was shown not to make a significant difference in testing. For now set to false :
do_tilted_plane_corr = False


In [None]:
#make masks
# detector_mask_crop = detector_mask_crop[wind1[0]:wind1[1],wind1[2]:wind1[3]]
num_scans1 = num_scans
# num_scans2 = len(ROI_inds_sub[1])
# num_scans3 = len(ROI_inds_sub[2])
plt.figure(figsize=[10,10])
# print("total scans of all angles",num_scans1+num_scans2)
# print("scan crossover points",num_scans1,num_scans1+num_scans2)
GT = torch.zeros((crop_window_size[0],crop_window_size[1],num_scans))

if do_tilted_plane_corr == 0:

    mask = torch.ones_like(GT)


In [None]:
#Assign GT
GT_out_pack = GT
print("Size of GT exit waves:",GT_out_pack.size())


In [None]:
def print_progress_bar(iteration, total, bar_length=20):
    progress = (iteration / total)
    arrow = '=' * int(round(bar_length * progress))
    spaces = ' ' * (bar_length - len(arrow))
    sys.stdout.write(f'\rthis iter progress: [{arrow + spaces}] {int(progress * 100)}%')
    sys.stdout.flush()


In [None]:
def blur_probe_test(probe_in, blur_strength_x, blur_strength_y):
    
    probe_in_R = torch.real(probe_in).cpu().numpy()
    probe_in_I = torch.imag(probe_in).cpu().numpy()
    BPR = torch.tensor(gaussian_filter(probe_in_R,(blur_strength_x,blur_strength_y))).to(torch.float32)
    BPI = torch.tensor(gaussian_filter(probe_in_I,(blur_strength_x,blur_strength_y))).to(torch.float32)
    blurred_probe = BPR+1j*BPI
    
    return blurred_probe


In [None]:
use_model_probe = True
n_probe_modes = int(np.max(scan_identifier_list+1))
model_probe_pixel_size = voxel_size#torch.tensor((4.4570e-08,4.4570e-08))#torch.zeros(n_probe_modes,2)
probe_xvoxelsizes = torch.linspace(7e-9,7e-9,n_probe_modes) 
rolls = torch.linspace(0,0,n_probe_modes)
pretiltangles = torch.linspace(0,0,n_probe_modes)

new_flux = 4.0e-8
probe_vmax = 8e-4
probe_aspr = voxel_size[0]/voxel_size[1]
recon_FFT_vmin = -5
recon_FFT_vmax = 0
new_prop = -8.5e-4#-13.5e-4 best for 06deg,07deg, -10e-4 beter for 09deg
probepad = 400

initprobesize = [182,182]
cd1 = int((initprobesize[0]-crop_window_size[1])/2)
probecenter=None
probesigma_x=30.0
probesigma_y=100.0

if use_model_probe == True:
    probe_diameter = 240e-6 #240e-6             #170e-6 #240e-6   #300e-6
    central_stop_diameter = 48e-6#45e-6    #55e-6
    zone_plate_diameter = 490e-6#330e-6           #200e-6   #330e-6
    outer_zone_width = 43e-9#90e-9           #55e-9   #70e-9
    # init_model_probe = make_gauss_model_probe(initprobesize[1]*int(voxel_size[1]/voxel_size[0]),initprobesize[1],probecenter,probesigma_x,probesigma_y)
    init_model_probe = make_csaxs_model_probe(lambda_val,initprobesize,model_probe_pixel_size,
                           probe_diameter,central_stop_diameter,zone_plate_diameter,
                         outer_zone_width,new_prop)
    init_model_probe = optimise_probe_prop(torch.nn.functional.pad(init_model_probe,(probepad,probepad,probepad,probepad)),new_prop,lambda_val,model_probe_pixel_size)[probepad:-probepad,probepad:-probepad]*new_flux
    model_probes = torch.zeros(init_model_probe.size(0),init_model_probe.size(1),n_probe_modes,dtype=torch.complex64)
    model_probes[:,:,0] = init_model_probe #first one is always voxel size
    for n1 in range(0,n_probe_modes):
        print("probe mode:",n1)
        # this_model_probe = make_gauss_model_probe(initprobesize[1]*int(voxel_size[1]/voxel_size[0]),initprobesize[1],probecenter,probesigma_x,probesigma_y)
        this_model_probe = make_csaxs_model_probe(lambda_val,initprobesize,model_probe_pixel_size,
                               probe_diameter,central_stop_diameter,zone_plate_diameter,
                             outer_zone_width,new_prop)
        newsd = int((init_model_probe.size(0)-this_model_probe.size(0))/2)
        
        this_model_probe = torch.nn.functional.pad(this_model_probe,(0,0,newsd,newsd),value=0)
        if this_model_probe.size() != model_probes[:,:,n1].size():
            this_model_probe = torch.nn.functional.pad(this_model_probe,(0,0,0,1),value=0)
            
            
        this_model_probe = optimise_probe_prop(torch.nn.functional.pad(this_model_probe,(probepad,probepad,probepad,probepad)),new_prop,lambda_val,model_probe_pixel_size)[probepad:-probepad,probepad:-probepad]
        # this_model_probe[-200:,:] = init_model_probe[-200:,:] 
        this_model_probe = torch.roll(this_model_probe,int(rolls[n1]),0)*new_flux
        model_probes[:,:,n1] = this_model_probe#probe_tilt_gradient(this_model_probe,pretiltangles[n1],voxel_size,slab_thickness)
    
    orig_csaxs_probe = ((init_model_probe) * new_flux)
    print("using model probe, with dimensions of", init_model_probe.size())
else:
    probe_pixel_size = recon_probe_pixel_size
    orig_csaxs_probe = (new_probe_T)

plt.figure(figsize=[5,5])
plt.imshow(torch.abs(model_probes[:,:,0]).cpu(),aspect=probe_aspr,vmin=0,vmax=probe_vmax), plt.colorbar()
plt.title("input probe for sim")
plt.figure(figsize=[5,10])
for n1 in range(n_probe_modes):
    plt.subplot(5,4,n1+1)
    plt.imshow(torch.abs(model_probes[:,:,n1]).cpu(),aspect=probe_aspr,vmin=0,vmax=probe_vmax), plt.colorbar()
    plt.title(f"Probe Mode {n1}")
plt.figure()
plt.subplot(1,2,1)


this_exit_wave_crop = v2.CenterCrop(size=(crop_window_size))(torch.abs(torch.fft.fftshift(torch.fft.fft2((model_probes[:,:,0]))))+1e-20)
plt.imshow(torch.log(this_exit_wave_crop).cpu(),aspect=1)
plt.title("FFT of input probe")
plt.subplot(1,2,2)

plt.imshow(torch.log(torch.abs(GT[:,:,5]**0.5)+1e-10).cpu(),vmin=recon_FFT_vmin,vmax=recon_FFT_vmax)


In [None]:

def combine_probe_modes(multimode_probe_in,scan_number):
    if probe_is_FFT == True:
            multimode_probe_in = torch.fft.ifft2(torch.fft.fftshift(multimode_probe_in))
        
    if len(multimode_probe_in.size()) == 3:
        combined_probes = multimode_probe_in[:,:,int(scan_identifier_list[scan_number])]
    if len(multimode_probe_in.size()) == 2:
        combined_probes = multimode_probe_in
    return combined_probes


In [None]:
def add_second_mode(probe_in,pad1):
    probeinR = torch.real(probe_in)
    probeinI = torch.imag(probe_in)
    probeinR = torch.nn.functional.pad(probeinR,(0,0,pad1,pad1))
    probeinI = torch.nn.functional.pad(probeinI,(0,0,pad1,pad1))
    probeinR = torch.nn.functional.interpolate(probeinR.unsqueeze(0).unsqueeze(0),size=probe_in.size()).squeeze()
    probeinI = torch.nn.functional.interpolate(probeinI.unsqueeze(0).unsqueeze(0),size=probe_in.size()).squeeze()
    probeout = probeinR + 1j* probeinI
    return probeout

In [None]:
def create_iso_random_initguess(xsize,ysize,zsize,aspr):
    str1 = torch.rand(xsize,int(ysize/aspr),zsize)
    str2 = torch.nn.functional.interpolate(str1.unsqueeze(0).unsqueeze(0),size=[xsize,ysize,zsize]).squeeze().unsqueeze(0)
    str3 = torch.tensor(gaussian_filter(str2.cpu().numpy(),5)).to(torch.float32)
    vec_X = torch.linspace(-2,2,str2.size(2))
    vec_Y = torch.linspace(-2,2,str2.size(1))
#testcurve_X = -2*(vec_X)**2+10
    testcurve_Y = 5*(vec_Y)**2
    test_structure_curve,_ = torch.meshgrid(testcurve_Y,vec_X)
#     plt.imshow(test_structure_curve.cpu(),aspect=0.01)
    
    
#     str3 += test_structure_curve
    return str3

# This cell sets pytorch relevant initial parameters for the optimisaiton, and defines parameter tensors, optimizers and learning rates.

In [None]:
print("full area covered by optimizable volume:", "%5.5E" % (slab_thickness*full_sim_size[2]*1e6), "um" )
yzaspr = (slab_thickness*downscaling_factor_z)/(voxel_size[1]*downscaling_factor_y)
volsizeratio = xr_z_size/xr_y_size
xy_voxel_size_ratio = voxel_size[0]/voxel_size[1]
print("params size", params_size)

using_n_scans_per_pixel = True
# using_rprop = False
do_gradient_accumulation = True
#if loading a previously saved/optimised 
load_previous = False
#this is whether TV is calculated ove rthe hwole volume or just over each scan. 
use_per_scan_TV = False
use_support_mask = False
use_multiple_probes = False
use_building_blocks = False
use_multiple_probe_modes = True
use_noise = True
optimize_scan_offsets = True
probe_is_FFT = 0
oversample_structure = True
if using_n_scans_per_pixel == True and do_gradient_accumulation == False:
    print("cannot have n_scans_per_pixel div if gradient accumulation is not used, setting ot true")
    do_gradient_accumulation == True

if use_multiple_probe_modes == True:
    print("using multiple probe modes")

    probe_in = model_probes.clone()

else:
    print("not using multiple probe modes")

    
probe_prop_amt = torch.nn.Parameter(torch.tensor(1e-5))
probefluxmult = torch.nn.Parameter(torch.tensor(1,dtype=torch.float32)) #6e2 for scan 509

if load_previous:
    probes_presave = torch.load('/home/lubs/pytorchnotebooks/Au_realdatatensor_probes_050608_conv_27-11-2024.pt')
    # probe_in = torch.nn.Parameter(probes_presave)
    probes_param = torch.nn.Parameter(probes_presave)
else:
    if probe_is_FFT == True:
        probes_param = torch.nn.Parameter(torch.fft.fftshift(torch.fft.fft2(probe_in)))
    else:
        probes_param = torch.nn.Parameter(probe_in)

if use_building_blocks == True:
    xr_x_size = 1 #must be 1 for building blocks.
else:
    xr_x_size = 1#params_size[0]
# support_maskC_bin3d = (~support_maskC_bin).repeat(xr_x_size,1,1)

#the slice you can use to watch the recon. not always 0, but usually.
viewing_slice = 5


yzpad = 0
print("final param size for this sim:", xr_x_size,xr_y_size+yzpad*2,xr_z_size+yzpad*2)

init_xr_scaling_factor = 0.02
init_xr_substrateamt = 0.00
init_xr_learning_rate = 1e5#30e-1
grads_target = 0.2
if optimize_scan_offsets == True:
    subpixel_shifts_z = torch.nn.Parameter(-((torch.tensor((np.round(pz_voxels,decimals=0) - pz_voxels))*voxel_size[0])).clone())
    subpixel_shifts_y = torch.nn.Parameter((torch.tensor((np.round(px_voxels,decimals=0) - px_voxels))*voxel_size[1]).clone())
    hx_shift = torch.nn.Parameter(torch.rand(n_probe_modes)*1e-13)
    hz_shift = torch.nn.Parameter(torch.rand(n_probe_modes)*1e-13)
    scan_positions_optim = torch.optim.SGD([{'params': [subpixel_shifts_z,subpixel_shifts_y], 'lr': 1e-15},{'params': [hx_shift], 'lr': 5e-15},{'params': [hz_shift], 'lr': 5e-15}])
    print("position optimization on")


# if load_previous == True:
    
#     pre_xr = torch.load('/home/lubs/pytorchnotebooks/Au_realdatatensor_0506deg_311024_conv.pt').detach() #this file had oversample factor 2 !!!
#     # substrate_mean = torch.mean(pre_xr[:,:,3500:4000])
#     # # pre_xr[:,:,:500] = 0
#     # pre_xr -= substrate_mean
#     pre_xr *= 12
#     pre_xr = torch.clamp(pre_xr,0,4.1)
    
#     xr = torch.nn.Parameter(pre_xr)
#     oversample_factor = 5

# if not load_previous:
if oversample_structure == True:
    oversample_factor = 1
    xr = torch.nn.Parameter(torch.rand(1,xr_y_size,xr_z_size*oversample_factor)*init_xr_scaling_factor+init_xr_substrateamt)#torch.nn.Parameter((create_iso_random_initguess(xr_x_size,xr_y_size+yzpad*2,(xr_z_size+yzpad*2)*oversample_factor,(yzaspr))*init_xr_scaling_factor+init_xr_substrateamt))
else:
    oversample_factor = 1
    xr = torch.nn.Parameter((create_iso_random_initguess(xr_x_size,xr_y_size+yzpad*2,xr_z_size+yzpad*2,(yzaspr))*init_xr_scaling_factor+init_xr_substrateamt))


if use_building_blocks != True:
    n_scans_per_pixel = (torch.ones_like(xr))
    #if youre using it or not:

    for n1 in range(len(new_yposindex)):
        n_scans_per_pixel[:,(new_yposindex[n1][0]+yzpad):(new_yposindex[n1][1]+yzpad),
                          (new_zposindex[n1][0]+yzpad):(new_zposindex[n1][1]+yzpad)] += 1 

    plt.imshow(n_scans_per_pixel[0,:,:].cpu()), plt.colorbar()

    if using_n_scans_per_pixel:
        print("using n scans per pixel")
    else:
        print("not using n scans per pixel")


#model noise

# noise_guess = torch.nn.Parameter(torch.clamp(torch.normal(torch.ones(crop_window_size[0],crop_window_size[1],num_scans)*noise_mean,torch.ones(crop_window_size[0],crop_window_size[1],num_scans)*noise_std),min=0))

# Create an optimizer for the parameter xr #good lr is 2e-3!!
#normal optim
#are you doing gradient accumulation?
if do_gradient_accumulation == True:
    print('using gradient accumulation')
else:
    print('not using gradient accumulation')

if use_per_scan_TV == True:
    print('TV is calculated per scan area')
else:
    print('TV is calculated over entire volume')

if use_support_mask == True:
    print('support mask for grads is used')
else:
    print('support mask for grads is not used')
if use_noise == 1:
    print('using noise addition')
else:
    print('not using noise addition')

#### Are you using rprop or Adam ?


if do_gradient_accumulation == 0:
    optimizer = torch.optim.RAdam([{'params': [xr], 'lr': init_xr_learning_rate,'weight_decay':0}])  # You can adjust the learning rate as needed    
else:
    optimizer = torch.optim.SGD([{'params': [xr], 'lr': init_xr_learning_rate,'weight_decay':0}])  # You can adjust the learning rate as needed    
    optimizer_adam = torch.optim.Adam([{'params': [xr], 'lr': 8e-3,'weight_decay':0}])
probe_params = [ {'params': [probes_param], 'lr': 5e-2, 'weight_decay': 1e-20}]
probe_grads_target = 1.0e-8
# noise_optim = torch.optim.RAdam([{'params': [noise_guess], 'lr': 5e-7}])  # You can adjust the learning rate as needed

probe_optimizer = torch.optim.SGD(probe_params)
#scheduler for probe
probe_LR_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(probe_optimizer, mode='min', 
                    factor=0.5, patience=4, threshold=0.0001, threshold_mode='rel',
                    cooldown=0, min_lr=0, eps=1e-07)

prop_optim = torch.optim.RAdam([{'params': [probe_prop_amt], 'lr': 1e-4}]) #lr 5e-7 is agressive  # You can adjust the learning rate as needed
num_iters = 50
loss_tracker = np.zeros((num_iters,num_scans))
total_sim_start_time = time.time()
iters_out = torch.tensor([])
dzz = 0
dzy = 0
tvx_alpha =  1e1*(xr_x_size*slab_thickness*downscaling_factor_z)/(params_size[0]*voxel_size[0]) #weight for TV in x direction
print("tvx alpha",tvx_alpha)
tvy_alpha = 1#weight for TV in y direction
print("tvy alpha",tvy_alpha)
tvz_alpha = 1 #weight for TV in z direction
tvt_alpha = 3e-3 #total weight for all TV penalties  #if parial voxels, 1e-7 is typical, multiply by 1e6 or so if using complex

probestart = 2 #iteration number at which the probe will begin optimising.
tv_start = 2 #iteration number at which TV penalty will begin to apply.

patience = 5
divergence_count_start = 5
divergence_thr = 2
xrg = torch.zeros_like(xr)
PFM = torch.zeros_like(probefluxmult)
# orig_probe_max = torch.max(probe_amp)
print("patience (number of iters before reducing LR):",(divergence_thr+1))
postresizeprobemask_rad = 95


In [None]:
if oversample_structure:
    new_zposindex = [(int(x * oversample_factor), int(y * oversample_factor)) for x, y in zposindex]
        

In [None]:

torch.cuda.empty_cache()
optimizer.zero_grad()
probe_optimizer.zero_grad()
prop_optim.zero_grad()
gc.collect()
# if use_noise == 1:
    # noise_optim.zero_grad()
gpu_memcheck()

In [None]:

EWCI = [528,528,1,1] #EWCI = Exit Wave Crop Indices. Crop the exit wave of the simulation to match the size of the data. Useful when working with real experimental data.


## Set additional simulation settings:

In [None]:
substrate_layers = 200 #number of layers from the bottom upwards of substrate. This shuold be enough to fully reflect/absorb the incoming beam.
probe_substrate_buffer = 2 #extra spacing between probe and substrate.
top_buffer = 0 #add padding in the final multislice component if the beam reflects too early or too steeply.
pre_prop_dist_multiplier = 1.5 #add extra propagation distance before the beam encounters the structure.
post_prop_dist_multiplier = 1 #add extra propagation distance after the beam encounters the structure.
test_inc_angle = 0.7 # the incidence angle to test out these settings.
slab_pad_pre = 0 #add extra slices, not used, but necessary atm for code to run.
slab_pad_post = 0 #add extra slices, not used , but necessary atm for code to run.


## Run a quick multislice forward simulation for diagnostic purposes.

### Here you can check that the beam comes in, reflects, and exits, all within the space of the volume you defined.
### this is a good point to check your simulation settings are okay, or if you need to go back and modify something, without having to run the entire reconstruction. It should only take a few seconds on a GPU
#### The simulation follows this same code.

In [None]:

iii = 25
with torch.no_grad():
    
    teststrin =  torch.zeros(params_size)
        
    probein = combine_probe_modes(probes_param,0) 
    probesize = probein.size()
    pre_EW,pre_post_amt,probe_insertion_Y = MSForward_GI_SS_novol_partialvoxel_pre(
                        torch.zeros(params_size),
                        full_sim_size, params_size, voxel_size, slab_thickness,
                        test_inc_angle, slab_pad_pre, slab_pad_post, probein,
                        probe_buffer,1,0,substrate_layers,init_xr_substrateamt,0,0,0,0,0)
    midEW,midslice,sideslice = MSForward_GI_SS_novol_partialvoxel_mid(teststrin,pre_EW,
                   full_sim_size,params_size,voxel_size,slab_thickness,test_inc_angle,(substrate_layers))
#     midEW*=make_probecut_mask(midEW,1,5e-1)
    exit_wave = MSForward_GI_SS_novol_partialvoxel_post(xr,midEW,
                    full_sim_size,params_size,voxel_size,slab_thickness,test_inc_angle,(substrate_layers),pre_post_amt,probe_insertion_Y,probesize,init_xr_substrateamt)
#     F2 = torch.abs(farfield_PSI_prop(exit_wave,lambda_val,7.36e0,voxel_size))
    F1 = torch.abs(torch.fft.fftshift(torch.fft.fft2(exit_wave)))
    F1 = torch.nn.functional.interpolate(F1.unsqueeze(0).unsqueeze(0),size=(probesize),mode='bilinear').squeeze()
#     F2 = torch.nn.functional.interpolate(F2.unsqueeze(0).unsqueeze(0),size=(probesize),mode='bilinear').squeeze()
    F1 = (F1)[EWCI[0]:-EWCI[1],EWCI[2]:-EWCI[3]]
#     F2 = (F2)[EWCI[0]:-EWCI[1],EWCI[2]:-EWCI[3]]

    
    
    plt.figure(figsize=[15,10])
    plt.subplot(2,3,1)
    plt.imshow(torch.abs(pre_EW).cpu(),aspect=0.1)
    plt.title('|Input slice|')
    plt.xlabel("y (voxels)")
    plt.ylabel("x (voxels)")
    plt.subplot(2,3,2)
    plt.imshow(torch.abs(midEW).cpu(),aspect=0.1)
    plt.title('|Slice after reflection|')
    plt.xlabel("y (voxels)")
    plt.ylabel("x (voxels)")
    plt.subplot(2,3,3)
    plt.imshow(torch.abs(exit_wave).cpu(),aspect=0.1)
    plt.title('|Exit wave| (final slice)')
    plt.xlabel("y (voxels)")
    plt.ylabel("x (voxels)")
    # pre_FT = torch.fft.fftshift(torch.fft.fft2(pre_EW))
    # mid_FT = torch.fft.fftshift(torch.fft.fft2(midEW))

    # plt.figure(figsize=[5,5])
    plt.subplot(2,3,4)
    plt.imshow(torch.log(torch.abs((F1))+1e-20).cpu(),aspect=1,vmin=recon_FFT_vmin,vmax=recon_FFT_vmax)
    plt.title("FT[exit wave]")
    plt.xlabel("uy (voxels)")
    plt.ylabel("uz (voxels)")
    
    # plt.figure(figsize=[10,5])
    plt.subplot(2,3,5)
    plt.imshow((torch.abs(sideslice.cpu())),aspect='auto')
    plt.title("side-on view")
    plt.xlabel("z (voxels)")
    plt.ylabel("x (voxels)")

    # plt.subplot(2,3,5)
    # plt.imshow(torch.log(torch.abs((F1))+1e-20).cpu(),aspect=1,vmin=recon_FFT_vmin,vmax=recon_FFT_vmax)
    # plt.title("FT[exit wave]")
    # plt.subplot(1,3,3)
    # plt.imshow(torch.log(torch.abs((GT[:,:,-5]**0.5))+1e-20).cpu(),aspect=1,vmin=recon_FFT_vmin,vmax=recon_FFT_vmax)
    
    print("exit wave norm",torch.norm(torch.abs(F1)))
    print("exit wave max",torch.max(torch.abs(F1)))
    

## Assert that the exit wave simulation window size matches the data/detector window size.
### not needed for simulated data, but useful for real data.

In [None]:
F1_size = (F1).size()

print("size of GT  crop",crop_window_size)
print("size of EW",F1_size)
# print("difference",(F1_size[0] - crop_window_size[0]),(F1_size[1] - crop_window_size[1]))
# print("old EWCI", EWCI)
# print("new EWCI should be",EWCI[0]+(F1_size[0] - crop_window_size[0])/2,EWCI[2]+(F1_size[1] - crop_window_size[1])/2)
assert (F1_size[0] == crop_window_size[0]) & (F1_size[1] == crop_window_size[1]) , "EWCI isnt the same"

In [None]:
# plt.figure()
# plt.imshow((torch.abs((midslice)).detach().cpu()))
# plt.title("beam footprint on substrate")
plt.figure()
plt.plot(((torch.abs(midslice.cpu()))[:,int(full_sim_size[2]/2)]))

plt.title("beam profile at midpoint")

## Define beam footprint mask. This will be to aid with a real space constraint in the volume.

In [None]:
beam_footprint_radius_y_voxels = (1.9e-6/(voxel_size[1]*downscaling_factor_y))
beam_footprint_radius_z_voxels = (1.9e-6*oversample_factor/np.sin(np.deg2rad(inc_angle_list[0])))/(slab_thickness*downscaling_factor_z)
beam_footprint_ratio = beam_footprint_radius_z_voxels/beam_footprint_radius_y_voxels
beam_footprint_buffer = 1
beam_footprint_blur = 5
beamfootprints = (torch.ones(1,xr.size(1),xr.size(2)))
print(beamfootprints.size())
for n1 in range(len(new_yposindex)):
    
    beamfootprints[:,(new_yposindex[n1][0]):(new_yposindex[n1][1]),
                      (new_zposindex[n1][0]):(new_zposindex[n1][1])] += make_prop_mask(beamfootprints[0,int(new_yposindex[n1][0]):int(new_yposindex[n1][1]),int(new_zposindex[n1][0]):int(new_zposindex[n1][1])].rot90(1),(beam_footprint_radius_y_voxels+beam_footprint_buffer).cpu().numpy(),beam_footprint_blur,beam_footprint_ratio).rot90(-1)+1e-3


    
iii = 2
thisxrtest = xr[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)]
masktest1 = make_prop_mask(thisxrtest.rot90(1),(beam_footprint_radius_y_voxels+beam_footprint_buffer).cpu().numpy(),beam_footprint_blur,beam_footprint_ratio).rot90(-1)+1e-4
# beamfootprints = beamfootprints.unsqueeze(-1)
# print(thisxrtest.size())
# print(masktest1.size())
# print(midslice.size())
plt.figure(figsize=[10,5])
plt.subplot(1,2,1)
plt.imshow(masktest1.detach().cpu(),aspect=2)
plt.title("single scan footprint mask")
plt.subplot(1,2,2)
plt.imshow(torch.abs(midslice).detach().cpu(),aspect=2/oversample_factor)
plt.title("beam footprint on substrate")
plt.figure(figsize=[15,5])
plt.subplot(1,2,1)
plt.imshow(beamfootprints[0,:,:].cpu(),aspect=0.9), plt.colorbar()
plt.title("beam footprints with mask applied")

beam_footprints_binary = beamfootprints.clone()
beam_footprints_binary[beam_footprints_binary<2] = 0
beam_footprints_binary[beam_footprints_binary>=2] = 1


plt.subplot(1,2,2)
plt.imshow(beam_footprints_binary[0,:,:].cpu(),aspect=0.9), plt.colorbar()
plt.title("beam footprints binary mask")



In [None]:
def generate_per_spoke_modulated_siemens_star(
    width=512, height=512,
    num_spokes=36,
    r_inner=50, r_outer=200,
    modulated_spokes=None  # dict: {index: (r_inner_alt, r_outer_alt)}
):
    if modulated_spokes is None:
        modulated_spokes = {}

    # Create coordinate grid
    y, x = np.ogrid[:height, :width]
    cx, cy = width // 2, height // 2
    x = x - cx
    y = y - cy

    # Convert to polar
    r = np.sqrt(x**2 + y**2)
    theta = np.arctan2(y, x)
    theta_norm = (theta + np.pi) / (2 * np.pi)

    # Spoke indices
    spoke_idx = (theta_norm * num_spokes).astype(int)

    # Alternating spoke pattern
    is_spoke = ((theta_norm * num_spokes) % 1) < 0.5

    # Initialize radius maps with default values
    r_outer_map = np.full_like(r, r_outer, dtype=float)
    r_inner_map = np.full_like(r, r_inner, dtype=float)

    # Apply custom values for each modulated spoke
    for idx, (ri_alt, ro_alt) in modulated_spokes.items():
        mask = spoke_idx == idx
        r_outer_map[mask] = ro_alt
        r_inner_map[mask] = ri_alt

    # Mask for radius within bounds
    radius_mask = (r >= r_inner_map) & (r <= r_outer_map)

    # Final Siemens star
    siemens_star = is_spoke & radius_mask

    return torch.tensor(siemens_star.astype(np.float32)).unsqueeze(0)


In [None]:
def create_GT_structure(width,height,bsx=1550,bsy=-170,downscaley = 1,downscalez = 1,height_voxels=2):
    image_tensor = torch.zeros(1, 1, height, width)  # 1 channel (grayscale), 1 batch
    for n1 in range(erclogo.shape[0]):
    
        patch_coords = (round((erclogo[n1,0])/(voxel_size[1]*downscaley).cpu().numpy()+bsx),
                        round((erclogo[n1,1])/(slab_thickness*downscalez).cpu().numpy()-((min(erclogo[:,1])/(slab_thickness*downscalez).cpu().numpy()+bsy))),
                        round((erclogo[n1,0]+erclogo[n1,2])/(voxel_size[1]*downscaley).cpu().numpy()+bsx),
                        round((erclogo[n1,1]+erclogo[n1,3])/(slab_thickness*downscalez).cpu().numpy()-((min(erclogo[:,1])/(slab_thickness*downscalez).cpu().numpy()+bsy))))
    
        mask = torch.zeros(1, 1, height, width)  # Initialize with zeros
        
        mask[:, :, patch_coords[1]:patch_coords[3], patch_coords[0]:patch_coords[2]] = 1
    
        # Apply the mask to the image tensor to create a binary image
        image_tensor = image_tensor + mask
        image_tensor = torch.clamp(image_tensor,0,1)
    image_tensor = image_tensor.squeeze().flip(0)*height_voxels
    return image_tensor.unsqueeze(0)
    

In [None]:
#create ground truth structure.
# ground_truth_image = create_GT_structure(xr_z_size,xr_y_size,bsx=450,bsy=-50,downscaley = 10,downscalez = 1,height_voxels=1)
# ground_truth_image = generate_fully_modulated_siemens_star(
    # width=xr_z_size, height=xr_y_size,num_spokes=35,r_inner=15, r_outer=200,r_inner_alt=80, r_outer_alt=200,n_modulate=5)
modulated = {
    1: (60, 150),
    4: (70, 160),
    7: (80, 170),
    10: (90, 180),
    13: (100, 190),
    16: (110, 200),
    19: (120, 200),
    22: (130, 200),
    24: (140, 200),
    # 25: (100, 200),
    # 28: (110, 200),
    # 31: (120, 200),
    # 34: (130, 200),

    
    
}
ground_truth_image = generate_per_spoke_modulated_siemens_star(width=xr_z_size,
                                                               height=xr_y_size,
                                                               num_spokes=25,
                                                               r_inner=15, 
                                                               r_outer=200,
                                                               modulated_spokes=modulated)
# ground_truth_image = generate_custom_modulated_siemens_star(
#     width=xr_z_size, height=xr_y_size,num_spokes=35,r_inner=15, r_outer=200,r_inner_alt=80, r_outer_alt=150,modulated_spoke_indices=[1,4,8,13,19,26,34])
ny, nx = ground_truth_image.squeeze().shape
extent = [0, nx * slab_thickness.cpu()*1e6, 0, ny * voxel_size[1].cpu()*1e6]
plt.imshow(ground_truth_image.squeeze().cpu(),extent=extent,aspect='auto'),plt.colorbar()
plt.title("Ground Truth")
plt.ylabel("y (μm)")
plt.xlabel("z (μm)")


In [None]:
#create simulated diffrac patterns:
with torch.no_grad():
    GTs = torch.zeros_like(GT)
    this_GT_str = ground_truth_image
    for iii in range(num_scans):
            
            print("creating sim data diffrac. patterns:",(iii+1), "/", num_scans,end="\r")
            this_xr = fill_3d_tensor(this_GT_str[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                           int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)],params_size[0],oversample_structure,oversample_factor)
            sim_probe_in = combine_probe_modes(probes_param,0)
            waveout,_ = multislice_3stage(
                    this_xr,
                    full_sim_size, params_size, voxel_size, slab_thickness,
                    inc_angle_list[iii], slab_pad_pre, slab_pad_post, sim_probe_in,
                    probe_buffer,1,shift_amount,substrate_layers,
                    init_xr_substrateamt,top_buffer,0,subpixel_shifts_y[iii],subpixel_shifts_z[iii],hx_shift[int(scan_identifier_list[iii])],hz_shift[int(scan_identifier_list[iii])])
                #crop exit wave
            out1 = torch.abs(waveout[EWCI[0]:-EWCI[1],EWCI[2]:-EWCI[3]])#*mean_1d_perscan[:,:,iii]
            
            GTs[:,:,iii] = out1
            

In [None]:
with torch.no_grad():
    # plt.imshow(this_xr[-2,:,:].cpu())
    plt.subplot(2,1,1)
    plt.imshow(torch.log(GTs[:,:,4]).cpu()),plt.colorbar()
    plt.subplot(2,1,2)
    plt.imshow(this_GT_str[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                           int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)].cpu())

## Run the Reconstruction

In [None]:
psr = probes_param.size(0)/probes_param.size(1)
wave_deletion_mask = 1
# normal partial voxels, grads are not accumulated but optimsied after each scan.
fig = plt.figure(figsize=[10,10])
divergence_count = 0
#start the loop
for i in range(num_iters):

    mstime1 = time.time()
    scan_orders = (list(range(num_scans)))
    random.shuffle(scan_orders)
    scan_counter = 0
    #zero grad "storage" tensors
    probes_cumulative_grad = torch.zeros_like(probes_param)
    spsz_cumulative = torch.zeros_like(subpixel_shifts_z)
    spsy_cumulative = torch.zeros_like(subpixel_shifts_y)
    hx_cumulative = torch.zeros_like(hx_shift)
    hz_cumulative = torch.zeros_like(hz_shift)

    xrg *= 0

    for iii in scan_orders:
        #zero optimizers
        optimizer.zero_grad()
        optimizer_adam.zero_grad()
        probe_optimizer.zero_grad()
        prop_optim.zero_grad()
        if optimize_scan_offsets ==  1 :
            scan_positions_optim.zero_grad()
        # if use_noise == 1:
        #     noise_optim.zero_grad()
        
        resized_probe = combine_probe_modes(probes_param,iii)
        if use_multiple_probe_modes == True:
            probes_in_unshifted = (resized_probe)
        else:
            probes_in_unshifted = (resized_probe)
     
        probes_in_shifted = probes_in_unshifted
        # Forward pass
        this_xr = fill_3d_tensor((xr)[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
               int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)],params_size[0],oversample_structure,oversample_factor)
        

        #actually do the multislice:
        waveout,_ = multislice_3stage(
            this_xr,
            full_sim_size, params_size, voxel_size, slab_thickness,
            inc_angle_list[iii], slab_pad_pre, slab_pad_post, probes_in_shifted,
            probe_buffer,wave_deletion_mask,shift_amount,substrate_layers,
            init_xr_substrateamt,top_buffer,0,subpixel_shifts_y[iii],subpixel_shifts_z[iii],hx_shift[int(scan_identifier_list[iii])],hz_shift[int(scan_identifier_list[iii])])
        #crop exit wave
        out1 = torch.abs(waveout[EWCI[0]:-EWCI[1],EWCI[2]:-EWCI[3]])
        with torch.no_grad(): 
            out1.data[torch.isnan(out1.data)] = 0
        #apply Total Variation Penalties / other regularizers
        if i < tv_start:
            tvx = 0
            tvy = 0
            tvz = 0
            tvsy = 0
            tvsz = 0
#             voxel_weights = 0
        else:

            if use_per_scan_TV == True:
                tvx = 0#torch.nansum(torch.abs(torch.diff(xr[:,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
               #int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)],dim=0)**2))*tvx_alpha
            else:
                tvx = 0#torch.nanmean(torch.abs(torch.diff(xr,dim=0)**2))*tvx_alpha#torch.nansum(torch.abs(torch.diff(xr[:,dzy+yzpad:-(dzy+yzpad),dzz+yzpad:-(dzz+yzpad)],dim=0)**2))*tvx_alpha
            if use_per_scan_TV == True:
                tvy = torch.nansum(torch.abs(torch.diff(xr[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)],dim=1)**2))*tvy_alpha 
                tvz = torch.nansum(torch.abs(torch.diff(xr[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)],dim=2)**2))*tvz_alpha 
            else: 
                tvy = torch.nanmean(torch.abs(torch.diff(xr[0,:,:],n=1,dim=0)**2))**0.5*tvy_alpha 
                tvz = torch.nanmean(torch.abs(torch.diff(xr[0,:,:],n=1,dim=1)**2))**0.5*tvz_alpha 
        tv_total = (tvx+tvy+tvz)*tvt_alpha
        
        
        this_GT = GTs[:,:,iii]#*detector_mask_crop
        #normalise total flux of ecah scan perhaps ? 
        with torch.no_grad():
            flux_ratio = 1
        #calculate the loss 
        if use_noise == 1:
            this_diff = (torch.abs((torch.abs(this_GT)) - (((out1)))))**2
            loss = torch.nanmean(this_diff)+tv_total
        else:
            loss = torch.nanmean((torch.abs((torch.abs(this_GT)) - (((out1)))))**2)+tv_total

        allocated_memory = torch.cuda.memory_allocated()
        backtime1 = time.time()
        
        ### the all important backward function ###
        loss.backward()
        #after backward:

        #optimize scan offsets if youre doing that
        if optimize_scan_offsets ==  1 :
            if i > 10:
                spsz_cumulative += subpixel_shifts_z.grad
                spsy_cumulative += subpixel_shifts_y.grad
                hx_cumulative += hx_shift.grad
                hz_cumulative += hz_shift.grad
                
        #start storing gradients of the recon structure into xrg
        if do_gradient_accumulation == True:
            with torch.no_grad():
                this_xrg = xr.grad[:,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)]
                this_xrg[torch.isnan(this_xrg)] = 0

                
                xrg[:,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)] += this_xrg
        else: 
            with torch.no_grad():
                this_xrg = xr.grad[0,int(new_yposindex[iii][0]+yzpad):int(new_yposindex[iii][1]+yzpad),
                                   int(new_zposindex[iii][0]+yzpad):int(new_zposindex[iii][1]+yzpad)]
                optimizer.param_groups[0]['lr'] = (grads_target/(torch.mean(this_xrg)+torch.std(this_xrg)))
            
            if i< 500:
                optimizer.step()
            else:
                optimizer_adam.step()
            xr.data = torch.clamp(xr.data,0,params_size[0])

        
        #accumulate grads for probe param and prop amt.
        with torch.no_grad():
            probes_cumulative_grad += probes_param.grad
            probes_cumulative_grad[torch.isnan(probes_cumulative_grad)] = 0
#             PO += probe_prop_amt.grad
        # if use_noise == 1:
        #     noise_optim.step()
        #     #noise is modelled as a normal distribution, and so we clip it to within 3 std of mean.
        #     noise_guess.data = torch.clamp(noise_guess.data,0,(noise_mean+3*noise_std))
    

        backtime2 = time.time()
        scan_counter += 1
        loss_tracker[i,iii] = loss.item()
        print_progress_bar(scan_counter + 1, num_scans)
        
        
        ### end loop over scans###
    #reapply gradients and actually update the structure
    if do_gradient_accumulation == True:
        optimizer.zero_grad()
        optimizer_adam.zero_grad()
        if using_n_scans_per_pixel == True:
            xr.grad = xrg/beamfootprints #warning: cant also be in the previous bit of code... 
        else:        
            xr.grad = xrg

        if i< 500:
            optimizer.step()
        else:
            optimizer_adam.step()
        
        with torch.no_grad():
            xr.data = torch.clamp(xr.data,0,params_size[0])#*beam_footprints_binary
            xr.data[torch.isnan(xr.data)] = 0
        # if i == 0:
        #     optimizer.param_groups[0]['lr'] = 1e6
            # if i == 4:
            #     str = xr>torch.quantile((xr[xr>0]),0.8)
            #     xr[str].data += 3
            
    subpixel_shifts_y.grad = spsy_cumulative
    subpixel_shifts_z.grad = spsz_cumulative
    hx_shift.grad = hx_cumulative 
    hz_shift.grad = hz_cumulative 
    # scan_positions_optim.step()
    subpixel_shifts_z.data = torch.clamp(subpixel_shifts_z.data,-2e-7,2e-7)
    subpixel_shifts_y.data = torch.clamp(subpixel_shifts_y.data,-2e-6,2e-6)
    hx_shift.data = torch.clamp(hx_shift.data,-2e-6,2e-6)
    hz_shift.data = torch.clamp(hz_shift.data,-2e-7,2e-7)
    #transfer probe cumulative gra ot the actual tensor.
    probes_param.grad = probes_cumulative_grad/num_scans
    
    
    
    #only update the probe after 'probestart'
    if i > probestart:
        with torch.no_grad():
            if i < 8: #only have this high LR for the first few iterations.
                probe_optimizer.param_groups[0]['lr'] = probe_grads_target/(torch.mean(torch.abs(probes_cumulative_grad/num_scans))+torch.std(probes_cumulative_grad/num_scans)) 
        # probe_optimizer.step()
        probes_param.data[torch.isnan(probes_param)] = 0
#         probes_param.data[torch.isnan(probes_param)] = probe_orig.data[torch.isnan(probes_param)]

    with torch.no_grad():    

        if probe_is_FFT == 1:
                pass #probes_param.data = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifft2(torch.fft.fftshift(probes_param.data))*make_prop_mask(orig_csaxs_probe,80,5,1)))
        else:
                for n3 in range((probes_param.size(2))):    
                    probes_param[:,:,n3].data *= torch.clamp(make_prop_mask(probes_param[:,:,0].data,95,3,psr),1e-8,1)

             
    if i > divergence_count_start:            
        if (np.mean(loss_tracker,1)[i-1]) < (np.mean(loss_tracker,1)[i]):
            divergence_count += 1 
            print("loss increasing, count:",divergence_count)

        if divergence_count > divergence_thr:
            grads_target *= 0.5
            probe_grads_target *= 0.5
            optimizer.param_groups[0]['lr'] *= 0.5
            optimizer_adam.param_groups[0]['lr'] *= 0.5
            probe_optimizer.param_groups[0]['lr'] *= 0.5
            
            tvt_alpha *= 0.7
            if using_rprop == 1:
                optimizer.param_groups[0]['step_sizes'] = (optimizer.param_groups[0]['step_sizes'][0]*0.8,
                                                           optimizer.param_groups[0]['step_sizes'][1])


            print('loss increasing, reducing lr, tv')
            divergence_count = 0


    xr.data = torch.clamp(xr.data,0,params_size[0])

    mstime2 = time.time()
    
    
    
    # Backpropagation
    hz_shift_formatted = " ".join("%2.2E" % x.item() for x in hz_shift.data)
    hx_shift_formatted = " ".join("%2.2E" % x.item() for x in hx_shift.data)
    print()
    print("=======")
    print(i+1, "/", num_iters, "%2.5E" % np.mean(loss_tracker[i]), 
          "mem used:", 
          allocated_memory // (1024 * 1024), "MB",
          "MS fwd time", "%2.2F" % (mstime2-mstime1),
          "Bkwd time", "%2.2F" % ((backtime2-backtime1)))
    print("typical loss", "%2.2E" %torch.mean(torch.abs(torch.abs(GT[:,:,iii])
                                    - torch.abs(out1))**2),
          "typical tvx", "%2.2E" % tvx,
          "typical tvy", "%2.2E" % tvy,
          "typical tvz", "%2.2E" % tvz,
          "typical tv_total", "%2.2E" % tv_total,
          "calculated learning rate", "%2.2E" % (optimizer.param_groups[0]['lr']),
          "calculated probe learning rate", "%2.2E" % (probe_optimizer.param_groups[0]['lr'])
          
         )
    print("=======")
    if i > 0:
        ax.cla()
        plt.close(fig)
        plt.close()
        fig = plt.figure(figsize=[10,10])
    # iters_out = torch.cat((iters_out,xr[0,:,:].detach()),0)
    
    ax1 = fig.add_subplot(2,2,1)

    if not ((dzz == 0) & (dzy == 0) & (yzpad == 0)):    
        im1 = ax1.imshow(torch.sum(xr[:,(dzy+yzpad):-(dzy+yzpad),(dzz+yzpad):-(dzz+yzpad)].detach().cpu(),dim=0), interpolation='none', aspect=2.5)
    else:
        im1 = ax1.imshow((xr[0,:,:].detach().cpu()), interpolation='none', aspect=1)
        
    fig.colorbar(im1, ax=ax1)  # Add colorbar to the first subplot
    
    ax2 = fig.add_subplot(2,2,2)
    if use_multiple_probe_modes == True:
        im2 = ax2.imshow(torch.abs(combine_probe_modes(probes_param,iii)).detach().cpu(), vmin=0, vmax=probe_vmax, interpolation='none', aspect=1/psr)
    else:
        im2 = ax2.imshow(torch.abs(combine_probe_modes(probes_param,iii)).detach().cpu(), vmin=0, vmax=probe_vmax, interpolation='none', aspect=1/psr)
    
    ax3 = fig.add_subplot(2,2,3)
    ax3.imshow(torch.log((torch.abs(out1)*flux_ratio)+1e-9).detach().cpu(), vmin=recon_FFT_vmin, vmax=recon_FFT_vmax, interpolation='None')
    
    ax4 = fig.add_subplot(2,2,4)
    ax4.imshow(torch.log(torch.abs(this_GT)+1e-9).detach().cpu(), vmin=recon_FFT_vmin, vmax=recon_FFT_vmax, interpolation='None')
    display(fig)
    
total_sim_end_time = time.time()
print()
print("finished")

print("total time for",num_iters,"iters:", "%2.2f" % ((total_sim_end_time-total_sim_start_time)/60),"min")

In [None]:
with torch.no_grad():
    plt.figure(figsize=[15,5])
    plt.subplot(1,2,1)
    
    
    ny, nx = xr.squeeze().shape
    extent = [0, nx * slab_thickness.cpu()*1e6, 0, ny * voxel_size[1].cpu()*1e6]
    plt.imshow((xr*voxel_size[0]*1e9).squeeze().cpu(),extent=extent,aspect='auto')
    cbar1 = plt.colorbar()
    cbar1.set_label("height (nm)")
    plt.title("reconstruction")
    plt.ylabel("y (μm)")
    plt.xlabel("z (μm)")


    
    plt.subplot(1,2,2)
    plt.imshow((ground_truth_image*voxel_size[0]*1e9).squeeze().cpu(),extent=extent,aspect='auto')
    plt.title("Ground Truth")
    plt.ylabel("y (μm)")
    plt.xlabel("z (μm)")

    
    cbar2 = plt.colorbar()
    cbar2.set_label("height (nm)")



In [None]:

plt.figure(figsize=[10,5])

lossmeans = np.mean(loss_tracker+1e-20,1)
losserrs = np.std(loss_tracker+1e-20,1)
plt.errorbar(np.arange(0,num_iters),lossmeans,losserrs)
plt.plot((loss_tracker),alpha=0.1)
plt.xlim([0,(i+5)])
plt.title("Loss")
plt.xlabel("n iterations")

In [None]:
save_filepath = '/your_output_directory/' 
save_filename = 'demo_script_output'
today_date = datetime.today().strftime('%d-%m-%Y')
full_filename = f"{save_filepath}{save_filename}_{mean_radius*40}_nm_{today_date}.pt"
print("full filename",full_filename)
# Example metadata
#blob_image,blob_heights,blob_x,blob_y,blob_rad,blob_width
metadata = {
    "slab_thickness": slab_thickness,
    "simulation_size": full_sim_size,
    "voxel_size": voxel_size,
    "probes": probes_param,
    "oversample_factor": oversample_factor,
    "num_inc_angles": len(np.unique(inc_angle_list)),
    "inc_angles": np.unique(inc_angle_list),
    "loss_function": loss_tracker,

    "GT": blob_tensor.detach(),
    

}

# Save tensor and metadata
torch.save({"recon": xr.detach(), "metadata": metadata}, full_filename)

