# Assemble a Diffraction Intensity Volume using Ground-truth Orientations

In [None]:
import numpy as np

from tqdm import tqdm

import h5py as h5

from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D


## Utility functions

In [None]:
# Adapted from: https://sscc.nimh.nih.gov/pub/dist/bin/linux_gcc32/meica.libs/nibabel/quaternions.py
def quat2mat(q):
    ''' Calculate rotation matrix corresponding to quaternion

    Parameters
    ----------
    q : 4 element array-like

    Returns
    -------
    M : (3,3) array
      Rotation matrix corresponding to input quaternion *q*

    Notes
    -----
    Rotation matrix applies to column vectors, and is applied to the
    left of coordinate vectors.  The algorithm here allows non-unit
    quaternions.

    References
    ----------
    Algorithm from
    http://en.wikipedia.org/wiki/Rotation_matrix#Quaternion

    Examples
    --------
    >>> import numpy as np
    >>> M = quat2mat([1, 0, 0, 0]) # Identity quaternion
    >>> np.allclose(M, np.eye(3))
    True
    >>> M = quat2mat([0, 1, 0, 0]) # 180 degree rotn around axis 0
    >>> np.allclose(M, np.diag([1, -1, -1]))
    True
    '''
    w, x, y, z = q
    Nq = w*w + x*x + y*y + z*z
    FLOAT_EPS = np.finfo(np.float).eps
    if Nq < FLOAT_EPS:
        return np.eye(3)
    
    s = 2.0 / Nq
    X = x * s
    Y = y * s
    Z = z * s
    wX = w * X
    wY = w * Y
    wZ = w * Z
    xX = x * X
    xY = x * Y
    xZ = x * Z
    yY = y * Y
    yZ = y * Z
    zZ = z * Z
    
    return np.array(
           [[ 1.0 - (yY + zZ), xY - wZ, xZ + wY ],
            [ xY + wZ, 1.0 - (xX + zZ), yZ - wX ],
            [ xZ - wY, yZ + wX, 1.0 - (xX + yY) ]])

def build_empty_intensity_grid():
    x_ = np.linspace(-63., 64., 128.)
    y_ = np.linspace(-63., 64., 128.)
    z_ = np.linspace(-63., 64., 128.)

    x, y, z = np.meshgrid(x_, y_, z_)

    intensity_coords = np.column_stack((x.flatten(), y.flatten(), z.flatten()))

    intensity_vals = np.zeros(len(intensity_coords))

    return intensity_coords, intensity_vals

def interpolate_oriented_intensity_using_diffraction_pattern(oriented_intensity_coords, intensity_vals, diffraction_pattern):
    n_oriented_intensity_coords = len(oriented_intensity_coords)
    diffraction_pattern_height = diffraction_pattern.shape[0]
    diffraction_pattern_width = diffraction_pattern.shape[1]
    
    for oriented_intensity_coord_index in range(n_oriented_intensity_coords):
        
        oriented_intensity_coord_z = oriented_intensity_coords[oriented_intensity_coord_index, 2]

        diffraction_slice_coord_z = int(round(oriented_intensity_coord_z))
        
        if diffraction_slice_coord_z == 0:
            
            oriented_intensity_coord_x = oriented_intensity_coords[oriented_intensity_coord_index, 0]
            oriented_intensity_coord_y = oriented_intensity_coords[oriented_intensity_coord_index, 1]
        
            diffraction_slice_coord_x = int(round(oriented_intensity_coord_x))
            diffraction_slice_coord_y = int(round(oriented_intensity_coord_y))
            
            diffraction_pattern_x = diffraction_slice_coord_x + diffraction_pattern_height // 2 - 1
            diffraction_pattern_y = diffraction_slice_coord_y + diffraction_pattern_width // 2 - 1
            
            if 0 <= diffraction_pattern_x and diffraction_pattern_x < diffraction_pattern_height and 0 <= diffraction_pattern_y and diffraction_pattern_y < diffraction_pattern_width: 
                intensity_vals[oriented_intensity_coord_index] += diffraction_pattern[diffraction_pattern_x, diffraction_pattern_y]


## Reconstruct the intensity

In [None]:
dataset_name = "3iyf-10K-mixed-hit-99"
downsampled_images_output_subdir = "downsample-128x128"

dataset_size = 10000

downsampled_h5_file = "/reg/data/ana03/scratch/deebanr/{}/dataset/{}/cspi_synthetic_dataset_diffraction_patterns_3iyf-10K-mixed-hit_uniform_quat_dataset-size={}_diffraction-pattern-shape=1024x1040.hdf5".format(dataset_name, downsampled_images_output_subdir, dataset_size)
h5_file = "/reg/data/ana03/scratch/deebanr/{}/dataset/cspi_synthetic_dataset_diffraction_patterns_3iyf-10K-mixed-hit_uniform_quat_dataset-size={}_diffraction-pattern-shape=1024x1040.hdf5".format(dataset_name, dataset_size)

intensity_coords, intensity_vals = build_empty_intensity_grid()

downsampled_h5_file_handle = h5.File(downsampled_h5_file, 'r')
h5_file_handle = h5.File(h5_file, 'r')

for dataset_index in tqdm(range(2)):   
    diffraction_pattern = downsampled_h5_file_handle["downsampled_diffraction_patterns"][dataset_index]
    orientation = h5_file_handle["orientations"][dataset_index]
        
    rotation_matrix_3d = quat2mat(orientation)
    oriented_intensity_coords = np.dot(intensity_coords, rotation_matrix_3d)
    
    interpolate_oriented_intensity_using_diffraction_pattern(oriented_intensity_coords, intensity_vals, diffraction_pattern)

h5_file_handle.close()


## Plot the reconstructed intensity

In [None]:
nonzero_idx = np.where(intensity_vals > 0)
nonzero_intensity_vals = intensity_vals[nonzero_idx]
nonzero_intensity_coords = intensity_coords[nonzero_idx]


In [None]:
fig = pyplot.figure()
ax = Axes3D(fig)

ax.view_init(45, -60)
ax.scatter(nonzero_intensity_coords[:, 0], nonzero_intensity_coords[:, 1], nonzero_intensity_coords[:, 2], c=nonzero_intensity_vals)
pyplot.show()
