In [None]:
from ptyrodactyl import simul, tools, workflows
import jax
import jax.numpy as jnp

In [None]:
import numpy as np
import matplotlib.pyplot as plt
gb500 = np.load("../../data/mos2_500x500_GB_tiled_stem4d.npz")
plt.imshow(gb500["data"][10, :, :])

In [None]:
jax.devices()

In [None]:
mos2_gb_file = "../../data/for_CBED_MoS2_poly.xyz"

In [None]:
mos2_data: tools.CrystalData = simul.parse_crystal(mos2_gb_file)
mos2_data.positions.shape

In [None]:
x_pos = mos2_data.positions[:, 0]
y_pos = mos2_data.positions[:, 1]
z_pos = mos2_data.positions[:, 2]

In [None]:
jnp.max(x_pos), jnp.min(x_pos), jnp.max(y_pos), jnp.min(y_pos), jnp.max(z_pos), jnp.min(z_pos)


In [None]:
voltage_kv = 60
cbed_aperture_mrad = 5.0
cbed_extent_mrad = 50.0
cbed_shape = (256, 256)
real_space_pixel_size_ang = 0.1
slice_thickness = 1.0


In [None]:
# Use fewer scan positions for faster testing (40x40 = 1600 positions)
# Increase to 150x150 for full resolution once testing is complete
yy, xx = jnp.meshgrid(jnp.linspace(-50, 100, 40), jnp.linspace(-50, 100, 40))
scan_positions = jnp.asarray((yy.ravel(), xx.ravel())).T
print(f"Number of scan positions: {scan_positions.shape[0]}")
scan_positions

In [None]:
cbed_simulation = workflows.crystal2stem4d(
    crystal_data=mos2_data, 
    scan_positions=scan_positions, 
    voltage_kv=voltage_kv, 
    cbed_aperture_mrad=cbed_aperture_mrad, 
    cbed_extent_mrad=cbed_extent_mrad, 
    cbed_shape=cbed_shape, 
    real_space_pixel_size_ang=real_space_pixel_size_ang, 
    slice_thickness=slice_thickness,
    force_parallel=True
    )