# Global AMICO Solver with Total Variation Regularization

This example demonstrates how to use `dmipy-jax` to perform global microstructure reconstruction using the `GlobalAMICOSolver`. 
Unlike standard pixel-wise fitting, this solver optimizes the entire volume simultaneously, imposing spatial regularity constraints (Total Variation) to reduce noise while preserving edges.

In [None]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from dmipy_jax.io.connectome2 import load_connectome2_mri
from dmipy_jax.core.modeling_framework import JaxMultiCompartmentModel
from dmipy_jax.core.acquisition import SimpleAcquisitionScheme as JaxAcquisition
from dmipy_jax.signal_models.cylinder_models import RestrictedCylinder
from dmipy_jax.signal_models.sphere_models import SphereGPD
from dmipy_jax.inverse.global_amico import GlobalAMICOSolver

## 1. Load Data
We use the **Connectome 2.0** dataset (ds006181). We load a small patch for this demonstration.

In [None]:
# Load a small patch: Slice 70, center 40x40
try:
    data_dict = load_connectome2_mri(voxel_slice=(slice(70, 71), slice(40, 80), slice(40, 80)))
except FileNotFoundError:
    print("Dataset not found. Please ensure ds006181 is present.")
    # Create dummy data if missing for demo purposes
    data_dict = None

if data_dict:
    dwi = data_dict['dwi'] # (1, 40, 40, N_meas)
    scheme = data_dict['scheme']
    
    # Remove single slice dim
    dwi = jnp.squeeze(dwi, axis=0) # (40, 40, N_meas)
    print(f"Data Shape: {dwi.shape}")
    
    # Normalize
    b0_mask = scheme.bvalues < 50
    b0_map = jnp.mean(dwi[..., b0_mask], axis=-1, keepdims=True)
    dwi_norm = dwi / jnp.maximum(b0_map, 1e-6)
else:
    print("Skipping execution (no data).")

## 2. Setup Multi-Compartment Model
We define a model consisting of:
1. **Restricted Cylinder**: Representing intra-axonal water.
2. **Sphere**: Representing extra-axonal/free water.

In [None]:
if data_dict:
    # Helper for directions
    def get_fibonacci_sphere(samples=1):
        points = []
        phi = np.pi * (3. - np.sqrt(5.))
        for i in range(samples):
            y = 1 - (i / float(samples - 1)) * 2
            radius = np.sqrt(1 - y * y)
            theta = phi * i
            x = np.cos(theta) * radius
            z = np.sin(theta) * radius
            points.append([x, y, z])
        return np.array(points)

    def cart2sphere(pts):
        x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]
        r = np.sqrt(x**2 + y**2 + z**2)
        theta = np.arccos(np.clip(z / r, -1, 1))
        phi = np.arctan2(y, x)
        return np.stack([theta, phi], axis=1)

    # Setup Grid
    n_dirs = 32
    dirs_cart = get_fibonacci_sphere(n_dirs)
    dirs_sphere = cart2sphere(dirs_cart)
    mu_grid = [d for d in dirs_sphere]
    diameters = np.linspace(1e-6, 8e-6, 6)

    mc_model = JaxMultiCompartmentModel([
        RestrictedCylinder(lambda_par=1.7e-9), 
        SphereGPD(diffusion_constant=3.0e-9, diameter=15e-6)
    ])
    
    # Acquisition with timing
    acq = JaxAcquisition(
        scheme.bvalues, 
        scheme.gradient_directions,
        delta=0.012, 
        Delta=0.043,
        b0_threshold=50
    )
    
    # AMICO Grid
    rc_diam_name = 'diameter_1' if 'diameter_1' in mc_model.parameter_names else 'diameter'
    grid = {
        'mu': mu_grid,
        rc_diam_name: diameters
    }
    
    solver = GlobalAMICOSolver(mc_model, acq, grid)
    print(f"Dictionary Size: {solver.dict_matrix.shape}")

## 3. Perform Global Fit
We fit the model using the `fit_global` method. We can tune `lambda_tv` to control the smoothness.

In [None]:
if data_dict:
    print("Fitting with TV Regularization...")
    coeffs = solver.fit_global(
        dwi_norm, 
        lambda_tv=0.01, 
        lambda_l1=1e-3,
        maxiter=50,
        display=True
    )
    print("Fit complete.")

## 4. Visualize Results
We can visualize the coefficients or derived metrics like Mean Axon Diameter.

In [None]:
if data_dict:
    # Calculate Mean Diameter map
    n_mu = len(mu_grid)
    n_diam = len(diameters)
    n_cyl_atoms = n_mu * n_diam
    diams_per_atom = np.tile(diameters, n_mu)
    diams_param_vec = jnp.array(diams_per_atom)

    w_cyl = coeffs[..., :n_cyl_atoms]
    denom = jnp.sum(w_cyl, axis=-1)
    denom = jnp.where(denom < 1e-6, 1.0, denom)
    mrd = jnp.sum(w_cyl * diams_param_vec, axis=-1) / denom
    
    plt.figure(figsize=(8, 6))
    plt.imshow(mrd * 1e6, cmap='viridis')
    plt.colorbar(label='Mean Diameter (microns)')
    plt.title('Reconstructed Mean Diameter')
    plt.axis('off')
    plt.show()