In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)


from ImageD11.parameters import AnalysisSchema
from ImageD11.columnfile import columnfile

In [2]:
pars = AnalysisSchema.from_default().geometry_pars_obj

In [3]:
pars.set('tilt_x', 0.00123)
pars.set('tilt_y', -0.0345)
pars.set('tilt_z', 0.02)
pars.set('chi', 1)
pars.set('wedge', -3)

In [4]:
nrows = 100_000_000
fc = np.random.random(nrows) * 2048
sc = np.random.random(nrows) * 2048
om = np.random.random(nrows) * 360

In [5]:
cf = columnfile(new=True)
cf.nrows = nrows

cf.addcolumn(fc, 'fc')
cf.addcolumn(sc, 'sc')
cf.addcolumn(om, 'omega')

In [6]:
cf.parameters = pars

In [7]:
cf.updateGeometry()

In [8]:
cf.titles

['fc', 'sc', 'omega', 'xl', 'yl', 'zl', 'tth', 'eta', 'ds', 'gx', 'gy', 'gz']

In [9]:
import ImageD11.transform

In [10]:
ImageD11.transform.detector_rotation_matrix(pars.get('tilt_x'), pars.get('tilt_y'), pars.get('tilt_z'))

array([[ 9.99205060e-01, -1.99867662e-02, -3.44931565e-02],
       [ 1.99562335e-02,  9.99800099e-01, -1.22926776e-03],
       [ 3.45108303e-02,  5.39937080e-04,  9.99404178e-01]])

In [11]:
@jax.jit
def detector_rotation_matrix(tilt_x, tilt_y, tilt_z):
    r1 = jnp.array([[jnp.cos(tilt_z), -jnp.sin(tilt_z), 0],  # note this is r.h.
                   [jnp.sin(tilt_z), jnp.cos(tilt_z), 0],
                   [0,    0, 1]], float)
    r2 = jnp.array([[jnp.cos(tilt_y), 0, jnp.sin(tilt_y)],
                   [0, 1,   0],
                   [-jnp.sin(tilt_y), 0, jnp.cos(tilt_y)]], float)
    r3 = jnp.array([[1,          0,       0],
                   [0,  jnp.cos(tilt_x), -jnp.sin(tilt_x)],
                   [0,  jnp.sin(tilt_x), jnp.cos(tilt_x)]], float)
    r2r1 = jnp.dot(jnp.dot(r3, r2), r1)
    return r2r1

In [12]:
rmat = detector_rotation_matrix(pars.get('tilt_x'), pars.get('tilt_y'), pars.get('tilt_z'))

In [13]:
rmat.nbytes

72

In [14]:
def compute_k_vector(tth, eta, wvln):
    tth = jnp.radians(tth)
    eta = jnp.radians(eta)
    c = jnp.cos(tth / 2)  # cos theta
    s = jnp.sin(tth / 2)  # sin theta
    ds = 2 * s / wvln

    k1 =  -ds * s  # this is negative x
    k2 =  -ds * c * jnp.sin(eta)  # CHANGED eta to HFP convention 4-9-2007
    k3 = ds * c * jnp.cos(eta)

    return jnp.array([k1, k2, k3])

In [15]:
compute_k_vectors = jax.jit(jax.vmap(compute_k_vector, in_axes=[0, 0, None]))

In [29]:
k_jax = compute_k_vectors(cf.tth, cf.eta, cf.parameters.get('wavelength'))

In [17]:
np.allclose(k_jax, ImageD11.transform.compute_k_vectors(cf.tth, cf.eta, cf.parameters.get('wavelength')).T)

True

In [26]:
def compute_g_from_k_matr(k, omega, wedge, chi):
    rom = jnp.radians(omega)
    som = jnp.sin(rom)
    com = jnp.cos(rom)
    
    rwed = jnp.radians(wedge)
    swed = jnp.sin(rwed)
    cwed = jnp.cos(rwed)

    rchi = jnp.radians(chi)
    schi = jnp.sin(rchi)
    cchi = jnp.cos(rchi)

    R = jnp.array([[ com, som, 0],
                  [ -som, com, 0],
                  [0, 0, 1]])

    W = jnp.array([[ cwed, 0, swed],
                  [    0, 1,    0],
                  [-swed, 0, cwed]])

    C = jnp.array([[1,     0,    0],
                  [0,  cchi, schi],
                  [0, -schi, cchi]])

    g = R @ C @ W @ k

    return g

In [27]:
compute_g_from_k_all = jax.jit(jax.vmap(compute_g_from_k_matr, in_axes=[0, 0, None, None]))

In [30]:
compute_g_from_k_all(k_jax, cf.omega, cf.parameters.get('wedge'), cf.parameters.get('chi'))[0]

Array([ 0.07132756,  0.46433902, -0.81642077], dtype=float64)

In [31]:
cf.gx[0], cf.gy[0], cf.gz[0]

(0.07132756258253843, 0.46433902223834955, -0.8164207661221281)

In [None]:
kt = k.T

In [None]:
%%time

g = ImageD11.transform.compute_g_from_k(kt, cf.omega, cf.parameters.get('wedge'), cf.parameters.get('chi'))

In [None]:
%%time

g2 = compute_g_from_k_all(k, cf.omega, cf.parameters.get('wedge'), cf.parameters.get('chi'))