In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", 'cpu')

from ImageD11.parameters import AnalysisSchema
from ImageD11.columnfile import columnfile

In [2]:
pars = AnalysisSchema.from_default().geometry_pars_obj

In [3]:
pars.set('tilt_x', 0.00123)
pars.set('tilt_y', -0.0345)
pars.set('tilt_z', 0.02)
pars.set('chi', 1)
pars.set('wedge', -3)
pars.set('t_x', 1)
pars.set('t_y', 2)
pars.set('t_z', 3)

In [4]:
nrows = 100_000

fc = np.random.random(nrows) * 2048
sc = np.random.random(nrows) * 2048
om = np.random.random(nrows) * 360

In [5]:
cf = columnfile(new=True)
cf.nrows = nrows

cf.addcolumn(fc, 'fc')
cf.addcolumn(sc, 'sc')
cf.addcolumn(om, 'omega')

In [6]:
cf.parameters = pars

In [7]:
cf.updateGeometry()

In [8]:
cf.titles

['fc', 'sc', 'omega', 'xl', 'yl', 'zl', 'tth', 'eta', 'ds', 'gx', 'gy', 'gz']

In [9]:
gvecs = np.column_stack([cf.gx, cf.gy, cf.gz])

In [10]:
import ImageD11.transform, ImageD11.gv_general

In [11]:
import transform as mytrans

In [12]:
import importlib
importlib.reload(mytrans)

<module 'transform' from '/home/esrf/james1997a/Code/Anri/anri/sandbox/transform.py'>

# Test (xl, yl, zl)

In [13]:
xyz_me = mytrans.det_to_xyz_lab(cf.sc, cf.fc, pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [14]:
assert np.allclose(xyz_me[:, 0], cf.xl)
assert np.allclose(xyz_me[:, 1], cf.yl)
assert np.allclose(xyz_me[:, 2], cf.zl)

# Test (tth, eta)

In [15]:
tth_me, eta_me = mytrans.xyz_lab_to_tth_eta(jnp.stack((cf.xl, cf.yl, cf.zl), axis=1),
                                    cf.omega,
                                    jnp.stack((pars.get('t_x'), pars.get('t_y'), pars.get('t_z')), axis=0),
                                    pars.get('wedge'), pars.get('chi'))

In [16]:
assert np.allclose(tth_me, cf.tth)
assert np.allclose(eta_me, cf.eta)

# Test g-vectors

In [17]:
gvecs_me = mytrans.tth_eta_omega_to_g(cf.tth, cf.eta, cf.omega, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'))

In [18]:
assert np.allclose(gvecs_me, gvecs)

# Full pipeline test: sc, fc, omega, origins to g-vectors

In [21]:
gvecs_me = mytrans.det_to_g(cf.sc, cf.fc, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))),
                            pars.get('wedge'), pars.get('chi'), pars.get('wavelength'),
                            pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'), jnp.array([1., 0, 0]) )

In [22]:
assert np.allclose(gvecs_me, gvecs)