# White noise object dataset

> Georgios Varnavides | Oct 24 2025  
>
> Sample: white-noise object  
> Sampling conditions:2 A^-1 max-scattering angle, Nyquist sampling  
> Imaging conditions: 1 A^-1 semiangle, 200 A defocus, 100 A stig 

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import quantem as em
from scipy.ndimage import rotate

In [3]:
# parameters
n = 96
k_max = 2 # inverse Angstroms
k_probe = 1 # inverse Angstroms

energy = 300e3
wavelength = em.core.utils.utils.electron_wavelength_angstrom(energy)
sampling = 1 / k_max / 2 # Angstroms
reciprocal_sampling = 2 * k_max / n # inverse Angstroms

scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1

aberrations = {
    "C10": 200,
    "C12": 100,
    "phi12": np.deg2rad(11)
}

In [4]:
def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

In [5]:
# we build probe in Fourier space, using a soft aperture

kx = ky = np.fft.fftfreq(n,sampling)
k2 = kx[:,None]**2 + ky[None,:]**2
k  = np.sqrt(k2)
phi = np.arctan2(ky[None,:],kx[:,None])

aperture_fourier = np.sqrt(
    np.clip(
        (k_probe - k)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)

def prepare_probe(aberrations):
    
    chi = np.zeros_like(k)
    chi += k2 * wavelength * np.pi * aberrations.get("C10",0.0)
    chi += k2 * wavelength * np.pi * (aberrations.get("C12",0.0) * np.cos(2 * (phi - aberrations.get("phi12",0.0))))
    # chi += k**3 * wavelength**2 * 2 / 3 * np.pi * (aberrations.get("C21",0.0) * np.cos(phi - aberrations.get("phi21",0.0)))
    # chi += k**3 * wavelength**2 * 2 / 3 * np.pi * (aberrations.get("C23",0.0) * np.cos(3*(phi - aberrations.get("phi23",0.0))))
    # chi += k**4 * wavelength**3 * 2 / 4 * np.pi * aberrations.get("C30",0.0)
    exp_chi = np.exp(-1j*chi)
    
    probe_array_fourier = aperture_fourier * exp_chi
    probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
    probe_array = np.fft.ifft2(probe_array_fourier) * n

    return chi, probe_array_fourier, probe_array

chi, probe_array_fourier, probe_array = prepare_probe(aberrations)

In [6]:
def return_patch_indices(positions_px,roi_shape,obj_shape):
    """ """
    x0 = np.round(positions_px[:, 0]).astype("int")
    y0 = np.round(positions_px[:, 1]).astype("int")

    x_ind = np.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int")
    y_ind = np.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int")

    row = (x0[:, None, None] + x_ind[None, :, None]) % obj_shape[0]
    col = (y0[:, None, None] + y_ind[None, None, :]) % obj_shape[1]

    return row, col

def simulate_exit_waves(
    complex_obj,
    probe,
    row,
    col,
):
    """ """
    obj_patches = complex_obj[row,col]
    exit_waves = obj_patches * probe
    return obj_patches, exit_waves

def simulate_intensities(
    complex_obj,
    probe,
    row,
    col,
):
    """ """
    obj_patches, exit_waves = simulate_exit_waves(
        complex_obj,
        probe,
        row,
        col
    )
    fourier_exit_waves = np.fft.fft2(exit_waves)
    intensities = np.abs(fourier_exit_waves)**2
    return obj_patches, exit_waves, fourier_exit_waves, intensities


def sum_patches_base(patches, positions_px, roi_shape, object_shape):
    """ """
    
    x0 = np.round(positions_px[:, 0]).astype("int")
    y0 = np.round(positions_px[:, 1]).astype("int")

    x_ind = np.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int")
    y_ind = np.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int")

    flat_weights = patches.ravel()
    indices = ((y0[:, None, None] + y_ind[None, None, :]) % object_shape[1]) + (
        (x0[:, None, None] + x_ind[None, :, None]) % object_shape[0]
    ) * object_shape[1]
    
    counts = np.bincount(
        indices.ravel(), weights=flat_weights, minlength=np.prod(object_shape)
    )
    counts = np.reshape(counts, object_shape)
    return counts

def sum_patches(patches, positions_px, roi_shape, obj_shape):
    """ """

    if np.iscomplexobj(patches):
        real = sum_patches_base(
            patches.real, positions_px, roi_shape, obj_shape
        )
        imag = sum_patches_base(
            patches.imag, positions_px, roi_shape, obj_shape
        )
        return real + 1.0j * imag
    else:
        return sum_patches_base(patches, positions_px, roi_shape, obj_shape)

In [7]:
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)

sim_row, sim_col = return_patch_indices(positions,(n,n),(n,n))

_,_,_, intensities = simulate_intensities(
    complex_obj,
    probe_array,
    sim_row,
    sim_col
)

In [9]:
rotation_angle_deg = -13
rotation_angle_rad = np.deg2rad(rotation_angle_deg)

rotated_intensities = np.fft.ifftshift(
    rotate(
        np.fft.fftshift(
            intensities,
            axes=(-1,-2)
        ),
        -rotation_angle_deg,
        axes=(-1,-2),
        reshape=False,
        order=3,
    ),
    axes=(-1,-2)
)

In [11]:
dataset = em.core.datastructures.Dataset4dstem.from_array(
    np.fft.fftshift(rotated_intensities.reshape((sx,sy,n,n)),axes=(-1,-2)),
    sampling=[scan_step_size*sampling,scan_step_size*sampling,reciprocal_sampling, reciprocal_sampling],
    units=['A','A','A^-1','A^-1']
)

dataset.save("../data/white-noise-object_defocus+stig.zip",mode='o')