# Interactive Tutorial: DBSI Fitting

This notebook guides you through using the `dbsi_toolbox` to load DWI data, fit the DBSI model, and visualize the results.

## 1. Setup and Imports

Ensure the toolbox is installed (run `pip install -e .` from the repository root). Let's import the necessary modules.

In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import sys
from time import time

# Import from our toolbox
try:
    from dbsi_toolbox.utils import load_dwi_data_dipy, save_parameter_maps
    from dbsi_toolbox import DBSIModel
except ImportError:
    print("ERROR: Could not import 'dbsi_toolbox'.", file=sys.stderr)
    print("Make sure you have installed the package by running 'pip install -e .' from the root directory.", file=sys.stderr)

# Configure matplotlib for plotting
%matplotlib inline
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [12, 12]

## 2. Define File Paths

Edit the paths below to point to your data.

In [None]:
# --- EDIT THESE PATHS ---
f_corrected = '<your_corrected_dwi_file_path.nii.gz>'
fbval = '<your_bval_file_path>'
fbvec = '<your_bvec_file_path.bvec>'
f_mask = '<your_brain_mask_file_path.nii.gz>'
output_dir = '<your_output_directory_folder_path>'
prefix = "<id_prefix_for_output_files>"

## 3. Load Data

We'll use the `load_dwi_data_dipy` utility function to load everything.

In [None]:
try:
    data, affine, gtab, mask = load_dwi_data_dipy(
        f_nifti=f_corrected,
        f_bval=fbval,
        f_bvec=fbvec,
        f_mask=f_mask
    )
    print(f"\n✓ Data loaded. Voxels in mask: {np.sum(mask):,}")
except FileNotFoundError as e:
    print(f"ERROR: File not found. Check your paths.\n{e}", file=sys.stderr)

## 4. Initialize and Fit the Model

We'll create an instance of `DBSIModel` and run `fit_volume`. This is the most time-consuming step.

In [None]:
from dbsi_toolbox import DBSIModel 

# 1. Two-step DBSI model instantiation
model = DBSIModel(
    reg_lambda=0.01,       # Linear Spectrum regularization term
    filter_threshold=0.01  # Threshold before nnls fitting, to filter out small values (noise)
)

# 2. Esegui il fit
print(f"Starting Two-Step DBSI on {np.sum(mask):,} voxels...")
start_time = time()

param_maps = model.fit_volume(
    volume=data, 
    bvals=gtab.bvals, 
    bvecs=gtab.bvecs, 
    mask=mask,
    show_progress=True
)

end_time = time()
print(f"\n✓ Two-Step Fitting complete in {end_time - start_time:.2f} seconds.")

## 5. Visualize the Results

This is the advantage of a notebook! Let's visualize a few key maps for a quality check.

In [None]:
# Choose a slice to display (e.g., the middle slice)
z_slice_index = data.shape[2] // 2

print(f"Displaying slice Z = {z_slice_index}")

# Helper function for plotting
def plot_map(ax, data, title, cmap='gray', vmin=0, vmax=None):
    if vmax is None:
        vmax = np.percentile(data[data > 0], 98) # Clip range for better viz
    im = ax.imshow(data.T, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
    ax.set_title(title, color='white')
    ax.axis('off')
    plt.colorbar(im, ax=ax, shrink=0.8)

# Get the key maps
f_fiber = param_maps['f_fiber'][:, :, z_slice_index]
f_restricted = param_maps['f_restricted'][:, :, z_slice_index]
f_water = param_maps['f_water'][:, :, z_slice_index]
r_squared = param_maps['R2'][:, :, z_slice_index]

# Create a 2x2 plot
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Plot f_fiber (Fiber Fraction)
plot_map(axes[0, 0], f_fiber, 'f_fiber (Fiber Fraction)', cmap='bone', vmax=1.0)

# Plot f_restricted (Cellularity/Inflammation)
plot_map(axes[0, 1], f_restricted, 'f_restricted (Cellularity)', cmap='hot', vmax=0.5)

# Plot f_water (Edema)
plot_map(axes[1, 0], f_water, 'f_water (Edema/CSF)', cmap='Blues', vmax=0.8)

# Plot R-Squared (Fit Quality)
plot_map(axes[1, 1], r_squared, 'R² (Fit Quality)', cmap='inferno', vmin=0, vmax=1.0)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.suptitle(f'DBSI Maps - Slice {z_slice_index}', color='white', fontsize=16)
plt.show()

## 6. Save the Maps

Finally, let's save all the output maps to NIfTI files.

In [None]:
save_parameter_maps(
    param_maps=param_maps, 
    affine=affine, 
    output_dir=output_dir, 
    prefix=prefix
)

print(f"\n✓ Processing complete. Results saved in: {output_dir}")