In [None]:
### NOTEBOOL EXAMPLE ###

import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from dbsi_toolbox.utils import load_dwi_data_dipy, save_parameter_maps, estimate_snr
from dbsi_toolbox.calibration import optimize_dbsi_params
from dbsi_toolbox.twostep import DBSI_TwoStep

# ==========================================
# 1. CONFIGURATION
# ==========================================
f_nifti = '/DWI_preprocessed_data/DWI_preprocessed_N4.nii.gz'
f_bval  = '/DWI_preprocessed_data/DWI_corrected.bval'
f_bvec  = '/DWI_preprocessed_data/DWI_corrected.bvec'
f_mask  = '/DWI_preprocessed_data/DWI_preprocessed_brain_mask.nii.gz'
out_dir = '/DWI_preprocessed_data/twostep_dbsi_results'

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# ==========================================
# 2. LOADING AND PRELIMINARY ANALYSIS
# ==========================================
print(">>> Loading data...")
data, affine, gtab, mask = load_dwi_data_dipy(f_nifti, f_bval, f_bvec, f_mask)
print(f"Data loaded: {data.shape}")

print(">>> Estimating SNR...")
snr_val = estimate_snr(data, gtab, affine, mask)
print(f"Estimated SNR: {snr_val:.2f}")

# ==========================================
# 3. PARAMETER CALIBRATION
# ==========================================
print("\n>>> Starting Automatic Calibration (Standard Backend)...")
best_params = optimize_dbsi_params(
    gtab.bvals, gtab.bvecs,
    snr_estimate=snr_val,
    n_monte_carlo=1000,
    bases_grid=[25, 50, 75, 100],
    lambdas_grid=[0.01, 0.1, 0.25, 0.5],
    verbose=True
)
opt_bases = best_params['n_bases']
opt_lambda = best_params['reg_lambda']
print(f"\nOPTIMAL PARAMETERS -> Bases: {opt_bases}, Lambda: {opt_lambda}")

# ==========================================
# 4. INITIALIZATION AND FITTING
# ==========================================
print("\n>>> Initializing Model...")
model = DBSI_TwoStep(
    n_iso_bases=opt_bases,
    reg_lambda=opt_lambda,
    iso_diffusivity_range=(0.0, 3.0e-3)
)

# --- Matrix Visualization (Corrected) ---
# Ensure inputs are in the format expected by the internal method
flat_bvals = gtab.bvals.flatten()
current_bvecs = gtab.bvecs
if current_bvecs.shape[0] == 3 and current_bvecs.shape[1] != 3:
    current_bvecs = current_bvecs.T

design_matrix = model.spectrum_model._build_design_matrix(flat_bvals, current_bvecs)

plt.figure(figsize=(10, 6))
plt.imshow(design_matrix, aspect='auto', cmap='viridis')
plt.title(f"DBSI Design Matrix ({opt_bases} isotropic bases)")
plt.xlabel("Bases (Fibers + Isotropic)")
plt.ylabel("Measurements")
plt.colorbar(label="Expected Signal")
plt.show()

# --- Start Fitting ---
print("\n>>> STARTING VOLUMETRIC FITTING (Standard Mode - Scipy)...")
# Note: Without Numba you will see a tqdm progress bar. It will be slower but accurate.
maps = model.fit_volume(data, gtab.bvals, gtab.bvecs, mask=mask)

# ==========================================
# 5. SAVING
# ==========================================
print(f"\n>>> Saving results to: {out_dir}")
save_parameter_maps(maps, affine, output_dir=out_dir, prefix='dbsi')
print("Done.")