# 12 - Deterministic Tractography

This notebook demonstrates how to perform deterministic tractography using Orientation Distribution Functions (ODFs), primarily leveraging the `diffusemri` library's capabilities. 

**What is Deterministic Tractography?**
Deterministic tractography is a method used to reconstruct white matter pathways (streamlines or tracts) in the brain. It works by starting from seed points and iteratively following the most likely fiber orientation within each voxel. This orientation is typically derived from the peaks of an ODF, often estimated using models like Constrained Spherical Deconvolution (CSD).

The `diffusemri` library provides tools to perform these steps, including fitting ODF models and then running the tracking algorithm.

In [None]:
import os
import shutil
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

# Dipy imports
from dipy.core.gradients import gradient_table, generate_bvecs
from dipy.data import get_sphere # For ODF concepts, though not directly used in this example's tracking function call
from dipy.io.streamline import save_trk # For saving the generated streamlines
from dipy.io.stateful_tractogram import StatefulTractogram, Space # For TRK saving context
# For optional advanced visualization (not run in this notebook by default):
# from dipy.viz import actor, window 

# diffusemri library imports
from models.csd import CsdModel # Used here to generate a GFA map for stopping criterion
from tracking.deterministic import track_deterministic_oudf 

# For conceptual CLI examples:
# from cli.run_tracking import main as tracking_cli_main

# Setup a temporary directory for example files
TEMP_DIR = "temp_deterministic_tractography_example"
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR)

print(f"Temporary directory for examples: {os.path.abspath(TEMP_DIR)}")

# Helper function for plotting slices
def show_slice(data_vol, slice_idx=None, title="", cmap='gray', vmin=None, vmax=None):
    data_to_show = None
    if data_vol.ndim == 3:
        s_idx = slice_idx if slice_idx is not None else data_vol.shape[2] // 2
        data_to_show = data_vol[:, :, s_idx]
    elif data_vol.ndim == 2:
        data_to_show = data_vol
    else:
        print(f"Cannot display data with {data_vol.ndim} dimensions with this plotter.")
        return
    plt.figure(figsize=(6,5))
    plt.imshow(data_to_show.T, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.xlabel("X voxel index"); plt.ylabel("Y voxel index")
    plt.colorbar(label="Metric Value")
    plt.show()

## Part 1: Preparing Input Data (Synthetic HARDI Data & ODF Model for Stopping Criteria)

Deterministic tractography requires:
1.  Diffusion-weighted imaging (DWI) data, typically HARDI (High Angular Resolution Diffusion Imaging), along with its gradient table (b-values and b-vectors).
2.  A model to estimate ODFs (e.g., CSD). The `track_deterministic_oudf` function in `diffusemri` fits its own CSD model internally.
3.  Seed points or regions from which to start tracking.
4.  A stopping criterion, often a map of a scalar metric (like FA or GFA) and a threshold.

We'll create synthetic HARDI data and then fit a CSD model to get a GFA map to use as a stopping criterion.

In [None]:
# Define dimensions and affine for HARDI synthetic data
dims_hardi = (20, 20, 8)  # x, y, z (small for faster processing)
affine_hardi = np.diag([2.0, 2.0, 2.5, 1.0]) # 2mm x 2mm x 2.5mm voxels

# Create a HARDI scheme (e.g., 1 b0, 32 directions at b=1000)
num_b0s_hardi = 1
num_dirs_hardi = 32
bval_shell_hardi = 1000.0

bvals_hardi_np = np.concatenate(([0]*num_b0s_hardi, np.ones(num_dirs_hardi) * bval_shell_hardi))
bvecs_hardi_np = np.vstack((np.zeros((num_b0s_hardi, 3)), generate_bvecs(num_dirs_hardi)))
gtab_hardi = gradient_table(bvals_hardi_np, bvecs_hardi_np, b0_threshold=50)
print(f"HARDI GradientTable: {len(bvals_hardi_np)} volumes, b-value for DWI: {bval_shell_hardi}")

# Create synthetic DWI data with a fiber bundle along the X-axis
S0_val = 150.0
d_parallel = 0.0018  # Diffusivity along fibers
d_perpendicular = 0.0004 # Diffusivity perpendicular to fibers
d_iso_background = 0.0010 # Isotropic diffusivity for background

dwi_data_hardi = np.zeros(dims_hardi + (len(bvals_hardi_np),), dtype=np.float32)
center_y, center_z = dims_hardi[1] // 2, dims_hardi[2] // 2
bundle_radius_sq = (min(dims_hardi[1], dims_hardi[2]) / 3.5)**2

for i in range(dims_hardi[0]):
    for j in range(dims_hardi[1]):
        for k_slice in range(dims_hardi[2]):
            is_in_bundle = (j - center_y)**2 + (k_slice - center_z)**2 < bundle_radius_sq
            for vol_idx in range(len(bvals_hardi_np)):
                b_val = gtab_hardi.bvals[vol_idx]
                b_vec = gtab_hardi.bvecs[vol_idx]
                if b_val == 0:
                    dwi_data_hardi[i,j,k_slice,vol_idx] = S0_val
                else:
                    if is_in_bundle:
                        # Fiber along X-axis: principal eigenvector [1,0,0]
                        adc = (b_vec[0]**2 * d_parallel + 
                               (b_vec[1]**2 + b_vec[2]**2) * d_perpendicular)
                    else:
                        # More isotropic background
                        adc = d_iso_background
                    dwi_data_hardi[i,j,k_slice,vol_idx] = S0_val * np.exp(-b_val * adc)

# Add some noise
dwi_data_hardi += np.random.normal(loc=0, scale=S0_val * 0.04, size=dwi_data_hardi.shape)
dwi_data_hardi[dwi_data_hardi < 0] = 0
dwi_data_hardi = dwi_data_hardi.astype(np.float32)

print(f"Generated synthetic HARDI DWI data with shape: {dwi_data_hardi.shape}")
show_slice(dwi_data_hardi[..., 0], slice_idx=dims_hardi[2]//2, title=f"Synthetic HARDI DWI (b0, Slice Z={dims_hardi[2]//2})")

In [None]:
print("\nFitting CSD model to generate GFA map for stopping criterion...")
# For real data, auto-response estimation or a well-validated fixed response is crucial.
# For this synthetic data, a simple fixed response will suffice for GFA generation.
fixed_response_tracto = (np.array([d_parallel, d_perpendicular, d_perpendicular]), S0_val)
csd_model_for_gfa = CsdModel(gtab_hardi, response=fixed_response_tracto, sh_order_max=6)

# Create a simple mask (all true for this synthetic data)
mask_hardi_np = np.ones(dims_hardi, dtype=bool)
gfa_map_for_stopping = None

try:
    csd_fit_for_gfa = csd_model_for_gfa.fit(dwi_data_hardi, mask=mask_hardi_np)
    gfa_map_for_stopping = csd_model_for_gfa.gfa # Access GFA via property
    if gfa_map_for_stopping is not None:
        print("CSD model fitted and GFA map generated.")
        show_slice(gfa_map_for_stopping, slice_idx=dims_hardi[2]//2, title=f"GFA Map (for stopping, Slice Z={dims_hardi[2]//2})", vmin=0, vmax=0.8)
    else:
        print("GFA map could not be generated from CSD fit. Using a dummy stopping map.")
        gfa_map_for_stopping = mask_hardi_np.astype(np.float32) # Fallback
except Exception as e:
    print(f"Error fitting CSD or generating GFA: {e}")
    print("Using a dummy GFA (mask) for stopping criterion.")
    gfa_map_for_stopping = mask_hardi_np.astype(np.float32) # Fallback

## Part 2: Deterministic Tractography

With the DWI data, gradient table, seed points, and a stopping criterion map, we can now perform deterministic tractography.

In [None]:
# Define Seed Points/Mask
# For this example, create a small seed ROI in the center of the synthetic bundle (along X)
seed_mask_np = np.zeros(dims_hardi, dtype=bool)
seed_x_start = dims_hardi[0] // 2 - 1 # Center along X, for a short segment
seed_y_center = dims_hardi[1] // 2
seed_z_center = dims_hardi[2] // 2

# Define a small square seed region in an YZ plane, at a particular X location
seed_mask_np[seed_x_start, 
             seed_y_center -1 : seed_y_center +1, 
             seed_z_center -1 : seed_z_center +1] = True

num_seed_voxels = np.sum(seed_mask_np)
print(f"Number of seed voxels defined: {num_seed_voxels}")
if num_seed_voxels == 0:
    print("Warning: No seed voxels are set. Tractography will not generate streamlines.")
    # As a fallback, seed from a single central voxel if above logic failed for some reason
    # seed_mask_np[dims_hardi[0]//2, dims_hardi[1]//2, dims_hardi[2]//2] = True
    # print(f"Fallback: Set a single seed voxel. New seed count: {np.sum(seed_mask_np)}")

# Define Stopping Criterion from GFA map
stopping_threshold_gfa = 0.15 # Example threshold: stop if GFA falls below this value
print(f"Using GFA threshold for stopping: < {stopping_threshold_gfa}")

In [None]:
print("\nRunning Deterministic Tractography...")
# The `track_deterministic_oudf` function from `diffusemri.tracking.deterministic` 
# fits its own CSD model internally to get ODFs and then tracks.

streamlines_list = [] # To store the output
if gfa_map_for_stopping is None or num_seed_voxels == 0:
    print("Skipping tractography due to missing GFA map or no seed voxels.")
else:
    try:
        streamlines_list = track_deterministic_oudf(
            dwi_data=dwi_data_hardi,
            gtab=gtab_hardi,
            seeds=seed_mask_np,  # Can be a boolean mask or an N x 3 array of seed coordinates
            affine=affine_hardi,
            metric_map_for_stopping=gfa_map_for_stopping, # Map used for stopping (e.g., GFA, FA)
            stopping_threshold_value=stopping_threshold_gfa, # Threshold on the metric_map
            sh_order=6,          # Spherical Harmonics order for internal CSD fit
            response=fixed_response_tracto, # Provide the response for internal CSD
            # step_size=0.5,     # Default is 0.5 mm
            # model_peak_threshold=0.3, # Relative threshold for peak extraction
            min_length=5,       # Minimum streamline length in mm (e.g., 2*voxel_size)
            max_length=100       # Maximum streamline length in mm
        )
        print(f"Deterministic tractography generated {len(streamlines_list)} streamlines.")

        # Save the generated streamlines to a .trk file
        if streamlines_list:
            trk_output_filepath = os.path.join(TEMP_DIR, "deterministic_streamlines.trk")
            # Create a Nifti1Image object for reference space (needed by StatefulTractogram)
            # Using the b0 volume or the GFA map for header info is common.
            reference_nifti_image = nib.Nifti1Image(gfa_map_for_stopping, affine_hardi)
            
            sft = StatefulTractogram(streamlines_list, reference_nifti_image, Space.RASMM)
            save_trk(sft, trk_output_filepath)
            print(f"Streamlines saved to: {trk_output_filepath}")
        else:
            print("No streamlines were generated.")

    except Exception as e:
        print(f"An error occurred during deterministic tractography: {e}")

## Part 3: Visualizing Streamlines (Conceptual)

Visualizing 3D tractography streamlines effectively usually requires specialized software such as:
*   **MRtrix `mrview`**
*   **TrackVis**
*   **Dipý Horizon** (Dipý's advanced visualization tool: `dipy horizon ...` from command line)
*   **Dipý FURY library** (for programmatic 3D visualization in Python, can be complex for notebooks).

A very basic 2D projection is generally not very informative for complex 3D structures but can give a rough idea if only a few simple streamlines are generated.

In [None]:
if 'streamlines_list' in locals() and streamlines_list and len(streamlines_list) > 0 and len(streamlines_list) < 100:
    plt.figure(figsize=(7,7))
    for sl in streamlines_list:
        # Project onto XY plane (plot x vs y coordinates)
        # Note: Streamlines are in world coordinates (mm)
        plt.plot(sl[:, 0], sl[:, 1], 'b-', alpha=0.5) 
    plt.title(f"Basic 2D Projection of {len(streamlines_list)} Streamlines (XY plane)")
    plt.xlabel("X coordinate (mm)"); plt.ylabel("Y coordinate (mm)")
    plt.axis('equal') # Ensure aspect ratio is maintained
    plt.grid(True)
    plt.show()
else:
    print("\nStreamline visualization not attempted due to no streamlines or too many for simple plot.")
    print("Please use specialized software like MRtrix mrview, TrackVis, or Dipý Horizon to view the saved .trk file.")

## Part 4: CLI Usage (Conceptual)

Deterministic tractography can also be initiated using the `run_tracking.py` script from the command line. This script would typically take paths to your DWI data, b-values, b-vectors, seed mask, stopping criteria map, and output file name.

In [None]:
print("Conceptual CLI command for deterministic tractography:")
print("""
python cli/run_tracking.py det_oudf \
  --dwi_file /path/to/your/dwi.nii.gz \
  --bval_file /path/to/your/bvals.bval \
  --bvec_file /path/to/your/bvecs.bvec \
  --seed_input /path/to/your/seed_mask.nii.gz \
  --stopping_criteria_map /path/to/your/gfa_or_fa_map.nii.gz \
  --stopping_threshold 0.15 \
  --output_tractogram /path/to/output/deterministic_tracts.trk \
  --sh_order 6 \
  --min_length 10 \
  --max_length 200
""")
# Note: The actual CLI subcommand and arguments might vary based on cli/run_tracking.py implementation.

## Cleanup

Remove the temporary directory and its contents.

In [None]:
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"Cleaned up temporary directory: {TEMP_DIR}")