In [None]:
# Cell 1: Imports and Setup
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

# Add package root to path if running locally
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from dbsi.model import DBSI_Fused
from dbsi.utils.tools import load_data

# Visualization config
%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 6]
print("Libraries loaded successfully.")

# Cell 2: Configuration
# Define your file paths here
DATA_DIR = '../data_example' 

f_dwi  = os.path.join(DATA_DIR, 'dwi.nii.gz')
f_bval = os.path.join(DATA_DIR, 'dwi.bval')
f_bvec = os.path.join(DATA_DIR, 'dwi.bvec')
f_mask = os.path.join(DATA_DIR, 'mask.nii.gz')
output_dir = os.path.join(DATA_DIR, 'dbsi_results')

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory ready: {output_dir}")

# Cell 3: Data Loading
# This will automatically print the protocol summary
print(">>> Loading Data...")
try:
    data, affine, bvals, bvecs, mask = load_data(f_dwi, f_bval, f_bvec, f_mask, verbose=True)
    print(f"Data loaded. Volume shape: {data.shape}")
except Exception as e:
    print(f"Error: {e}")

# Cell 4: Pipeline Execution
# The fit method handles SNR estimation, MC calibration, and fitting
print("\n>>> Starting DBSI Fusion Pipeline...")

# Initialize model
model = DBSI_Fused(enable_step2=True)

# Run fitting (returns 4D array with 6 channels)
results = model.fit(data, bvals, bvecs, mask, calibrate=True)

print(">>> Analysis Complete.")

# Cell 5: Visualization
mid_slice = results.shape[2] // 2

# Extract maps
fiber_map = results[:, :, mid_slice, 0]
restricted_map = results[:, :, mid_slice, 1]  # Inflammation marker
hindered_map = results[:, :, mid_slice, 2]    # Edema marker
water_map = results[:, :, mid_slice, 3]       # CSF

# Plot
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

im0 = axes[0].imshow(np.rot90(fiber_map), cmap='jet', vmin=0, vmax=1)
axes[0].set_title('Fiber Fraction (Axons)')
plt.colorbar(im0, ax=axes[0], fraction=0.046)

im1 = axes[1].imshow(np.rot90(restricted_map), cmap='hot', vmin=0, vmax=0.5)
axes[1].set_title('Restricted Fraction (Inflammation)')
plt.colorbar(im1, ax=axes[1], fraction=0.046)

im2 = axes[2].imshow(np.rot90(hindered_map), cmap='viridis', vmin=0, vmax=1)
axes[2].set_title('Hindered Fraction (Edema)')
plt.colorbar(im2, ax=axes[2], fraction=0.046)

im3 = axes[3].imshow(np.rot90(water_map), cmap='Blues', vmin=0, vmax=1)
axes[3].set_title('Free Water (CSF)')
plt.colorbar(im3, ax=axes[3], fraction=0.046)

for ax in axes: ax.axis('off')
plt.tight_layout()
plt.show()

# Cell 6: Save to Disk
print(f"\n>>> Saving maps to: {output_dir}")

map_names = [
    'fiber_fraction', 
    'restricted_fraction', 
    'hindered_fraction', 
    'water_fraction', 
    'axial_diffusivity', 
    'radial_diffusivity'
]

for i, name in enumerate(map_names):
    out_img = nib.Nifti1Image(results[..., i], affine)
    fname = os.path.join(output_dir, f'dbsi_{name}.nii.gz')
    nib.save(out_img, fname)
    print(f"Saved: {fname}")

print("\nAll files saved successfully.")