In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", 'cpu')


from ImageD11.parameters import AnalysisSchema
from ImageD11.columnfile import columnfile

In [2]:
pars = AnalysisSchema.from_default().geometry_pars_obj

In [3]:
pars.set('tilt_x', 0.00123)
pars.set('tilt_y', -0.0345)
pars.set('tilt_z', 0.02)
pars.set('chi', 1)
pars.set('wedge', -3)
pars.set('t_x', 1)
pars.set('t_y', 2)
pars.set('t_z', 3)

In [4]:
nrows = 100_000

fc = np.random.random(nrows) * 2048
sc = np.random.random(nrows) * 2048
om = np.random.random(nrows) * 360

In [5]:
cf = columnfile(new=True)
cf.nrows = nrows

cf.addcolumn(fc, 'fc')
cf.addcolumn(sc, 'sc')
cf.addcolumn(om, 'omega')

In [6]:
cf.parameters = pars

In [7]:
cf.updateGeometry()

In [8]:
cf.titles

['fc', 'sc', 'omega', 'xl', 'yl', 'zl', 'tth', 'eta', 'ds', 'gx', 'gy', 'gz']

In [9]:
gvecs = np.column_stack([cf.gx, cf.gy, cf.gz])

In [10]:
import ImageD11.transform, ImageD11.gv_general

In [11]:
import transform as mytrans

In [12]:
import importlib
importlib.reload(mytrans)

<module 'transform' from '/home/esrf/james1997a/Code/Anri/anri/sandbox/transform.py'>

# Full pipeline test: sc, fc, omega, origins (zeroes for now) to g-vectors

In [21]:
gvecs_me = mytrans.det_to_g(cf.sc, cf.fc, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))),
                            pars.get('wedge'), pars.get('chi'), pars.get('wavelength'),
                            pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'), )

In [22]:
gvecs_me

Array([[-0.32244418,  0.0760074 ,  0.69370081],
       [-0.10574296, -0.09161964,  0.03334961],
       [ 1.04111012, -0.50065258, -1.00153655],
       ...,
       [-1.08158685, -0.86094558,  1.05477111],
       [ 0.81374755,  1.11029796,  1.51385406],
       [-0.295523  ,  0.88655418, -0.96764699]], dtype=float64)

In [23]:
gvecs

array([[-0.32244418,  0.0760074 ,  0.69370081],
       [-0.10574296, -0.09161964,  0.03334961],
       [ 1.04111012, -0.50065258, -1.00153655],
       ...,
       [-1.08158685, -0.86094558,  1.05477111],
       [ 0.81374755,  1.11029796,  1.51385406],
       [-0.295523  ,  0.88655418, -0.96764699]])

In [24]:
assert np.allclose(gvecs_me, gvecs)

# Lab <-> Detector

In [25]:
%%time

xyz_id11 = ImageD11.transform.compute_xyz_lab((cf.sc, cf.fc), **pars.parameters)

CPU times: user 61 ms, sys: 0 ns, total: 61 ms
Wall time: 2.15 ms


In [27]:
%%time

xyz_me = mytrans.det_to_xyz_lab(cf.sc, cf.fc, pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

CPU times: user 643 μs, sys: 0 ns, total: 643 μs
Wall time: 448 μs


In [28]:
assert np.allclose(xyz_id11.T, xyz_me)

In [29]:
v_det_me = mytrans.xyz_lab_to_det(cf.xl, cf.yl, cf.zl,
                                  pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [30]:
assert np.allclose(v_det_me[:, 0], cf.sc)
assert np.allclose(v_det_me[:, 1], cf.fc)

# Lab <-> tth, eta, omega

In [31]:
tth_id11, eta_id11 = ImageD11.transform.compute_tth_eta_from_xyz(np.stack((cf.xl, cf.yl, cf.zl)), cf.omega, **pars.parameters)

In [32]:
tth_me, eta_me = mytrans.xyz_lab_to_tth_eta(jnp.column_stack((cf.xl, cf.yl, cf.zl)), cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))), pars.get('wedge'), pars.get('chi'))

In [33]:
assert np.allclose(tth_me, tth_id11)
assert np.allclose(eta_me, eta_id11)

In [34]:
fc_id11, sc_id11 = ImageD11.transform.compute_xyz_from_tth_eta(cf.tth, cf.eta, cf.omega, **pars.parameters)

In [35]:
assert np.allclose(fc_id11, cf.fc)
assert np.allclose(sc_id11, cf.sc)

In [36]:
sc_me, fc_me = mytrans.tth_eta_omega_to_det(cf.tth, cf.eta, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))),
                                          pars.get('wedge'), pars.get('chi'),  pars.get('wavelength'),
                                          pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                          pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                          pars.get('tilt_x'),
                                          pars.get('distance'),
                                          pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [37]:
assert np.allclose(sc_me, cf.sc)
assert np.allclose(fc_me, cf.fc)

In [38]:

dxyzl = mytrans.det_to_xyz_lab(sc, fc, pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [39]:
dxyzl.shape

(100000, 3)

In [40]:
t_id11 = ImageD11.transform.compute_grain_origins(cf.omega, pars.get('wedge'), pars.get('chi'), cf.xl, cf.yl, cf.zl)

In [41]:
t_me = mytrans.sample_to_lab(np.column_stack((cf.xl, cf.yl, cf.zl)), cf.omega, pars.get('wedge'), pars.get('chi'))

In [42]:
assert np.allclose(t_id11, t_me.T)

In [47]:
%%time

tth_id11, (eta1_id11, eta2_id11), (omega1_id11, omega2_id11) = ImageD11.transform.uncompute_g_vectors(gvecs.T, pars.get('wavelength'), pars.get('wedge'), pars.get('chi'))

CPU times: user 533 ms, sys: 6.48 ms, total: 539 ms
Wall time: 31.5 ms


In [48]:
%%time

tth_me, (eta1_me, eta2_me), (omega1_me, omega2_me) = mytrans.g_to_tth_eta_omega(gvecs, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), )

CPU times: user 320 ms, sys: 5.05 ms, total: 325 ms
Wall time: 213 ms


In [49]:
assert np.allclose(tth_id11, tth_me)
assert np.allclose(eta1_id11, eta1_me)
assert np.allclose(eta2_id11, eta2_me)
assert np.allclose(omega1_id11, omega1_me)
assert np.allclose(omega2_id11, omega2_me)

In [50]:
t_id11 = ImageD11.transform.compute_grain_origins(cf.omega, pars.get('wedge'), pars.get('chi'), cf.xl, cf.yl, cf.zl)

In [51]:
t_me = mytrans.sample_to_lab(np.column_stack((cf.xl, cf.yl, cf.zl)), cf.omega, pars.get('wedge'), pars.get('chi'))

In [52]:
t_id11.shape

(3, 100000)

In [53]:
t_me.shape

(100000, 3)

In [54]:
assert np.allclose(t_id11, t_me.T)

In [55]:
%%time

tth_id11, (eta1_id11, eta2_id11), (omega1_id11, omega2_id11) = ImageD11.transform.uncompute_g_vectors(gvecs.T, pars.get('wavelength'), pars.get('wedge'), pars.get('chi'))

CPU times: user 575 ms, sys: 15.8 ms, total: 591 ms
Wall time: 33.3 ms


In [56]:
tth_id11

array([12.55926516,  2.34531498, 25.12971977, ..., 28.648915  ,
       33.85193998, 22.06978978])

In [57]:
%%time

tth_me, (eta1_me, eta2_me), (omega1_me, omega2_me) = mytrans.g_to_tth_eta_omega(gvecs, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'))

CPU times: user 0 ns, sys: 742 μs, total: 742 μs
Wall time: 522 μs


In [58]:
tth_me

Array([12.55926516,  2.34531498, 25.12971977, ..., 28.648915  ,
       33.85193998, 22.06978978], dtype=float64)

In [59]:
assert np.allclose(tth_id11, tth_me)
assert np.allclose(eta1_id11, eta1_me)
assert np.allclose(eta2_id11, eta2_me)
assert np.allclose(omega1_id11, omega1_me)
assert np.allclose(omega2_id11, omega2_me)

In [60]:
eta1_me

Array([ 24.81834276,  77.50977577, 132.32951041, ...,  51.18759884,
        38.78171996, 137.35821893], dtype=float64)

In [61]:
eta1_id11

array([ 24.81834276,  77.50977577, 132.32951041, ...,  51.18759884,
        38.78171996, 137.35821893])

In [62]:
# test k vector computation
k_id11 = ImageD11.transform.compute_k_vectors(cf.tth, cf.eta, pars.get('wavelength'))

In [63]:
k_id11

array([[-0.08408603, -0.00294359, -0.3326114 , ..., -0.43021346,
        -0.59570555, -0.25748682],
       [-0.32073882,  0.13922682,  1.13769147, ..., -1.31279195,
        -1.22609488,  0.92782587],
       [ 0.69355918,  0.03598828, -0.96574662, ...,  1.05597973,
         1.52595118, -0.93941074]])

In [64]:
k_me = mytrans.tth_eta_to_k(cf.tth, cf.eta, pars.get('wavelength'))

In [65]:
k_me

Array([[-0.08408603, -0.32073882,  0.69355918],
       [-0.00294359,  0.13922682,  0.03598828],
       [-0.3326114 ,  1.13769147, -0.96574662],
       ...,
       [-0.43021346, -1.31279195,  1.05597973],
       [-0.59570555, -1.22609488,  1.52595118],
       [-0.25748682,  0.92782587, -0.93941074]], dtype=float64)

In [66]:
k_id11.shape

(3, 100000)

In [67]:
k_me.shape

(100000, 3)

In [68]:
assert np.allclose(k_id11.T, k_me)

In [69]:
# test computation of angles for g-vectors

W =  mytrans.wedgemat(pars.get('wedge'))
C =  mytrans.chimat(pars.get('chi'))
pre = (C @ W).T

oms_id11 = ImageD11.gv_general.g_to_k(gvecs.T, pars.get('wavelength'), axis=[0,0,1], pre=pre, post=None)

In [70]:
oms_id11

(array([  60.69632233,  130.35772382, -133.7054595 , ...,  112.03377579,
         -64.58801045,    5.45637236]),
 array([ -85.9125816 ,  -47.21417087,   81.53729172, ...,  -30.82771048,
         165.61804221, -143.6559431 ]),
 array([ True,  True,  True, ...,  True,  True,  True]))

In [71]:
oms_me = mytrans.omega_solns_for_g(gvecs, pars.get('wavelength'), np.array([0,0,1]), pre, jnp.eye(3))

In [72]:
oms_me

(Array([  60.69632233,  130.35772382, -133.7054595 , ...,  112.03377579,
         -64.58801045,    5.45637236], dtype=float64),
 Array([ -85.9125816 ,  -47.21417087,   81.53729172, ...,  -30.82771048,
         165.61804221, -143.6559431 ], dtype=float64),
 Array([ True,  True,  True, ...,  True,  True,  True], dtype=bool))

In [73]:
assert np.allclose(oms_id11[0], oms_me[0])
assert np.allclose(oms_id11[1], oms_me[1])
assert np.allclose(oms_id11[2], oms_me[2])

In [74]:
# now use k vectors to test g_from_k

In [75]:
# test k vector computation
k_id11 = ImageD11.transform.compute_k_vectors(cf.tth, cf.eta, pars.get('wavelength'))
g_id11 = ImageD11.transform.compute_g_from_k(k_id11, cf.omega, pars.get('wedge'), pars.get('chi'))

In [76]:
g_id11

array([[-0.32244418, -0.10574296,  1.04111012, ..., -1.08158685,
         0.81374755, -0.295523  ],
       [ 0.0760074 , -0.09161964, -0.50065258, ..., -0.86094558,
         1.11029796,  0.88655418],
       [ 0.69370081,  0.03334961, -1.00153655, ...,  1.05477111,
         1.51385406, -0.96764699]])

In [77]:
g_me = mytrans.lab_to_sample(k_id11.T, cf.omega, pars.get('wedge'), pars.get('chi'))

In [78]:
g_me

Array([[-0.32244418,  0.0760074 ,  0.69370081],
       [-0.10574296, -0.09161964,  0.03334961],
       [ 1.04111012, -0.50065258, -1.00153655],
       ...,
       [-1.08158685, -0.86094558,  1.05477111],
       [ 0.81374755,  1.11029796,  1.51385406],
       [-0.295523  ,  0.88655418, -0.96764699]], dtype=float64)

In [79]:
assert np.allclose(g_id11.T, g_me)

In [80]:
g_id11_trans = ImageD11.transform.compute_g_from_k(k_id11, cf.omega, pars.get('wedge'), pars.get('chi'))

In [81]:
g_id11_trans

array([[-0.32244418, -0.10574296,  1.04111012, ..., -1.08158685,
         0.81374755, -0.295523  ],
       [ 0.0760074 , -0.09161964, -0.50065258, ..., -0.86094558,
         1.11029796,  0.88655418],
       [ 0.69370081,  0.03334961, -1.00153655, ...,  1.05477111,
         1.51385406, -0.96764699]])

In [82]:
# C @ W -Z works

W =  mytrans.wedgemat(pars.get('wedge'))
C =  mytrans.chimat(pars.get('chi'))

post = C @ W

g_id11_gvgeneral = ImageD11.gv_general.k_to_g(k_id11, cf.omega, axis=np.array([0., 0., -1]), pre=None, post=post)

In [83]:
g_id11_gvgeneral

array([[-0.32244418, -0.10574296,  1.04111012, ..., -1.08158685,
         0.81374755, -0.295523  ],
       [ 0.0760074 , -0.09161964, -0.50065258, ..., -0.86094558,
         1.11029796,  0.88655418],
       [ 0.69370081,  0.03334961, -1.00153655, ...,  1.05477111,
         1.51385406, -0.96764699]])

In [84]:
assert np.allclose(g_id11_trans, g_id11_gvgeneral)

In [85]:
g_id11 = ImageD11.transform.compute_g_vectors(cf.tth, cf.eta, cf.omega, pars.get('wavelength'), pars.get('wedge'), pars.get('chi'))

In [86]:
g_id11

array([[-0.32244418, -0.10574296,  1.04111012, ..., -1.08158685,
         0.81374755, -0.295523  ],
       [ 0.0760074 , -0.09161964, -0.50065258, ..., -0.86094558,
         1.11029796,  0.88655418],
       [ 0.69370081,  0.03334961, -1.00153655, ...,  1.05477111,
         1.51385406, -0.96764699]])

In [87]:
assert np.allclose(gvecs, g_id11.T)

In [93]:
g_me = mytrans.tth_eta_omega_to_g(cf.tth, cf.eta, cf.omega,  pars.get('wedge'), pars.get('chi'), pars.get('wavelength'),)

In [94]:
g_me

Array([[-0.32244418,  0.0760074 ,  0.69370081],
       [-0.10574296, -0.09161964,  0.03334961],
       [ 1.04111012, -0.50065258, -1.00153655],
       ...,
       [-1.08158685, -0.86094558,  1.05477111],
       [ 0.81374755,  1.11029796,  1.51385406],
       [-0.295523  ,  0.88655418, -0.96764699]], dtype=float64)

In [95]:
assert np.allclose(gvecs, g_me)

In [96]:
assert np.allclose(gvecs, mytrans.sample_to_lab(mytrans.lab_to_sample(gvecs, cf.omega, pars.get('wedge'), pars.get('chi')), cf.omega, pars.get('wedge'), pars.get('chi')))

In [99]:
# g-vectors to (tth, eta, omega)
tth, [eta_one, eta_two], [omega1, omega2] = mytrans.g_to_tth_eta_omega(gvecs, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), )

In [100]:
assert np.allclose(tth, cf.tth)
assert all(np.logical_or(np.isclose(cf.eta, eta_one), np.isclose(cf.eta, eta_two)))
assert all(np.logical_or(np.isclose(cf.omega, omega1 % 360), np.isclose(cf.omega, omega2 % 360)))

In [102]:
# (tth, eta, omega) to g-vectors
gvecs_loop = mytrans.tth_eta_omega_to_g(tth, eta_one, omega1, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), )
assert np.allclose(gvecs_loop, gvecs)

In [104]:
# (tth, eta, omega) to g-vectors
gvecs_loop = mytrans.tth_eta_omega_to_g(tth, eta_two, omega2, pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), )
assert np.allclose(gvecs_loop, gvecs)

In [105]:
k_me = mytrans.xyz_lab_to_k(np.column_stack((cf.xl, cf.yl, cf.zl)), cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))), pars.get('wedge'), pars.get('chi'), pars.get('wavelength'))

In [106]:
k_me

Array([[-0.08408603, -0.32073882,  0.69355918],
       [-0.00294359,  0.13922682,  0.03598828],
       [-0.3326114 ,  1.13769147, -0.96574662],
       ...,
       [-0.43021346, -1.31279195,  1.05597973],
       [-0.59570555, -1.22609488,  1.52595118],
       [-0.25748682,  0.92782587, -0.93941074]], dtype=float64)

In [107]:
sc_me, fc_me = mytrans.k_to_det(k_me, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))), pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [108]:
sc_me

Array([ 706.36364282, 1095.59005447, 1737.43641177, ...,  420.49921996,
         60.67763906, 1707.04487671], dtype=float64)

In [109]:
xyz_me = mytrans.k_to_xyz_lab(k_me, cf.omega, jnp.array((pars.get('t_x'), pars.get('t_y'), pars.get('t_z'))), pars.get('wedge'), pars.get('chi'), pars.get('wavelength'), pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                              pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                              pars.get('tilt_x'),
                                              pars.get('distance'),
                                              pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [110]:
assert np.allclose(xyz_me, np.column_stack((cf.xl, cf.yl, cf.zl)))

In [111]:
xyz_me

Array([[151958.95858793, -14208.67191001,  30730.29224767],
       [152561.66508866,   6047.4053718 ,   1566.62441148],
       [153248.65484073,  54801.98503509, -46516.87271179],
       ...,
       [152230.30356463, -64802.00555688,  52130.04119328],
       [151273.9033695 , -63557.26217465,  79101.27067311],
       [153391.83355705,  43703.99769104, -44244.85757835]],      dtype=float64)