# 11 - Other Diffusion Models: CSD, DKI, and Q-Ball (Dipý-based)

This notebook explores several other important diffusion models beyond DTI and NODDI, focusing on implementations that are primarily wrappers around functionalities provided by the Dipý library. These models allow for more advanced characterization of diffusion properties:

*   **Diffusion Kurtosis Imaging (DKI):** Extends DTI to quantify the non-Gaussianity of water diffusion, providing metrics sensitive to microstructural complexity. Requires multi-shell data.
*   **Constrained Spherical Deconvolution (CSD):** A technique to estimate the fiber Orientation Distribution Function (fODF) in voxels containing multiple fiber populations (crossing fibers). Typically requires HARDI (High Angular Resolution Diffusion Imaging) data.
*   **Q-Ball Imaging (QBI):** Another method for reconstructing ODFs from HARDI data, often used to resolve crossing fibers.

The `diffusemri` library provides convenient classes that wrap Dipý's implementations of these models.

In [None]:
import os
import shutil
import numpy as np
import nibabel as nib # Though direct NIfTI I/O might not be prominent here
import matplotlib.pyplot as plt

# Dipy imports
from dipy.core.gradients import gradient_table, generate_bvecs
from dipy.data import get_sphere # For ODF visualization/peak extraction concepts
from dipy.reconst import dki as dipy_dki # For DKI model constants if needed
from dipy.sims.voxel import multi_tensor_dki # For more accurate DKI synthetic data (optional, advanced)
from dipy.sims.voxel import multi_tensor # For CSD/QBall synthetic data (optional, advanced)

# diffusemri library imports
from models.dki import DkiModel
from models.csd import CsdModel # Ensure this is the class name for CSD model wrapper
from models.qball import QballModel

# Setup a temporary directory 
TEMP_DIR = "temp_other_models_example"
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR)

print(f"Temporary directory for examples: {os.path.abspath(TEMP_DIR)}")

# Helper function for plotting slices
def show_slice(data_vol, slice_idx=None, title="", cmap='viridis', vmin=None, vmax=None):
    data_to_show = None
    if data_vol.ndim == 3:
        s_idx = slice_idx if slice_idx is not None else data_vol.shape[2] // 2
        data_to_show = data_vol[:, :, s_idx]
    elif data_vol.ndim == 2:
        data_to_show = data_vol
    else:
        print(f"Cannot display data with {data_vol.ndim} dimensions with this plotter.")
        return
    plt.figure(figsize=(6,5))
    plt.imshow(data_to_show.T, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.xlabel("X voxel index"); plt.ylabel("Y voxel index")
    plt.colorbar(label="Metric Value")
    plt.show()

## Part 1: Diffusion Kurtosis Imaging (DKI)

DKI requires multi-shell data (at least two non-zero b-values, ideally more, plus b0s) to estimate parameters related to the deviation of water diffusion from a Gaussian distribution.

In [None]:
# Define dimensions and affine for DKI synthetic data
dims_dki = (12, 12, 3)  # x, y, z slices
affine_dki = np.diag([2.0, 2.0, 2.0, 1.0]) # 2mm isotropic voxels

# Create a multi-shell scheme for DKI (e.g., b0, b1000, b2000)
bvals_dki_list = ([0]*2 +         # 2 b0 volumes
                  [1000]*8 +      # 8 directions at b=1000
                  [2000]*10)      # 10 directions at b=2000
bvals_dki_np = np.array(bvals_dki_list, dtype=float)
total_volumes_dki = len(bvals_dki_np)

bvecs_dki_np = np.zeros((total_volumes_dki, 3))
if total_volumes_dki - 2 > 0: # If there are DWI volumes
    # Generate b-vectors for the DWI shells (excluding b0s)
    dwi_bvecs_temp = generate_bvecs(total_volumes_dki - 2) 
    bvecs_dki_np[2:] = dwi_bvecs_temp

gtab_dki = gradient_table(bvals_dki_np, bvecs_dki_np, b0_threshold=50)
print(f"DKI GradientTable: {total_volumes_dki} volumes, unique b-values: {np.unique(gtab_dki.bvals)}")

# Create simplified synthetic DKI data
# Note: Accurately simulating DKI signal is complex. This is a heuristic for demonstration.
# For precise simulations, consider Dipy's `dipy.sims.voxel.multi_tensor_dki`.
S0_dki = 150.0
mean_diffusivity = 0.0012  # An assumed MD for all voxels for simplicity
mk_background = 0.2       # Mean Kurtosis for background voxels
mk_roi = 0.9              # Mean Kurtosis for a central Region of Interest (ROI)

dwi_data_dki = np.zeros(dims_dki + (total_volumes_dki,), dtype=np.float32)
x_c, y_c = dims_dki[0]//2, dims_dki[1]//2

for i in range(dims_dki[0]):
    for j in range(dims_dki[1]):
        for k_slice in range(dims_dki[2]):
            is_roi = ( (i-x_c)**2 + (j-y_c)**2 < (dims_dki[0]/4)**2 and k_slice == 1 ) # ROI in slice 1
            current_mk = mk_roi if is_roi else mk_background
            for vol_idx in range(total_volumes_dki):
                b_val = gtab_dki.bvals[vol_idx]
                if b_val == 0:
                    dwi_data_dki[i,j,k_slice,vol_idx] = S0_dki
                else:
                    # Simplified DKI signal: S = S0 * exp(-b*MD + (1/6)*b^2*MD^2*MK)
                    # This is a Taylor expansion and an oversimplification.
                    signal = S0_dki * np.exp(-b_val * mean_diffusivity + 
                                           (1/6) * (b_val**2) * (mean_diffusivity**2) * current_mk)
                    dwi_data_dki[i,j,k_slice,vol_idx] = signal

# Add some noise
dwi_data_dki += np.random.normal(loc=0, scale=S0_dki * 0.03, size=dwi_data_dki.shape)
dwi_data_dki[dwi_data_dki < 0] = 0 # Ensure no negative signals
dwi_data_dki = dwi_data_dki.astype(np.float32)

print(f"Generated synthetic DKI DWI data with shape: {dwi_data_dki.shape}")
show_slice(dwi_data_dki[..., 0], slice_idx=1, title="Synthetic DKI DWI (b0, Slice 1)")

In [None]:
print("\nFitting DKI model...")
# The DkiModel from diffusemri.models.dki is a wrapper around Dipy's DKI implementation.
dki_model_wrapper = DkiModel(gtab_dki) 

# A mask is optional for Dipy's DKI fitting but generally recommended for real data.
mask_dki_np = np.ones(dims_dki, dtype=bool) # Process all voxels for this synthetic example

try:
    # The .fit() method of our wrapper calls Dipy's dkimodel.fit()
    dki_fit_results = dki_model_wrapper.fit(dwi_data_dki, mask=mask_dki_np)
    print("DKI model fitting completed.")

    # Access DKI metrics through properties of the DkiModel wrapper instance
    # These properties internally access the Dipy DkiFit object's methods (e.g., dki_fit_results.mk())
    mk_map = dki_model_wrapper.mk # Mean Kurtosis
    # ak_map = dki_model_wrapper.ak # Axial Kurtosis
    # rk_map = dki_model_wrapper.rk # Radial Kurtosis
    # fa_from_dki = dki_model_wrapper.fa # FA from the tensor part of DKI

    if mk_map is not None:
        show_slice(mk_map, slice_idx=1, title="Mean Kurtosis (MK) - Slice 1", cmap='hot', vmin=0, vmax=1.2)
    else:
        print("MK map not found or not computed.")
    
    # print(f"FA from DKI (sample value): {fa_from_dki[dims_dki[0]//2, dims_dki[1]//2, 1] if fa_from_dki is not None else 'N/A'}")

except Exception as e:
    print(f"An error occurred during DKI fitting or metric extraction: {e}")
    print("Ensure Dipý is correctly installed with all its dependencies.")

## Part 2: Constrained Spherical Deconvolution (CSD)

CSD is used to estimate fODFs, which is particularly useful for resolving crossing fibers. It typically requires HARDI data (a single, relatively high b-value shell with many gradient directions).

In [None]:
# Define dimensions and affine for CSD/QBall HARDI synthetic data
dims_hardi = (12, 12, 3)  # x, y, z slices
affine_hardi = np.diag([2.0, 2.0, 2.0, 1.0])

# Create a HARDI-like scheme (e.g., 1 b0, 32 directions at b=1500)
num_b0s_hardi = 1
num_dirs_hardi = 32
bval_shell_hardi = 1500.0

bvals_hardi_np = np.concatenate(([0]*num_b0s_hardi, np.ones(num_dirs_hardi) * bval_shell_hardi))
bvecs_hardi_np = np.vstack((np.zeros((num_b0s_hardi, 3)), generate_bvecs(num_dirs_hardi)))
gtab_hardi = gradient_table(bvals_hardi_np, bvecs_hardi_np, b0_threshold=50)
print(f"CSD/HARDI GradientTable: {len(bvals_hardi_np)} volumes, b-value for DWI: {bval_shell_hardi}")

# Create simplified synthetic CSD data
# Simulating crossing fibers accurately is complex. Here's a very basic approach.
# For precise simulations, use Dipy's `dipy.sims.voxel.multi_tensor`.
S0_csd = 180.0
d_parallel = 0.0018  # Diffusivity along fibers
d_perpendicular = 0.0004 # Diffusivity perpendicular to fibers

dwi_data_csd = np.zeros(dims_hardi + (len(bvals_hardi_np),), dtype=np.float32)
x_c_csd, y_c_csd = dims_hardi[0]//2, dims_hardi[1]//2

for i in range(dims_hardi[0]):
    for j in range(dims_hardi[1]):
        for k_slice in range(dims_hardi[2]):
            # Central region in slice 1 for simulated crossing fibers
            is_crossing_region = ( (i-x_c_csd)**2 + (j-y_c_csd)**2 < (dims_hardi[0]/5)**2 and k_slice == 1 )
            
            for vol_idx in range(len(bvals_hardi_np)):
                b_val = gtab_hardi.bvals[vol_idx]
                b_vec = gtab_hardi.bvecs[vol_idx]
                
                if b_val == 0:
                    dwi_data_csd[i,j,k_slice,vol_idx] = S0_csd
                else:
                    if is_crossing_region:
                        # Simulate two fiber populations: one along X [1,0,0], one along Y [0,1,0]
                        adc_fiber1 = (b_vec[0]**2 * d_parallel + 
                                      (b_vec[1]**2 + b_vec[2]**2) * d_perpendicular)
                        adc_fiber2 = (b_vec[1]**2 * d_parallel + 
                                      (b_vec[0]**2 + b_vec[2]**2) * d_perpendicular)
                        signal = S0_csd * (0.5 * np.exp(-b_val * adc_fiber1) + 
                                          0.5 * np.exp(-b_val * adc_fiber2))
                    else:
                        # Single fiber population (e.g., along X-axis)
                        adc_single_fiber = (b_vec[0]**2 * d_parallel + 
                                            (b_vec[1]**2 + b_vec[2]**2) * d_perpendicular)
                        signal = S0_csd * np.exp(-b_val * adc_single_fiber)
                    dwi_data_csd[i,j,k_slice,vol_idx] = signal

dwi_data_csd += np.random.normal(loc=0, scale=S0_csd * 0.03, size=dwi_data_csd.shape)
dwi_data_csd[dwi_data_csd < 0] = 0
dwi_data_csd = dwi_data_csd.astype(np.float32)

print(f"Generated synthetic CSD/HARDI DWI data with shape: {dwi_data_csd.shape}")
show_slice(dwi_data_csd[...,0], slice_idx=1, title="Synthetic CSD/HARDI DWI (b0, Slice 1)")

In [None]:
print("\nFitting CSD model...")
# The CsdModel wrapper from diffusemri.models.csd uses Dipy's CSD implementation.
# It can auto-estimate the response function if `response=None` (can be slow).
# For faster execution in a notebook, providing a fixed response is often better for demos.

# Option 1: Auto-estimate response function (can be slow for larger data)
# from dipy.reconst.csdeconv import auto_response_ssst
# response_auto, ratio_auto = auto_response_ssst(gtab_hardi, dwi_data_csd, roi_radii=6) # roi_radii might need adjustment
# csd_model_wrapper_auto = CsdModel(gtab_hardi, response=response_auto)

# Option 2: Use a fixed, typical single-fiber response for demonstration
# (Eigenvalues [d_parallel, d_perpendicular, d_perpendicular], S0 value)
fixed_response_csd = (np.array([d_parallel, d_perpendicular, d_perpendicular]), S0_csd)
csd_model_wrapper = CsdModel(gtab_hardi, response=fixed_response_csd, sh_order_max=6) # sh_order_max for CSD

mask_csd_np = np.ones(dims_hardi, dtype=bool)

try:
    csd_fit_results = csd_model_wrapper.fit(dwi_data_csd, mask=mask_csd_np)
    print("CSD model fitting completed.")

    # Get Generalized Fractional Anisotropy (GFA) from the CSD fit
    gfa_map_csd = csd_model_wrapper.gfa # Access GFA via property
    if gfa_map_csd is not None:
        show_slice(gfa_map_csd, slice_idx=1, title="GFA (from CSD) - Slice 1", cmap='viridis', vmin=0, vmax=0.8)
    else:
        print("GFA map from CSD not found or not computed.")

    # ODFs can also be obtained from csd_model_wrapper.odf(sphere) for visualization or tractography
    # example_sphere = get_sphere('repulsion724')
    # odfs_csd = csd_model_wrapper.odf(example_sphere) # Shape (X,Y,Z, N_sphere_vertices)
    # if odfs_csd is not None: print(f"ODFs shape: {odfs_csd.shape}")

except Exception as e:
    print(f"An error occurred during CSD fitting: {e}")

## Part 3: Q-Ball Imaging (QBI)

Q-Ball Imaging is another technique to reconstruct ODFs from HARDI data. The `diffusemri` library wraps Dipy's Q-Ball model (Constant Solid Angle type).

In [None]:
print("\nFitting Q-Ball model...")
# We can reuse the HARDI data (dwi_data_csd, gtab_hardi) created for the CSD example.

# The QballModel wrapper from diffusemri.models.qball uses Dipy's QballModel.
# sh_order_max is an important parameter; 6 or 8 are common.
qball_model_wrapper = QballModel(gtab_hardi, sh_order_max=6)

mask_qball_np = np.ones(dims_hardi, dtype=bool) # Reuse or create a mask

try:
    qball_fit_results = qball_model_wrapper.fit(dwi_data_csd, mask=mask_qball_np)
    print("Q-Ball model fitting completed.")

    # Get Generalized Fractional Anisotropy (GFA) from the Q-Ball fit
    gfa_map_qball = qball_model_wrapper.gfa # Access GFA via property
    if gfa_map_qball is not None:
        show_slice(gfa_map_qball, slice_idx=1, title="GFA (from Q-Ball) - Slice 1", cmap='viridis', vmin=0, vmax=0.8)
    else:
        print("GFA map from Q-Ball not found or not computed.")

    # ODFs can also be obtained: qball_model_wrapper.odf(sphere)
    # example_sphere = get_sphere('repulsion724')
    # odfs_qball = qball_model_wrapper.odf(example_sphere) # Shape (X,Y,Z, N_sphere_vertices)
    # if odfs_qball is not None: print(f"Q-Ball ODFs shape: {odfs_qball.shape}")

except Exception as e:
    print(f"An error occurred during Q-Ball fitting: {e}")

## Discussion

*   **Generalized Fractional Anisotropy (GFA):** GFA is a scalar measure derived from ODFs, somewhat analogous to FA from DTI. It quantifies the anisotropy of the ODF, with higher values indicating more directed ODFs (sharper peaks).

*   **Orientation Distribution Functions (ODFs):** Both CSD and Q-Ball models estimate ODFs for each voxel. These ODFs represent the angular distribution of fiber orientations and are essential for advanced tractography algorithms that can resolve crossing fibers.

*   **Model Choice & Data Requirements:**
    *   **DKI** needs multi-shell data to capture non-Gaussian effects.
    *   **CSD** typically requires at least one high b-value shell with good angular resolution (many directions) and benefits from multi-shell data for multi-tissue CSD (not covered by this basic `CsdModel` wrapper directly, which focuses on single-tissue CSD from Dipy unless extended).
    *   **Q-Ball** also needs HARDI data (single or multiple shells).

*   **Computational Cost:** CSD, especially with response function estimation, can be more computationally intensive than DTI or basic Q-Ball.

## CLI Usage

Currently, the `diffusemri` library primarily provides dedicated CLI tools for DTI (`run_dti_fit.py`) and PyTorch-based NODDI (`run_noddi_fit.py`) fitting.

The Dipy-based models (DKI, CSD, Q-Ball) demonstrated in this notebook are accessed via their Python class wrappers (`DkiModel`, `CsdModel`, `QballModel`). If you need to run these models in a command-line fashion, you would typically write a custom Python script that:
1.  Parses command-line arguments (e.g., input DWI, bvals, bvecs, mask, output prefix).
2.  Loads the data using `data_io` utilities or `nibabel`/`dipy` directly.
3.  Instantiates the appropriate model class (`DkiModel`, `CsdModel`, `QballModel`).
4.  Calls the `.fit()` method.
5.  Saves the desired output maps (e.g., MK, GFA, ODFs as SH coefficients) to files, likely using `nibabel`.

There isn't a generic `run_dipy_model_fit.py` in `diffusemri` at this time that covers all these models with a unified CLI interface, as their parameters and typical use-cases can vary significantly.

## Cleanup

Remove the temporary directory.

In [None]:
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"Cleaned up temporary directory: {TEMP_DIR}")