In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", 'cpu')

from ImageD11.unitcell import unitcell
from ImageD11.parameters import AnalysisSchema
from ImageD11.columnfile import columnfile

In [None]:
import ImageD11
asc = ImageD11.unitcell.Phases.from_json('pars.json')
ucell = asc.unitcells['Cu']
ucell

In [None]:
pars = asc.geometry_pars_obj

In [None]:
npx = 2_048

In [None]:
pars.set('distance', 5e3)  # 130 mm
pars.set('wavelength', 0.1771)
pars.set('y_center', npx//2)
pars.set('z_center', npx//2)
pars.set('chi', 0)
pars.set('wedge', 0)
pars.set('y_size', 1)
pars.set('z_size', 1)
pars.set('tilt_x', 0)
pars.set('tilt_y', 0)
pars.set('tilt_z', 0)

In [None]:
pars.get_parameters()

In [None]:
import Dans_Diffraction as dif

crystal = dif.Crystal('EntryWithCollCode7954_scaled.cif')
crystal.Scatter.setup_scatter(scattering_type='xray',
                              wavelength_a=asc.geometry_pars_obj.get('wavelength'), 
                              powder_units='twotheta',
                              min_twotheta=0.1,
                              max_twotheta=30,
                              min_theta=-np.inf,
                              max_theta=np.inf
                             )

In [None]:
ref_ucell = ucell

In [None]:
# ref_ucell = unitcell([2.8, 2.8, 2.8, 90, 90, 90], 229)  # BCC Fe

In [None]:
ref_ucell.makerings(1.5)
hkls = []
mults = []
for i, d in enumerate(ref_ucell.ringds):
    hklring = ref_ucell.ringhkls[d]
    mults.append( len(hklring) )
    print(i, len(hklring),hklring[0],d)
    hkls += list(hklring)

hkls = np.array(hkls)

In [None]:
intensities = crystal.Scatter.intensity(hkls)

In [None]:
NX = 75
# NY = 30
# NZ = 5
NY = NZ = NX

In [None]:
from scipy.spatial.transform import Rotation as R

ng = 1

U = R.random(ng).as_matrix()

U.shape

In [None]:
nvoxels = NX * NY * NZ

In [None]:
U = np.broadcast_to(U, (nvoxels, 3, 3))
U.shape

In [None]:
U[0], U[1]

In [None]:
UB = U @ ref_ucell.B

In [None]:
inten_reshape = np.broadcast_to(intensities, (len(UB),len(hkls))).T
inten_reshape.shape

In [None]:
gves = (UB @ hkls.T).transpose(2, 0, 1)
gves.shape

In [None]:
import transform as mytrans

In [None]:
import importlib
importlib.reload(mytrans)

In [None]:
# generate discrete voxels - just a list of grains with different origins

In [None]:
ijk = np.mgrid[:NX, :NY, :NZ]

In [None]:
ijk.shape

In [None]:
voxel_size = 0.7  # um

In [None]:
tx, ty, tz = ijk * voxel_size
tx -= NX/2 * voxel_size
ty -= NY/2 * voxel_size
tz -= NZ/2 * voxel_size

In [None]:
tx = tx.flatten()
ty = ty.flatten()
tz = tz.flatten()

In [None]:
tx.min(), tx.max()

In [None]:
# tx = np.random.random(ng) * 1000 - 500
# ty = np.random.random(ng) * 1000 - 500
# tz = np.random.random(ng) * 100 - 50

In [None]:
origin_sample = np.column_stack((tx, ty, tz))
# origin_sample = np.zeros((ng,3))
origin_sample.shape

In [None]:
origin_sample = np.broadcast_to(origin_sample, gves.shape)
origin_sample.shape

In [None]:
gves = gves.reshape(-1, 3)
origin_sample = origin_sample.reshape(-1, 3)
inten_reshape = inten_reshape.reshape(-1)

In [None]:
gves.shape, origin_sample.shape, inten_reshape.shape

In [None]:
# simulate a beam energy spread

In [None]:
e_spread = pars.get('wavelength')/1000

In [None]:
n_e = 10

In [None]:
e_bins = np.linspace(pars.get('wavelength')-(e_spread*5), pars.get('wavelength')+(e_spread*5), n_e)

In [None]:
import scipy
e_scales = scipy.stats.norm.pdf(e_bins, pars.get('wavelength'), e_spread)

In [None]:
e_scales/=e_scales.max()

In [None]:
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
ax.plot(e_bins, e_scales)
plt.show()

In [None]:
def concat_interleave(a, b):
    c = jnp.vstack((a,b)).reshape((-1,),order='F')
    return c

In [None]:
det_trans, beam_cen_shift, x_distance_shift = mytrans.detector_transforms(pars.get('y_center'), pars.get('y_size'), pars.get('tilt_y'),
                                                                          pars.get('z_center'), pars.get('z_size'), pars.get('tilt_z'),
                                                                          pars.get('tilt_x'),
                                                                          pars.get('distance'),
                                                                          pars.get('o11'), pars.get('o12'), pars.get('o21'),pars.get('o22'))

In [None]:
k_in_norm = jnp.array([1., 0., 0])

In [None]:
all_sc = []
all_fc = []
all_omega = []
all_inten = []

for lam, scale in zip(e_bins, e_scales):
    k_in_norm = mytrans._scale_norm_k(k_in_norm, lam)
    
    (sc1, sc2), (fc1, fc2), (omega1, omega2), valid = mytrans.q_and_origin_sample_to_det(gves, origin_sample, k_in_norm,
                                                                        pars.get('wedge'), pars.get('chi'), lam,
                                                                        det_trans, beam_cen_shift, x_distance_shift)
    
    sc_calc = concat_interleave(sc1[valid], sc2[valid])
    fc_calc = concat_interleave(fc1[valid], fc2[valid])
    omega_calc = concat_interleave(omega1[valid], omega2[valid])
    inten_calc = concat_interleave(inten_reshape[valid], inten_reshape[valid])
    
    m = (sc_calc > 0) & (sc_calc < npx) & (fc_calc > 0) & (fc_calc < npx)
    
    sc = sc_calc[m]
    fc = fc_calc[m]
    omega = omega_calc[m]
    inten = scale*inten_calc[m]

    all_sc.append(sc)
    all_fc.append(fc)
    all_omega.append(omega)
    all_inten.append(inten)

sc = jnp.concatenate(all_sc)
fc = jnp.concatenate(all_fc)
omega = jnp.concatenate(all_omega)
inten = jnp.concatenate(all_inten)

In [None]:
from jax.experimental import sparse

def sparse_histogram_3d(coords, intensities, bins):
    """
    Simple sparse histogram using bin centers.
    
    Args:
        coords: (N, 3) array of 3D coordinates
        intensities: (N,) array of intensity values
        bins: list of 3 arrays, one for each dimension's bin centers
    
    Returns:
        Array of shape (K, 4) where K is number of non-empty bins
        Each row is [bin_center_x, bin_center_y, bin_center_z, bin_center_w, intensity_sum]
    """
    coords = jnp.asarray(coords)
    intensities = jnp.asarray(intensities)
    
    # Find which bin each coordinate falls into using searchsorted on bin edges
    bin_indices = []
    for dim in range(3):
        bin_centers = jnp.asarray(bins[dim])
        # Convert centers to edges (assuming uniform spacing)
        if len(bin_centers) >= 2:
            spacing = bin_centers[1] - bin_centers[0]
            bin_edges = jnp.linspace(bin_centers[0] - spacing/2, 
                                   bin_centers[-1] + spacing/2, 
                                   len(bin_centers) + 1)
        else:
            bin_edges = jnp.array([bin_centers[0] - 0.5, bin_centers[0] + 0.5])
        
        # Find which bin each coordinate falls into
        indices = jnp.searchsorted(bin_edges, coords[:, dim], side='right') - 1
        indices = jnp.clip(indices, 0, len(bin_edges) - 2)
        bin_indices.append(indices)
    
    # Stack bin indices
    bin_combinations = jnp.stack(bin_indices, axis=1)
    
    # Get unique bin combinations and their sums
    unique_bins, inverse_indices = jnp.unique(bin_combinations, return_inverse=True, axis=0)
    bin_sums = jax.ops.segment_sum(intensities, inverse_indices, num_segments=len(unique_bins))
    
    # Extract the bin centers for each unique bin
    result = jnp.column_stack([
        bins[0][unique_bins[:, 0]],  # s center
        bins[1][unique_bins[:, 1]],  # f center
        bins[2][unique_bins[:, 2]],  # w center
        bin_sums                       # intensity
    ])
    
    return result

In [None]:
ostep = 0.1
# coarse
omin = -180
omax = 181
obincens = np.linspace(omin, omax, int((omax-omin)/ostep))
obinedges = np.arange(omin - ostep / 2, omax + ostep / 1.9, ostep)

sc_bins = np.arange(npx)
fc_bins = np.arange(npx)

In [None]:
coords = jnp.column_stack((sc, fc, omega))
bins = (sc_bins, fc_bins, obincens)
res = sparse_histogram_3d(coords, inten, bins)

In [None]:
sc, fc, omega, inten = res.T

In [None]:
m_ff = omega == omega.min()
m_ff.sum()

In [None]:
m_ff = omega == jnp.unique(omega)[100]
m_ff.sum()

In [None]:
im, _, _ = np.histogram2d(sc[m_ff], fc[m_ff], weights=inten[m_ff], bins=(sc_bins, fc_bins))

In [None]:
from matplotlib import pyplot as plt
%matplotlib ipympl

fig, ax = plt.subplots()
ax.imshow(im+1e-16, norm='log', vmin=1e4, vmax=1e6, interpolation='nearest', origin='lower')
ax.set_aspect(1)
plt.show()

In [None]:
im, _, _ = np.histogram2d(sc, fc, weights=inten, bins=(sc_bins, fc_bins))

from matplotlib import pyplot as plt
%matplotlib ipympl

fig, ax = plt.subplots()
ax.imshow(im+1e-16, norm='log', vmin=1e4, vmax=1e7, interpolation='nearest', origin='lower')
ax.set_aspect(1)
plt.show()

In [None]:
from matplotlib import pyplot as plt
%matplotlib ipympl

fig, ax = plt.subplots(constrained_layout=True, figsize=(10,10))
ax.scatter(fc[m_ff], sc[m_ff], s=1, c=inten[m_ff])

ax.set_aspect(1)
ax.set(xlim=(0, npx), ylim=(0, npx))
ax.set(xlabel='fc', ylabel='sc',
       title=f"""Random grains: {ng}
Beam center (px): {pars.get('y_center')},{pars.get('z_center')}
Detector distance: {pars.get('distance')/1e6:.4} m
Wavelength: {pars.get('wavelength'):.4} A""")
plt.show()

In [None]:
from matplotlib import pyplot as plt
%matplotlib ipympl

fig, ax = plt.subplots(constrained_layout=True, figsize=(10,10))
ax.scatter(fc, sc, s=1, c=inten)

ax.set_aspect(1)
ax.set(xlim=(0, npx), ylim=(0, npx))
ax.set(xlabel='fc', ylabel='sc',
       title=f"""Random grains: {ng}
Beam center (px): {pars.get('y_center')},{pars.get('z_center')}
Detector distance: {pars.get('distance')/1e6:.4} m
Wavelength: {pars.get('wavelength'):.4} A""")
plt.show()

In [None]:
from matplotlib import pyplot as plt
%matplotlib ipympl

fig, ax = plt.subplots(constrained_layout=True, figsize=(10,10))
ax.scatter(fc, sc, s=1, c=inten)

ax.set_aspect(1)
ax.set(xlim=(0, npx), ylim=(0, npx))
ax.set(xlabel='fc', ylabel='sc',
       title=f"""Random grains: {ng}
Beam center (px): {pars.get('y_center')},{pars.get('z_center')}
Detector distance: {pars.get('distance')/1e6:.4} m
Wavelength: {pars.get('wavelength'):.4} A""")
plt.show()