In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", 'cpu')


from ImageD11.parameters import AnalysisSchema
from ImageD11.columnfile import columnfile

In [2]:
pars = AnalysisSchema.from_default().geometry_pars_obj

In [3]:
pars.set('tilt_x', 0.00123)
pars.set('tilt_y', -0.0345)
pars.set('tilt_z', 0.02)
pars.set('chi', 1)
pars.set('wedge', -3)
pars.set('t_x', 1)
pars.set('t_y', 2)
pars.set('t_z', 3)

In [4]:
nrows = 100_000

fc = np.random.random(nrows) * 2048
sc = np.random.random(nrows) * 2048
om = np.random.random(nrows) * 360

In [5]:
cf = columnfile(new=True)
cf.nrows = nrows

cf.addcolumn(fc, 'fc')
cf.addcolumn(sc, 'sc')
cf.addcolumn(om, 'omega')

In [6]:
cf.parameters = pars

In [7]:
cf.updateGeometry()

In [8]:
cf.titles

['fc', 'sc', 'omega', 'xl', 'yl', 'zl', 'tth', 'eta', 'ds', 'gx', 'gy', 'gz']

In [9]:
gvecs = np.column_stack([cf.gx, cf.gy, cf.gz])

In [10]:
import ImageD11.transform, ImageD11.gv_general

In [11]:
import transform as mytrans

In [12]:
import importlib
importlib.reload(mytrans)

<module 'transform' from '/home/esrf/james1997a/Code/Anri/anri/sandbox/transform.py'>

# Test k-vectors to (tth, eta)

In [15]:
k_id11 = ImageD11.transform.compute_k_vectors(cf.tth, cf.eta, pars.get('wavelength'))

In [18]:
tth_me, eta_me = mytrans.k_to_tth_eta(k_id11.T, pars.get('wavelength'))

In [19]:
assert np.allclose(tth_me, cf.tth)
assert np.allclose(eta_me, cf.eta)

# Test k to (sc, fc)

In [21]:
sc_me, fc_me = mytrans.k_to_det(k_id11.T, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))),
                                 pars.get('wedge'), pars.get('chi'), pars.get('wavelength'),
                                pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                  pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                  pars.get('tilt_x'),
                                  pars.get('distance'),
                                  pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'), )

In [22]:
assert np.allclose(sc_me, cf.sc)
assert np.allclose(fc_me, cf.fc)

# Test (xl, yl, zl) to (sc, fc)

In [26]:
sc_me, fc_me = mytrans.xyz_lab_to_det(cf.xl, cf.yl, cf.zl, pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                  pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                  pars.get('tilt_x'),
                                  pars.get('distance'),
                                  pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22')).T

In [29]:
assert np.allclose(sc_me, cf.sc)
assert np.allclose(fc_me, cf.fc)

# Test uncompute g

In [33]:
tth_id11, (eta1_id11, eta2_id11), (omega1_id11, omega2_id11) = ImageD11.transform.uncompute_g_vectors(gvecs.T, pars.get('wavelength'), pars.get('wedge'), pars.get('chi'))

In [30]:
tth_me, (eta1_me, eta2_me), (omega1_me, omega2_me) = mytrans.g_to_tth_eta_omega(gvecs, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'))

In [34]:
assert np.allclose(tth_id11, tth_me)
assert np.allclose(eta1_id11, eta1_me)
assert np.allclose(eta2_id11, eta2_me)
assert np.allclose(omega1_id11, omega1_me)
assert np.allclose(omega2_id11, omega2_me)