The goal of this notebook is to examine a FOD output and visualize it without using any purpose made visualization tool, requiring me to prove that I understand how to interpret the FOD coefficients. It succeeds in using vtk.js to render FODs that look exactly like the ones shown by mrview.

In [None]:
from pathlib import Path
import numpy as np
import nibabel as nib
from scipy.special import sph_harm
from IPython.display import display, Javascript

In [None]:
fod_path = Path('csd_output_mrtrix_msmt/fod/WM/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi_wmfod.nii.gz')
fod_dipy_path = Path('csd_output_dipy/fod/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi_fod_mrtrixResponse.nii.gz')
dwi_path = Path('extracted_images/NDARINV1JXDFV9Z_baselineYear1Arm1_ABCD-MPROC-DTI_20161206184105/sub-NDARINV1JXDFV9Z/ses-baselineYear1Arm1/dwi/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi.nii')

In [None]:
dwi = nib.load(dwi_path)
dwi_array = dwi.get_fdata()

In [None]:
fod = nib.load(fod_path)
fod_array = fod.get_fdata()
fod_dipy = nib.load(fod_dipy_path)
fod_dipy_array = fod_dipy.get_fdata()

The mrtrix image can be viewed in mrview as follows:

```sh
mrview extracted_images/NDARINV1JXDFV9Z_baselineYear1Arm1_ABCD-MPROC-DTI_20161206184105/sub-NDARINV1JXDFV9Z/ses-baselineYear1Arm1/dwi/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi.nii -odf.load_sh csd_output_mrtrix_msmt/fod/WM/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi_wmfod.nii.gz
```

The dipy generated fods (from ordinary CSD) can be viewed as follows:
```sh
mrview extracted_images/NDARINV1JXDFV9Z_baselineYear1Arm1_ABCD-MPROC-DTI_20161206184105/sub-NDARINV1JXDFV9Z/ses-baselineYear1Arm1/dwi/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi.nii -odf.load_sh csd_output_dipy/fod/sub-NDARINV1JXDFV9Z_ses-baselineYear1Arm1_run-01_dwi_fod_mrtrixResponse.nii.gz
```

In order to get the voxel indices displayed in the mrview voxel info area to match array indices of `dwi_array`, the first axis needs to be reversed:

In [None]:
dwi_array = dwi_array[::-1]
fod_array = fod_array[::-1]
fod_dipy_array = fod_dipy_array[::-1]

I do this transformation because I want to use mrview to compare my FOD renders to the FOD shapes as they are intepreted by the framework that generated them (mrtrix). I don't know why the first index gets reversed. Maybe it has to do with the minus sign in the affine:

In [None]:
dwi.affine

🤷

In [None]:
num_theta = 200
num_phi = 200
thetas = np.linspace(0,np.pi,num=num_theta,endpoint=True) # Note that we include endpoint for theta! The total number of them is still num_theta
phis = np.linspace(0,2*np.pi,num=num_phi,endpoint=False)
th, ph = np.meshgrid(thetas, phis)

In [None]:
sphere_points = np.stack([np.sin(th) * np.cos(ph), np.sin(th) * np.sin(ph), np.cos(th)], axis=-1).reshape(-1,3)

In [None]:
f = lambda i,j : i+j*num_theta # map from pair of theta_index,phi_index to a flattened single index
polys_list = []
for i in range(num_theta):
    for j in range(num_phi):
        polys_list += [4,f(i,j), f((i+1)%num_theta, j), f((i+1)%num_theta, (j+1)%num_phi), f(i, (j+1)%num_phi)]

In [None]:
def sph_harm_l_m(l_max):
    for l in range(0,l_max+1,2):
        for m in range(-l,l+1):
            yield l,m
l, m = np.array(list(sph_harm_l_m(8)),dtype=int).T

In [None]:
# follows formula at https://mrtrix.readthedocs.io/en/latest/concepts/spherical_harmonics.html#storage-conventions
def sph_harm_real(m,l,ph,th):
    y = sph_harm(m,l,ph,th)
    ynegm = sph_harm(-m,l,ph,th)
    y = np.where(m<0,np.sqrt(2)*np.imag(ynegm),y)
    y = np.where(m>0,np.sqrt(2)*np.real(y),y)
    return np.real_if_close(y)

In [None]:
sph_harm_vals = sph_harm_real(m[:,np.newaxis],l[:,np.newaxis],ph.reshape(1,-1),th.reshape(1,-1))

In [None]:
vtk_js_viewer_code = """
const script = document.createElement('script');
script.src = 'https://unpkg.com/vtk.js';
script.onload = () => {
  const renderWindow = vtk.Rendering.Core.vtkRenderWindow.newInstance();
  const renderer = vtk.Rendering.Core.vtkRenderer.newInstance({ background: [0,0,0] });
  const actor = vtk.Rendering.Core.vtkActor.newInstance();
  const mapper = vtk.Rendering.Core.vtkMapper.newInstance();

  const polydata = vtk.Common.DataModel.vtkPolyData.newInstance();
  polydata.getPoints().setData(Float32Array.from(pointsDataFromPython), 3);
  polydata.getPolys().setData(Uint32Array.from(polysDataFromPython));
  const normalsFilter = vtk.Filters.Core.vtkPolyDataNormals.newInstance();
  normalsFilter.setInputData(polydata);

  renderWindow.addRenderer(renderer);
  renderer.addActor(actor);
  actor.setMapper(mapper);
  mapper.setInputConnection(normalsFilter.getOutputPort());
  renderer.resetCamera();
  
  const openGLRenderWindow = vtk.Rendering.OpenGL.vtkRenderWindow.newInstance();
  renderWindow.addView(openGLRenderWindow);
  
  const container = document.createElement('div');
  container.style.width = '800px';
  container.style.height = '600px';
  element.appendChild(container);
  openGLRenderWindow.setContainer(container);
  
  const { width, height } = container.getBoundingClientRect();
  openGLRenderWindow.setSize(width, height);
  
  const interactor = vtk.Rendering.Core.vtkRenderWindowInteractor.newInstance();
  interactor.setView(openGLRenderWindow);
  interactor.initialize();
  interactor.bindEvents(container);
  
  const interactorStyle = vtk.Interaction.Style.vtkInteractorStyleTrackballCamera.newInstance();
  interactor.setInteractorStyle(interactorStyle);
  
  renderWindow.render();
};
document.head.appendChild(script);
"""

def view_voxel(i,j,k,fod_array):
    fod_vals = (fod_array[i,j,k] @ sph_harm_vals)
    scaled_sphere_pts = fod_vals[:,np.newaxis] * sphere_points
    
    js_code = f"""
    const pointsDataFromPython = {list(scaled_sphere_pts.reshape(-1))};
    const polysDataFromPython = {polys_list};
    """
    js_code += vtk_js_viewer_code
    display(Javascript(js_code))

In [None]:
view_voxel(70,70,84,fod_array)

In [None]:
view_voxel(70,70,84,fod_dipy_array)