# Simple Forward Model
Now we have some crystallographic functions and we can handle the detector geometry, we can perform a basic forward model of a single crystal to reassure ourselves that this wasn't all for nothing!

In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as jR
from matplotlib import pyplot as plt

import anri

import time
import urllib.request

start = time.time()

In `Anri`, all fundamental functions and transforms are written for single vectors.  
This was written to significantly simplify the functions themselves, keeping them easy to understand.  
Additonally, when forward simulating many grains or voxels, you will likely have more complicated array shapes, so it makes sense to leave the broadcasting to the user or another part of the program for now.  
I'm currently still grappling with the best way to expose vmapped functions in the API, so for now I will manually declare them here:

In [None]:
# easy example: many hkls, single B matrix, so we vmap over hkls only, giving us [0, None]
hkl_B_to_q_crystal_vec = jax.vmap(anri.crystal.hkl_B_to_q_crystal, in_axes=[0, None])
omega_solns_vec = jax.vmap(anri.diffract.omega_solns, in_axes=[0, None, None])
sample_to_lab_vec = jax.vmap(anri.geom.sample_to_lab, in_axes=[0, 0, None, None, None, None])
q_lab_to_k_out_vec = jax.vmap(anri.diffract.q_lab_to_k_out, in_axes=[0, None])
raytrace_to_det_vec = jax.vmap(anri.geom.raytrace_to_det, in_axes=[0, None, None, None, None])
q_lab_to_tth_eta_vec = jax.vmap(anri.diffract.q_lab_to_tth_eta, in_axes=[0, None])

## Crystallography

Let's take a CIF from Dan's Diffraction again.

In [None]:
cif_path = "https://github.com/DanPorter/Dans_Diffraction/raw/refs/heads/master/Dans_Diffraction/Structures/Iron.cif"
urllib.request.urlretrieve(cif_path, "Iron.cif")
struc = anri.crystal.Structure.from_cif("Iron.cif")

We generate some hkls:

In [None]:
dsmax = 2.0
wavelength = 0.3
struc.make_hkls(dsmax=dsmax, wavelength=wavelength)

In [None]:
struc.rings_dict[0]

Now we can generate some scattering vectors in the crystal frame:

In [None]:
q_crystal = hkl_B_to_q_crystal_vec(struc.ringhkls_arr, struc.B)

Let's generate a random orientation.

In [None]:
key = jax.random.key(time.time_ns())
random_euler = jax.random.uniform(key, shape=(3,), minval=-90.0, maxval=90.0)
U = jR.from_euler('XYZ', random_euler, degrees=True).as_matrix()
U

We can rotate the scattering vectors into the sample frame:

In [None]:
q_sample = (U @ q_crystal.T).T

In [None]:
q_sample.shape

## Ewald condition

Now we can determine the omega angles required to diffract:

In [None]:
chi = 0.0
wedge = 0.0
dty = 0.0
y0 = 0.0

# define incoming wavevector in the lab frame
k_in_lab = jnp.array([1., 0, 0])
k_in_lab_norm = anri.diffract.scale_norm_k(k_in_lab, wavelength)

# map it into the sample frame
k_in_sample_norm = anri.geom.lab_to_sample(k_in_lab_norm, 0.0, wedge, chi, dty, y0)
# etasign +1:
omega1, valid1 = omega_solns_vec(q_sample, 1.0, k_in_sample_norm)
# etasign -1:
omega2, valid2 = omega_solns_vec(q_sample, -1.0, k_in_sample_norm)
omega = jnp.concatenate([omega1, omega2])
valid = jnp.concatenate([valid1, valid2])
q_sample = jnp.concatenate([q_sample, q_sample])
omega_valid = omega[valid]
q_sample_valid = q_sample[valid]

## Into the lab frame
With the omega angles determined, we can rotate `q_sample` into the lab frame:

In [None]:
q_lab = sample_to_lab_vec(q_sample_valid, omega_valid, wedge, chi, dty, y0)

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
ax.scatter(q_lab[:, 1], q_lab[:, 2])
ax.set_aspect(1)
ax.set(xlabel='Lab Y', ylabel='Lab Z')
plt.show()

## Into the detector

Now we can forward-project them into the detector!

Let's set up the detector transforms:

In [None]:
y_center = 1000.0
z_center = 1000.0
y_size = 75.0
z_size = 75.0
tilt_x = 0.0
tilt_y = 0.0
tilt_z = 0.0
distance = 180e3
o11 = 1
o12 = 0
o21 = 0
o22 = 1
det_trans, beam_cen_shift, x_distance_shift = anri.geom.detector_transforms(
    y_center,
    y_size,
    tilt_y,
    z_center,
    z_size,
    tilt_z,
    tilt_x,
    distance,
    o11,
    o12,
    o21,
    o22
)

We get the detector basis vectors in the lab frame:

In [None]:
sc_lab, fc_lab, norm_lab = anri.geom.detector_basis_vectors_lab(det_trans, beam_cen_shift, x_distance_shift)

Now we can map into detector space:

In [None]:
origin_lab = jnp.array([0., 0, 0])

# get outgoing scattering vector
k_out = q_lab_to_k_out_vec(q_lab, k_in_lab_norm)
# ray-trace it into the detector
sc, fc = raytrace_to_det_vec(k_out, origin_lab, sc_lab, fc_lab, norm_lab)

## Results

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
ax.scatter(fc, sc)
ax.set_aspect(1)
ax.set(xlabel='Detector fast', ylabel='Detector slow')
# set some sensible detector limits
ax.set_xlim(0, 2048)
ax.set_ylim(0, 2048)
plt.show()

In [None]:
tth, eta = q_lab_to_tth_eta_vec(q_lab, wavelength)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(tth, eta, label='Peaks')
ax.vlines(struc.ringtth, -25, 25, color='red', label='Unit cell')
ax.set(xlabel=r'$2\theta$', ylabel=r'$\eta$')
ax.legend(loc='upper right')
plt.show()

In [None]:

print(f'Computed peaks in the first ring: {(tth < 10).sum()}\n2x multiplicity of first ring (Friedel pairs): {struc.ringmult[0] * 2}')

In [None]:
end = time.time()
print(f'Took {(end - start):.1f} seconds')