# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import os

import scipy as sp
from scipy.special import j0, jv
from scipy.optimize import bisect
import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
from scipy.ndimage import binary_dilation

from math import *
np.seterr(over='ignore')

from DTIFuncs import *

import dill as pickle

from dipy.data import get_fnames
from dipy.io.gradients import read_bvals_bvecs
from dipy.core.sphere import disperse_charges, Sphere, HemiSphere

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D  # for 3D plotting

from tqdm.auto import tqdm

import torch
from torch import Tensor
from sbi.inference import SNPE
from sbi import analysis as analysis


Bessel = False

from joblib import Parallel, delayed


In [None]:
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False

## Bessel params

In [None]:
def j1_derivative(x):
    """Derivative of J1(x) using the identity: J1'(x) = 0.5 * (J0(x) - J2(x))."""
    return 0.5 * (j0(x) - j2(x))

def j2(x):
    """Bessel function J_2(x)."""
    return jv(2, x)

def j1prime_zeros(n, x_max=100, step=0.1):
    """
    Find the first n positive roots of J1'(x) by scanning from x=0 to x_max.
    
    Parameters
    ----------
    n     : int
        Number of roots to find
    x_max : float
        Maximum x to search
    step  : float
        Step size for scanning sign changes
    
    Returns
    -------
    zeros : list of float
        List of the first n roots (x > 0) of J1'(x).
    """
    zeros = []
    x_vals = np.arange(0.0, x_max, step)
    
    f_prev = j1_derivative(x_vals[0])
    for i in range(1, len(x_vals)):
        f_curr = j1_derivative(x_vals[i])
        # Check for a sign change in [x_vals[i-1], x_vals[i]]
        if f_prev * f_curr < 0:
            root = bisect(j1_derivative, x_vals[i-1], x_vals[i])
            zeros.append(root)
            if len(zeros) == n:
                break
        f_prev = f_curr
    
    return zeros

n_roots = 100
Bessel_roots = np.array(j1prime_zeros(n_roots, x_max=10e6, step=0.01))
Bessel = True

# Functions

In [None]:
def CombSignal_poisson(bvecs, bvals, Delta, delta, params):
    """
    Compute the combined diffusion signal in a fast, vectorized way.
    
    Parameters:
      bvecs  : (M,3) array of b-vectors.
      bvals  : (M,) array of b-values.
      Delta, delta : acquisition parameters (scalars)
      params : list/tuple of parameters:
          params[0] : fiber directions as an (N,2) array of spherical angles (theta, phi)
          params[1] : Dpar (scalar)
          params[2] : Dperp (scalar)
          params[3] : D (for hindered compartment; passed to vals_to_mat)
          params[4] : fiber fractions as an (N+1,) array 
                      (first element for hindered compartment, then one per fiber)
          params[5] : mean (scalar, for gamma distribution)
          params[6] : sig2 (scalar, for gamma distribution)
          params[7] : S0 (scalar)
    
    Returns:
      Signal : (M,) array of simulated signal values.
    """
    # Unpack parameters
    V_angles, Dpar, Dperp, D, fracs, mean, S0 = params

    # --- 1. Compute fiber unit vectors from spherical angles ---
    # Assume V_angles is an (N,2) array: each row is (theta, phi).
    theta_fibers = V_angles[:, 0]
    phi_fibers   = V_angles[:, 1]
    V_unit = np.column_stack((np.sin(theta_fibers) * np.cos(phi_fibers),
                              np.sin(theta_fibers) * np.sin(phi_fibers),
                              np.cos(theta_fibers)))  # shape: (N, 3)

    # --- 2. Compute angles between each fiber and each b-vector ---
    # Make sure bvecs is an array.
    bvecs = np.asarray(bvecs)  # shape: (M,3)
    M = bvecs.shape[0]
    N = V_unit.shape[0]

    # Precompute norms of bvecs (we assume fibers are unit length so no extra norm is needed)
    bvec_norms = np.linalg.norm(bvecs, axis=1)
    # Avoid division by zero:
    safe_bvec_norms = np.where(bvec_norms == 0, 1, bvec_norms)

    # Compute the dot products for each fiber with all bvecs:
    # This gives a (N, M) array where the (i,j) element = v_i dot bvec_j.
    dots = V_unit @ bvecs.T  # shape: (N, M)

    # Divide each column j by the norm of bvec j (broadcasting over fibers)
    cos_angles = dots / safe_bvec_norms  # shape: (N, M)
    cos_angles = np.clip(cos_angles, -1, 1)
    # Get the angles in [0,pi]
    Angs = np.arccos(cos_angles)
    # For bvecs that are zero (norm==0), force the angle to zero.
    if np.any(bvec_norms == 0):
        Angs[:, bvec_norms == 0] = 0
    # If an angle is greater than pi/2, use pi - angle.
    Angs = np.where(Angs > np.pi/2, np.pi - Angs, Angs)
    # In the original code the first measurement was forced to zero (presumably b = 0)
    Angs[:, 0] = 0

    # --- 3. Precompute the gamma-distributed weights for the integration over R ---
    # Gamma distribution parameters:
    lam = mean*10000
    # Define R values (50 points between 0.0001 and 0.005)
    R_vals = np.arange(0.0001, 0.01, 0.0001)  # 
    transR = (R_vals * 10000).astype(int)

    weights = (lam**transR) * np.exp(-lam) / np.array([math.factorial(r) for r in transR.astype(int)]).astype(np.double)
    weights /= np.sum(weights)

    # --- 4. Precompute the "sumterm" that appears in the restricted compartment ---
    # Here we use m=10 terms and assume that a global array Bessel_roots is available.
    m = 10
    br = Bessel_roots[:m]  # shape: (m,)
    br2 = br**2
    br6 = br**6
    # For each R in R_vals, compute the sumterm.
    # We need to broadcast over R and over the m terms.
    R2 = R_vals**2  # shape: (50,)
    # numerator: shape (50, m)
    num = (2 * Dperp * br2 * delta / R2[:, None] - 2 +
           2 * np.exp(-Dperp * br2 * delta / R2[:, None]) +
           2 * np.exp(-Dperp * br2 * Delta / R2[:, None]) -
           np.exp(-Dperp * br2 * (Delta - delta) / R2[:, None]) -
           np.exp(-Dperp * br2 * (Delta + delta) / R2[:, None]))
    # denominator: shape (50, m)
    den = (Dperp**2) * br6 * (br2 - 1) / (R_vals[:, None]**6)
    sumterm_R = np.sum(num / den, axis=1)  # shape: (50,)

    # --- 5. Compute the restricted compartment signal ---
    # For each fiber orientation i (i = 0...N-1) and for each measurement j (j = 0...M-1)
    # we need to compute:
    #   Restricted(b, theta, R) = exp(-b * (cos(theta)**2) * Dpar) *
    #                             exp(-2 * b * (sin(theta)**2) / ((Delta-delta/3)*delta**2) * sumterm)
    #
    # Notice that only the second exponential depends on R (via sumterm_R) and we need to integrate
    # over R with weights.
    #
    # Compute the part independent of R (base) and the factor x that multiplies sumterm_R.
    #
    # Angs has shape (N, M) (one row per fiber) and bvals is (M,).
    # (We assume that bvals is a 1D array; if not, cast it with np.asarray(bvals).)
    bvals = np.asarray(bvals)  # shape: (M,)
    base = np.exp(-bvals * (np.cos(Angs)**2) * Dpar)  # shape: (N, M)
    # Factor multiplying sumterm_R inside the second exponential.
    x = -2 * bvals * (np.sin(Angs)**2) / ((Delta - delta/3) * delta**2)  # shape: (N, M)
    # For each fiber orientation and measurement, we want to compute:
    #    f(i,j) = sum_{r=0}^{49} weights[r] * exp( x(i,j) * sumterm_R[r] )
    # We can compute the 3D array exp(x * sumterm_R) with shape (N, M, 50) and then contract out the last axis.
    exp_term = np.exp(x[..., None] * sumterm_R)  # shape: (N, M, 50)
    # Now take the weighted sum over the last axis (the R axis):
    restricted_integral = np.tensordot(exp_term, weights, axes=([2], [0]))  # shape: (N, M)
    # The restricted compartment signal for each fiber and measurement is then:
    Res = base * restricted_integral  # shape: (N, M)
    #
    # Finally, combine the fibers by weighting each fiber's contribution by its fraction.
    # The original code did: np.sum([f * R for f,R in zip(fracs[1:],Res)], axis=0)
    # That is equivalent to a dot product: (fracs[1:]) dot (each row of Res).
    restricted_signal = np.dot(fracs[1:], Res)  # shape: (M,)

    # --- 6. Compute the hindered compartment signal ---
    # Compute the diffusion tensor from D (using your vals_to_mat function).
    dh = vals_to_mat(D)
    # The hindered signal is given by:
    #    Hi = exp(-b * s)
    # where s = sum((bvec @ dh)*bvec, axis=1). Here bvecs is (M,3).
    s = np.sum((bvecs @ dh) * bvecs, axis=1)  # shape: (M,)
    hindered_signal = np.exp(-bvals * s)  # shape: (M,)

    # --- 7. Combine compartments and scale by S0 ---
    Signal = fracs[0] * hindered_signal + restricted_signal
    return S0 * Signal


In [None]:
def GenRicciNoise(signal,S0,snr):

    size = signal.shape
    sigma = S0 / snr
    noise1 = np.random.normal(0, sigma, size=size)
    noise2 = np.random.normal(0, sigma, size=size)

    return np.sqrt((signal+noise1) ** 2 + noise2 ** 2)


def AddNoise(signal,S0,snr):
    
    return GenRicciNoise(signal,S0,snr)

def SpherAng(v_in):

    if v_in[2] < 0:
        v_in = -v_in  # Flip the vector to the top hemisphere

    x, y, z = v_in
    r = np.linalg.norm(v_in)
    if r == 0:
        # Degenerate vector, define angles however you like:
        return 0.0, 0.0
    
    # Polar angle in [0, pi]
    theta = np.arccos(z / r)
    
    # Azimuthal angle in (-pi, pi]
    phi = np.arctan2(y, x)
        
    return theta,phi

## Simulator

In [None]:
def Simulator_new(params,bvecs,bvals,Delta,S0=1):
    new_params = [np.array([params[:2]]),params[2],params[3],params[4:10],[params[10],1-params[10]],params[11],S0]
    Sig = []
    for bve,bva,d in zip(bvecs,bvals,Delta):
        Sig.append(CombSignal_poisson(bve,bva,d,delta,new_params))
    return np.hstack(Sig) 

In [None]:
def residuals(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator_new(params,bvecs,bvals,Delta,S0=1)
    return TrueSig - Signal

In [None]:
def residuals_S0(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator_new(params,bvecs,bvals,Delta,S0=params[-1])
    return TrueSig - Signal

## Error

In [None]:
def Errors(TrueSig,TrueParams,GuessParams,Delta,bvecs,bvals):

    Res = np.linalg.norm(residuals(GuessParams,TrueSig,bvecs,bvals,Delta))
    alpha_err = np.abs(GuessParams[11]-TrueParams[11])

    angle_err1 =  np.abs(GuessParams[0]-TrueParams[0])
    angle_err2 =  np.abs(GuessParams[1]-TrueParams[1])

    Dpar_err  = np.abs(TrueParams[2]-GuessParams[2])
    Dperp_err  = np.abs(TrueParams[3]-GuessParams[3])

    MD_guess = np.linalg.eigh(vals_to_mat(GuessParams[4:10]))[0].mean()
    MD_true = np.linalg.eigh(vals_to_mat(TrueParams[4:10]))[0].mean()

    FA_guess = FracAni(np.linalg.eigh(vals_to_mat(GuessParams[4:10]))[0],MD_guess)
    FA_true  = FracAni(np.linalg.eigh(vals_to_mat(TrueParams[4:10]))[0],MD_true)

    MD_err = np.abs(MD_guess-MD_true)
    FA_err = np.abs(FA_guess-FA_true)

    Frac_err  = np.abs(TrueParams[10]-GuessParams[10])

    return Res, alpha_err,angle_err1,angle_err2,Dpar_err,Dperp_err,MD_err,FA_err,Frac_err

# Basic parameters

In [None]:
Delta = [0.017, 0.035, 0.061]             # ms
delta = 0.007           # ms

In [None]:
np.random.seed(10)
n_pts = 90
theta = np.pi * np.random.random(n_pts)
phi = 2 * np.pi * np.random.random(n_pts)
hsph_initial = HemiSphere(theta=theta, phi=phi)
hsph_updated, potential = disperse_charges(hsph_initial, 5000)
vertices = hsph_updated.vertices
values = np.ones(45)
bvecs = np.vstack((vertices))
bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)
bvals = np.hstack((0,2000 * values[:-1],[4000]*46))
bvecs = np.vstack([bvecs,bvecs,bvecs])
bvals = np.hstack([bvals,bvals,bvals])

# Simulation verification

## Full set of acquisitions

### SBI - training

In [None]:
np.random.seed(10)
NumSamps = 200000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
VS = np.vstack([x1,y1,z1])
VS = (VS/np.linalg.norm(VS,axis=0)).T
AngsS = np.array([SpherAng(v) for v in VS])

#Diffusion of restricted
DparS  = np.random.rand(NumSamps)*5e-3
DperpS = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHindS = np.array([ComputeDTI(p) for p in Params])
DHindS = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHindS])

meanS = np.random.rand(NumSamps)*0.005+1e-4
sig2S = np.random.rand(NumSamps) * (4e-7 - 9e-8) + 9e-8

#Fraction of hindered
fracS  = np.random.rand(NumSamps)
TrainParams = np.column_stack([AngsS,DparS,DperpS,DHindS,fracS,meanS])



In [None]:
network_path = './Networks/'
if not os.path.exists(network_path):
    os.makedirs(network_path)

In [None]:
if os.path.exists(f"{network_path}/Full_Sim_50_200k_poisson.pickle"):
    with open(f"{network_path}/Full_Sim_50_200k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:
    np.random.seed(10)
    torch.manual_seed(10)
    TrainSigS = []
    NoisyTrainSigS = []
    for i in tqdm(range(NumSamps)):
        v = np.array([AngsS[i]])
        dpar = DparS[i]
        dperp = DperpS[i]
        
        dh   = DHindS[i]
        f    = [fracS[i],1-fracS[i]]
    
        a = meanS[i]
        s = sig2S[i]
        s0 = 1
        
        Noise = 50
        
        TrainSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSigS.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        
        NoisyTrainSigS.append(AddNoise(TrainSigS[-1],s0,Noise))
    NoisyTrainSigS = np.array(NoisyTrainSigS)

    Obs = torch.tensor(NoisyTrainSigS).float()
    Par = torch.tensor(TrainParams).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posterior = inference.build_posterior(density_estimator)

    with open(f"{network_path}/Full_Sim_50_200k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

 Training neural network. Epochs trained: 459

### Evaluation set

In [None]:
np.random.seed(12)
TestSamps = 20

# Directions
x1  = np.random.randn(TestSamps)
y1  = np.random.randn(TestSamps)
z1  =  np.random.randn(TestSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(TestSamps)*5e-3
Dperp = np.random.rand(TestSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(TestSamps,3)*0.14-0.07
Params_rest =  np.random.rand(TestSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(TestSamps)

mean = np.random.rand(TestSamps)*0.005+1e-4
sig2 = np.random.rand(TestSamps) * (4e-7 - 9e-8) + 9e-8

S0Rand =np.ones(TestSamps)

TestParams = np.column_stack([Angs,Dpar,Dperp,DHind,frac,mean])

TestSig = []
NoisyTestSig = []
for i in tqdm(range(TestSamps)):
    v = np.array([Angs[i]])
    dpar = Dpar[i]
    dperp = Dperp[i]
    
    dh   = DHind[i]
    f    = [frac[i],1-frac[i]]

    a = mean[i]
    s = sig2[i]
    alpha     = a * a / s
    scale = s / a
    rv = stats.gamma(a=alpha,scale=scale)
    
    R = np.linspace(0.0001,0.005, 30)
    weights = rv.pdf(R)
    weights = weights/np.sum(weights)
    fig, ax = plt.subplots(2,1)
    ax[0].plot(R,weights)
    s0 = 1

    TestSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
    ax[1].plot(TestSig1)
    ax[1].plot(TestSig2)
    ax[1].plot(TestSig3)
    fig.suptitle('a')
    plt.show()
    TestSig.append(np.hstack([TestSig1,TestSig2,TestSig3]))
    Noisy = []
    for Noise in [2,10,20,30]:
        Noisy.append(AddNoise(TestSig[-1],s0,Noise))
    NoisyTestSig.append(Noisy)
NoisyTestSig = np.array(NoisyTestSig)
NoisyTestSig = np.swapaxes(NoisyTestSig,0,1)
TestSig = np.array(TestSig)

In [None]:
mean = np.random.rand(NumSamps)*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T

mean_guess = np.random.rand()*0.005 + 1e-4

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*12).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
LS_result = np.zeros([4,20,12])
bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]
for i in tqdm(range(20)):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i],bve_split,bva_split,Delta],
                                      bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x

In [None]:
LS_Errors = []
for N in tqdm(LS_result):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors.append(temp)
LS_Errors = np.array(LS_Errors)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior.sample((1000,), x=NoisyTestSig[i,j],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,13])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors.append(temp)
SBI_Errors = np.array(SBI_Errors)

## Minimum Acquisitions

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
MinIdices = np.array(true_indices)
DevilIndices = np.hstack([MinIdices,MinIdices+91,MinIdices+182])
DevilIndices = np.hstack([0,DevilIndices])
bvecs_Dev = bvecs[DevilIndices]
bvals_Dev = bvals[DevilIndices]

In [None]:
np.random.seed(10)
NumSamps = 200000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
VS = np.vstack([x1,y1,z1])
VS = (VS/np.linalg.norm(VS,axis=0)).T
AngsS = np.array([SpherAng(v) for v in VS])

#Diffusion of restricted
DparS  = np.random.rand(NumSamps)*5e-3
DperpS = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHindS = np.array([ComputeDTI(p) for p in Params])
DHindS = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHindS])

meanS = np.random.rand(NumSamps)*0.005+1e-4

#Fraction of hindered
fracS  = np.random.rand(NumSamps)
TrainParams = np.column_stack([AngsS,DparS,DperpS,DHindS,fracS,meanS])

In [None]:
if os.path.exists(f"{network_path}/Dev_Sim_50_200k_poisson.pickle"):
    with open(f"{network_path}/Dev_Sim_50_200k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:

    np.random.seed(10)
    torch.manual_seed(10)
    TrainSigS = []
    NoisyTrainSigS = []
    for i in tqdm(range(NumSamps)):
        v = np.array([AngsS[i]])
        dpar = DparS[i]
        dperp = DperpS[i]
        
        dh   = DHindS[i]
        f    = [fracS[i],1-fracS[i]]
    
        a = meanS[i]
        s0 = 1
        
        Noise = 50
        
        TrainSig1 = CombSignal_poisson(bvecs_Dev[:7],bvals_Dev[:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs_Dev[7:13],bvals_Dev[7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs_Dev[13:],bvals_Dev[13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSigS.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        
        NoisyTrainSigS.append(AddNoise(TrainSigS[-1],s0,Noise))
    NoisyTrainSigS = np.array(NoisyTrainSigS)


    Obs = torch.tensor(NoisyTrainSigS).float()
    Par = torch.tensor(TrainParams).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMin = inference.build_posterior(density_estimator)

    with open(f"{network_path}/Dev_Sim_50_200k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((1000,), x=NoisyTestSig[i,j][DevilIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,12])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors_Min = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors_Min.append(temp)
SBI_Errors_Min = np.array(SBI_Errors_Min)

In [None]:
bounds = np.array([[-np.inf,np.inf]]*12).T
bounds[:,0] = [0,np.pi]
bounds[:,1] = [0,2*np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
LS_result = np.zeros([4,20,12])
bve_splitd = [bvecs_Dev[:7],bvecs_Dev[7:13],bvecs_Dev[13:]]
bva_splitd = [bvals_Dev[:7],bvals_Dev[7:13],bvals_Dev[13:]]
for i in tqdm(range(20)):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i][DevilIndices],bve_splitd,bva_splitd,Delta],
                                      bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x

In [None]:
LS_Errors_Min = []
for N in tqdm(LS_result):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors_Min.append(temp)
LS_Errors_Min = np.array(LS_Errors_Min)

In [None]:
i = 2
posterior_samples_1 = posterior.sample((10000,), x=NoisyTestSig[-1][i],show_progress_bars=False)
fit_params_full = np.array(posterior_samples_1).mean(axis=0)

GuessSig_full = Simulator_new(fit_params_full,bve_split,bva_split,Delta)

posterior_samples_1 = posteriorMin.sample((10000,), x=NoisyTestSig[-1][i][DevilIndices],show_progress_bars=False)
fit_params_Min = np.array(posterior_samples_1).mean(axis=0)

GuessSig_min = Simulator_new(fit_params_Min,bve_split,bva_split,Delta)


fig,ax = plt.subplots(2,1,figsize=(10,4))
xsig = np.hstack([np.arange(273)[:90:3],np.arange(273)[91:181:3],np.arange(273)[182::3]])
Sig1 = np.hstack([TestSig[i][:90:3],TestSig[i][91:181:3],TestSig[i][182::3]])
gSig1 = np.hstack([GuessSig_full[:90:3],GuessSig_full[91:181:3],GuessSig_full[182::3]])
gSig2 = np.hstack([GuessSig_min[:90:3],GuessSig_min[91:181:3],GuessSig_min[182::3]])
ax[0].plot(xsig,Sig1,lw=3,c='k',label='True signal')
ax[1].plot(xsig,Sig1,lw=3,c='k')
ax[0].plot(xsig,gSig1,lw=3,alpha=0.7,c='lightseagreen',label='SBI fit signal')
ax[1].plot(xsig,gSig2,lw=3,alpha=0.7,c='lightseagreen')


ax[1].fill_betweenx(np.arange(0,1.1,0.1),0*np.ones(11),7*np.ones(11),color='gray',alpha=0.5)
ax[1].fill_betweenx(np.arange(0,1.1,0.1),45*np.ones(11),50*np.ones(11),color='gray',alpha=0.5)
ax[0].legend(handlelength=1, # fine-tune the legend's position
    frameon=False,
    fontsize=24,ncols=2,bbox_to_anchor=(0.8,0.2),columnspacing=0.5,loc=1)
ax[1].axis('off')
ax[0].axis('off')


In [None]:
i = 2

result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[-1,i],bve_split,bva_split,Delta],
                              bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
fit_params_full = result.x

GuessSig_full = Simulator_new(fit_params_full,bve_split,bva_split,Delta)

result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[-1,i][DevilIndices],bve_splitd,bva_splitd,Delta],
                              bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
fit_params_Min = result.x

GuessSig_min = Simulator_new(fit_params_Min,bve_split,bva_split,Delta)


fig,ax = plt.subplots(2,1,figsize=(10,4))
xsig = np.hstack([np.arange(273)[:90:3],np.arange(273)[91:181:3],np.arange(273)[182::3]])
Sig1 = np.hstack([TestSig[i][:90:3],TestSig[i][91:181:3],TestSig[i][182::3]])
gSig1 = np.hstack([GuessSig_full[:90:3],GuessSig_full[91:181:3],GuessSig_full[182::3]])
gSig2 = np.hstack([GuessSig_min[:90:3],GuessSig_min[91:181:3],GuessSig_min[182::3]])
ax[0].plot(xsig,Sig1,lw=3,c='k',label='True signal')
ax[1].plot(xsig,Sig1,lw=3,c='k')
ax[0].plot(xsig,gSig1,lw=3,alpha=0.7,c='darkorange',label='NLLS fit signal')
ax[1].plot(xsig,gSig2,lw=3,alpha=0.7,c='darkorange')


ax[1].fill_betweenx(np.arange(0,1.1,0.1),0*np.ones(11),7*np.ones(11),color='gray',alpha=0.5)
ax[1].fill_betweenx(np.arange(0,1.1,0.1),45*np.ones(11),50*np.ones(11),color='gray',alpha=0.5)
ax[0].legend(handlelength=1, # fine-tune the legend's position
    frameon=False,
    fontsize=24,ncols=2,bbox_to_anchor=(0.8,0.2),columnspacing=0.5,loc=1)
ax[1].axis('off')
ax[0].axis('off')
ax[1].set_ylim([0,1])


In [None]:


# -----------------------------
# Parameters
# -----------------------------
r = 1.0  # sphere radius
vector = np.array([-0.5, -1, 1])   # arbitrary vector
n = vector / np.linalg.norm(vector)  # unit vector in the direction of 'vector'
intersection = n * r  # intersection of the vector with the sphere

# Circle parameters (geodesic circle on the sphere)
circle_angle_deg = 15  # angular radius in degrees
alpha1 = [(S[:,2].mean()) for S in SBI_Errors][-1]

# -----------------------------
# Construct a circle on the sphere
# -----------------------------
# To draw a circle on the sphere centered at 'intersection',
# we use the following idea:
# For a given center n (a point on the unit sphere) and an angular radius alpha,
# any point on the circle can be written as:
#   P(t) = cos(alpha)*n + sin(alpha)*(cos(t)*u + sin(t)*w)
# where u and w are any two orthonormal vectors spanning the tangent plane at n.

# First, choose u as a vector perpendicular to n.
# (If n is parallel to the z-axis, choose a different axis to avoid the zero vector.)
if np.allclose(n, [0, 0, 1]):
    u = np.array([1, 0, 0])
else:
    u = np.cross(n, [0, 0, 1])
    u = u / np.linalg.norm(u)

# Then, w is perpendicular to both n and u.
w = np.cross(n, u)


# -----------------------------
# Create the sphere mesh
# -----------------------------
phi = np.linspace(0, 2 * np.pi, 500)  # azimuthal angle
theta = np.linspace(0, np.pi, 500)      # polar angle

phi, theta = np.meshgrid(phi, theta)
x_sphere = r * np.sin(theta) * np.cos(phi)
y_sphere = r * np.sin(theta) * np.sin(phi)
z_sphere = r * np.cos(theta)

# -----------------------------
# Plot everything
# -----------------------------
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')


# Plot the vector (using quiver)
ax.quiver(0, 0, 0, intersection[0], intersection[1], intersection[2],
          color='r', linewidth=2, arrow_length_ratio=0.1)

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 200)
circle_points = np.array([
    np.cos(alpha1) * n + np.sin(alpha1) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='paleturquoise', linewidth=2,ls='--')

circle_angle_deg = 15  # angular radius in degrees
alpha2 = [(S[:,2].mean()) for S in SBI_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha2) * n + np.sin(alpha2) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='lightseagreen', linewidth=2,ls='--')

alpha3 = [(S[:,2].mean()) for S in LS_Errors][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha3) * n + np.sin(alpha3) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='sandybrown', linewidth=2,ls='--')

alpha4 = [(S[:,2].mean()) for S in LS_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha4) * n + np.sin(alpha4) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='darkorange', linewidth=2,ls='--')

# Set equal aspect ratio for all axes
max_range = r * 1.2
for axis in 'xyz':
    getattr(ax, 'set_{}lim'.format(axis))((-max_range, max_range))


dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha3)) + (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# -----------------------------
# Plot everything
# -----------------------------
dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha3)) + (dot < np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='darkorange',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha2)) + (dot < np.cos(alpha3))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='sandybrown',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha1)) + (dot < np.cos(alpha2))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='lightseagreen',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha1))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='paleturquoise',alpha=0.5,linewidth=0,rstride=1, cstride=1, shade=False,)

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='gray', alpha=0.2, rstride=2, cstride=2, edgecolor='none')

ax.axis('equal')
ax.axis('off')
ax.view_init(elev=20, azim=-80)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch,fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.18, 0.09),fontsize=18,
    columnspacing=0.5,
    handlelength=0.8,
)
ax.set_title('Average angle diff.',x=0.52, y=0.825,fontsize=24)

In [None]:
alpha1,alpha2,alpha3,alpha4

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,1]*1000,)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

# Colors
BG_WHITE = "#fbf9f4"
GREY_LIGHT = "#b4aea9"
GREY50 = "#7F7F7F"
BLUE_DARK = "#1B2838"
BLUE = "#2a475e"
BLACK = "#282724"
GREY_DARK = "#747473"
RED_DARK = "#850e00"

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,1]*1000,)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]



# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch,fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.12, 0.8),fontsize=24,
    columnspacing=0.5,
    handlelength=0.8,
)

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-1]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]


    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-1]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-1]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-1]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

plt.ylim([-0.1,1])


In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-4]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-4]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]


    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-4]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-4]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-2]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-2]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-2]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-2]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')


In [None]:
os.system('say Finished this part')

# Real data

In [None]:
import pymatreader as pmt
import matplotlib.pyplot as plt
from dipy.segment.mask import median_otsu

from dipy.io.image import load_nifti

In [None]:
Dir = '/Users/maximilianeggl/Dropbox/PostDoc/Silvia/HealthyCat/done/Ctrl055_R01_28/'
dat = pmt.read_mat(Dir+'data_loaded.mat')
bvecs = dat['direction']
bvals = dat['bval']
FixedParams = {
    'bvals':bvals,
    'bvecs':bvecs,
    'Delta':[0.017,0.035,0.061],
    'delta':0.007,
}
Delta = FixedParams['Delta']
delta = FixedParams['delta']
n_pts = 90

S_mask, _, _ = load_nifti(Dir+'mask_055.nii.gz', return_img=True)

In [None]:
data = dat['data']
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                             numpass=1, autocrop=False, dilate=2)

In [None]:
plt.imshow(maskdata[:,:,41,0])

In [None]:
plt.imshow(np.flipud(maskdata[:,51,:,0].T))

## SBI

### Full Set

In [None]:
np.random.seed(12)
NumSamps = 200000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

In [None]:
if os.path.exists(f"{network_path}/Full_Dat_50_200k_poisson.pickle"):
    with open(f"{network_path}/Full_Dat_50_200k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:

    
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)



    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Full_Dat_50_200k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior.sample((1000,), x=maskdata[i, 54, j, :],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:
NoiseEst2 = np.copy(NoiseEst)

mask1 = np.ones_like(S_mask[:,54,:])
mask1[S_mask[:,54,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

comb_mask = fat_mask * ((1-NoiseEst2[...,-3])>0.1)

mask_CC = (1-NoiseEst2[...,-3])<0.3
for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[~comb_mask,-2] = math.nan

In [None]:
NoiseEst2 = np.copy(NoiseEst)
mask_CC = (1-NoiseEst2[...,-3])<0.3
for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[~comb_mask,-2] = math.nan
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=14)
plt.axis('off')


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior.sample((1000,), x=maskdata[i, j,axial_middle, :],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:
NoiseEst2 = np.copy(NoiseEst)

for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[(1-NoiseEst2[...,-3])<0.3,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2[...,-3],vmin=0,vmax=1,cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=14)
plt.axis('off')

### Min Set

In [None]:
np.random.seed(12)
NumSamps = 10000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4
scale = np.random.rand(NumSamps)*0.0009+0.0001

S0Rand =np.random.rand(NumSamps)*2475+25
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices1 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[91:182][bvals[91:182]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[91:182][bvals[91:182]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[91:182][bvals[91:182]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[91:182][bvals[91:182]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices2 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[182:][bvals[182:]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[182:][bvals[182:]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[182:][bvals[182:]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[182:][bvals[182:]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices3 = true_indices

DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
bvecs_Dev = bvecs[DevIndices]
bvals_Dev = bvals[DevIndices]

In [None]:
np.random.seed(12)
NumSamps = 100000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4
scale = np.random.rand(NumSamps)*0.0009+0.0001

S0Rand =np.random.rand(NumSamps)*2475+25
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

In [None]:
if os.path.exists(f"{network_path}/Dev_Dat_50_200k_poisson.pickle"):
    with open(f"{network_path}/Dev_Dat_50_200k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
                
        TrainSig1 = CombSignal_poisson(bvecs_Dev[:7],bvals_Dev[:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs_Dev[7:13],bvals_Dev[7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs_Dev[13:],bvals_Dev[13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)


    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Dev_Dat_50_200k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((500,), x=maskdata[i, 54, j, DevIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_Min[i, j] = x

for i, j, x in results:
    NoiseEst_Min[i, j,-2] = np.clip(NoiseEst_Min[i, j,-2],0,100)
    NoiseEst_Min[i, j,-3] = np.clip(NoiseEst_Min[i, j,-3],0,1)

In [None]:
NoiseEst2_Min = np.copy(NoiseEst_Min)

for i in range(13):
    NoiseEst2_Min[~mask,i] = math.nan

NoiseEst2_Min[~comb_mask,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_Min[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_Min[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=14)
plt.axis('off')

In [None]:
%matplotlib inline

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((1000,), x=maskdata[i, j,axial_middle, DevIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_Min[i, j] = x

for i, j, x in results:
    NoiseEst_Min[i, j,-2] = np.clip(NoiseEst_Min[i, j,-2],0,100)
    NoiseEst_Min[i, j,-3] = np.clip(NoiseEst_Min[i, j,-3],0,1)

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst_Min[...,-3],cmap='hot',vmin=0,vmax=1)
cbar.ax.tick_params(labelsize=14)
plt.axis('off')


## NLLS

### Full Set

In [None]:
np.random.seed(133)
S0 = 2000
mean_guess = np.random.rand()*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, 54, j, :],bve_split,bva_split,Delta],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS[i, j] = x

In [None]:
NoiseEst2_LS = np.copy(NoiseEst_LS)

for i in range(13):
    NoiseEst2_LS[~mask,i] = math.nan

NoiseEst2_LS[~comb_mask,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_LS_real_dat, guess, args=[maskdata[i, j,axial_middle, :]],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS[i, j] = x

In [None]:
NoiseEst2_LS = np.copy(NoiseEst_LS)

for i in range(13):
    NoiseEst2_LS[~mask,i] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2_LS[...,-3],vmin=0,vmax=1,cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')



### Min Set

In [None]:
k = 5#np.arange(2,28,2)[::-1][10]
#Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices1 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[91:182][bvals[91:182]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[91:182][bvals[91:182]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[91:182][bvals[91:182]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[91:182][bvals[91:182]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices2 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[182:][bvals[182:]==2000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[182:][bvals[182:]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[182:][bvals[182:]==4000]
distance_matrix = squareform(pdist(bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(k):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[182:][bvals[182:]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices3 = true_indices

DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
print(len(DevIndices))
bvecs_Dev = bvecs[DevIndices]
bvals_Dev = bvals[DevIndices]

In [None]:
bvecs_Dev

In [None]:
bvals_Dev[13:25]

In [None]:
def residuals_S0(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator(params,bvecs,bvals,Delta,S0=params[-1])
    return TrueSig - Signal

In [None]:
i,j = 0,0

In [None]:
%debug

In [None]:
sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, j,axial_middle, DevIndices],[bvecs_Dev[:13],bvecs_Dev[13:25],bvecs_Dev[25:]]
                                                                                              ,[bvals_Dev[:13],bvals_Dev[13:25],bvals_Dev[25:]],Delta],
                          bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')

In [None]:
args=[NoisyTestSig[j,i],bvecs_split,bvals_split,Delta],

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, j,axial_middle, DevIndices],[bvecs_Dev[:13],bvecs_Dev[13:25],bvecs_Dev[25:]]
                                                                                              ,[bvals_Dev[:13],bvals_Dev[13:25],bvals_Dev[25:]],Delta],
                          bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min[i, j] = x

In [None]:
k

In [None]:
len(DevIndices)

In [None]:
NoiseEst_LS_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min[i, j] = x

In [None]:
NoiseEst2_LS_Min = np.copy(NoiseEst_LS_Min)

#for i in range(13):
#    NoiseEst2_LS_Min[~mask,i] = math.nan

In [None]:
i = -2
score,ssim_map = ssim(NoiseEst2_LS_Min[...,i],NoiseEst[...,i], data_range=np.max(NoiseEst2_LS_Min[...,i])-np.min(NoiseEst_Min[...,i]),full=True)
mask = np.zeros_like(ssim_map, dtype=bool)
masked_ssim = ssim_map[Outlines[0][:, :, axial_middle]].mean()
print(masked_ssim)

fig,ax = plt.subplots(1,3,figsize=(12,4))
ax[0].imshow(1-NoiseEst[...,i],cmap='hot')
ax[1].imshow(1-NoiseEst2_LS_Min[...,i],cmap='hot')
ax[2].imshow(ssim_map)

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2_LS_Min[...,-3],cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_LS_real_dat_Min, guess, args=[maskdata[i,54,j, FullIndices]],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape







ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min[i, j] = x

In [None]:
NoiseEst2_LS_Min = np.copy(NoiseEst_LS_Min)

for i in range(13):
    NoiseEst2_LS_Min[~mask,i] = math.nan

NoiseEst2_LS_Min[~comb_mask,-2] = math.nan

In [None]:
np.save('Temp_LS_Min_CC.npy',NoiseEst_LS_Min)

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_Min[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_Min[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_Min[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_Min[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


## Evaluation

In [None]:
SBIf = NoiseEst2[comb_mask]
SBIf2 = NoiseEst2_Min[comb_mask]
SBIf_LS = NoiseEst2_LS[comb_mask]
SBIf2_LS = NoiseEst2_LS_Min[comb_mask]

In [None]:
g_pos = np.array([0, 1, 2 ])*2
POSITIONS = g_pos
jitter = 0.04
y_data = [abs(SBIf[:,-2]-SBIf2[:,-2]),abs(SBIf_LS[:,-2]-SBIf2_LS[:,-2]),abs(SBIf[:,-2]-SBIf_LS[:,-2])]
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(y_data,)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

colors = ['lightseagreen','darkorange','k']
colors2 = ['paleturquoise','sandybrown','gray']
fig,ax = plt.subplots(figsize=(8,4))

# Customize violins (remove fill, customize line, etc.)
for pc in violins["bodies"]:
    pc.set_facecolor("none")
    pc.set_edgecolor('k')
    pc.set_linewidth(1)
    pc.set_alpha(1)
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

bplot =  ax.boxplot(
    y_data,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Update the color of each box
for i, box in enumerate(bplot['boxes']):
    box.set_color(colors[i])
    
# Update the color of the medians
for i, median in enumerate(bplot['medians']):
    median.set_color(colors[i])
    
# Update the color of the whiskers.
# Note: Each box has 2 whiskers, so they appear in order.
for i in range(len(POSITIONS)):
    bplot['whiskers'][2*i].set_color(colors[i])
    bplot['whiskers'][2*i+1].set_color(colors[i])
    
# Optionally, update the color of the caps if you ever enable them
if 'caps' in bplot:
    for i, cap in enumerate(bplot['caps']):
        cap.set_color(colors[i//2])  # since there are 2 caps per box

for x, y,c in zip(x_jittered, y_data,colors2):
    ax.scatter(x, y, s = 100, color=c, alpha=0.5)

# Three Indivs

In [None]:
Dirs = ['Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
Masks = ['mask_055.nii.gz','mask_056.nii.gz','mask_057.nii.gz']
BVecs = []
BVals = []
Deltas = []
deltas = []
S_masks = []
Datas = []
Outlines = []
for D,M in tqdm(zip(Dirs,Masks)):
    dat = pmt.read_mat('/Users/maximilianeggl/Dropbox/PostDoc/Silvia/HealthyCat/done/'+D+'/data_loaded.mat')
    BVecs.append(dat['direction'])
    BVals.append(dat['bval'])
    Deltas.append(FixedParams['Delta'])
    deltas.append(FixedParams['delta'])
    
    m, _, _ = load_nifti('/Users/maximilianeggl/Dropbox/PostDoc/Silvia/HealthyCat/done/'+D+'/'+M, return_img=True)
    S_masks.append(m)

    data = dat['data']
    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    Outlines.append(mk)

In [None]:
np.random.seed(12)
NumSamps = 600000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice([1,2,3],NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])

In [None]:
if os.path.exists(f"{network_path}/3Indv_50_600k_poisson.pickle"):
    with open(f"{network_path}/3Indv_50_600k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c-1][:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c-1][(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c-1][2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/3Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

## Minimum

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
np.random.seed(12)
NumSamps = 600000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice([1,2,3],NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])

In [None]:
    V_angles = np.array([Angs[i]])

In [None]:
if os.path.exists(f"{network_path}/Dev_3Indv_50_600k_poisson.pickle"):
    with open(f"{network_path}/Dev_3Indv_50_600k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20

        TrainSig1 = CombSignal_poisson(BVecsDev[c-1][:7],BValsDev[c-1][:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecsDev[c-1][7:13],BValsDev[c-1][7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecsDev[c-1][13:],BValsDev[c-1][13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Dev_3Indv_50_600k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

## Evaluation of CC

### SBI

In [None]:
Full_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posterior.sample((1000,), x=np.append(D[i, sl, j, :],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)

In [None]:
Min_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMin.sample((1000,), x=np.append(D[i, sl, j, IndxArr[kk]],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Min_SBI.append(NoiseEst)

In [None]:
CMasks = []
kk = 0
d  = 54
NoiseEst2 = np.copy(Full_SBI[kk])
for i in range(14):
    NoiseEst2[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-NoiseEst2[...,-4])>0.1) * (NoiseEst2[...,-4]>0))

kk = 1
d  = 52
NoiseEst2 = np.copy(Full_SBI[kk])
for i in range(14):
    NoiseEst2[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-NoiseEst2[...,-4])>0) * (NoiseEst2[...,-4]>0))

kk = 2
d  = 54
NoiseEst2 = np.copy(Full_SBI[kk])
for i in range(14):
    NoiseEst2[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)
CMasks.append(fat_mask * ((1-NoiseEst2[...,-4])>0.3) * (NoiseEst2[...,-4]>0))

### NLLS

In [None]:
np.random.seed(133)
S0 = 2000
mean_guess = np.random.rand()*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

In [None]:
Full_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_split = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
    bva_split = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, sl, j, :],bve_split,bva_split,Delta],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS.append(NoiseEst_LS)

In [None]:
Min_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_splitd = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
    bva_splitd = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]

    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, sl, j, IndxArr[kk]],bve_splitd,bva_splitd,Delta],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS.append(NoiseEst_LS)

In [None]:
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False):
    import numpy as np
    from scipy import stats

    jitter = 0.02
    x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(y_data)]
    x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color='dimgray',  # Replace 'GREY_DARK' with an actual color if needed
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        y_data,
        positions=positions, 
        showfliers=False,
        showcaps=False,
        showmeans=True,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    # Update the color of the medians (Line2D objects)
    for i, median in enumerate(bplot['medians']):
        median.set_color(colors[i])
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    # Plot the scatter points with jitter (using colors2)
    for x, y, c in zip(x_jittered, y_data, colors2):
        ax.scatter(x, y, s=100, color=c, alpha=0.5)

In [None]:
g_pos = np.array([0,0.25,0.5])

colors = ['lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise']
fig,ax = plt.subplots(figsize=(12,4))
y_data = [1000*abs(Min_SBI[i][CMasks[i]][:,-3]-Full_SBI[i][CMasks[i]][:,-3]) for i in range(3)]


BoxPlots(y_data,g_pos,colors,colors2,ax)

g_pos = np.array([1,1.25,1.5])
colors = ['darkorange','darkorange','darkorange']
colors2 = ['peachpuff','peachpuff','peachpuff']
y_data = [1000*abs(Min_LS[i][CMasks[i]][:,-2]-Full_SBI[i][CMasks[i]][:,-3]) for i in range(3)]

BoxPlots(y_data,g_pos,colors,colors2,ax)

g_pos = np.array([2,2.25,2.5])
colors = ['darkorange','darkorange','darkorange']
colors2 = ['peachpuff','peachpuff','peachpuff']
y_data = [1000*abs(Full_LS[i][CMasks[i]][:,-2]-Full_SBI[i][CMasks[i]][:,-3]) for i in range(3)]

BoxPlots(y_data,g_pos,colors,colors2,ax,True)

g_pos = np.array([3,3.25,3.5])
colors = ['k','k','k']
colors2 = ['gray','gray','gray']
y_data = [1000*abs(Full_LS[i][CMasks[i]][:,-2]-Min_LS[i][CMasks[i]][:,-2]) for i in range(3)]

BoxPlots(y_data,g_pos,colors,colors2,ax)

ax.set_xticks([0.25,1.25,2.25,3.25],['SBI Min','NLLS Min','NLLS Full','NLLS Comp'],fontsize =24)

ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)


## Evaluation using SSIM

In [None]:
Full_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[42,45,48],S_masks)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posterior.sample((1000,), x=np.append(D[i,j,sl, :],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)

In [None]:
Min_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[42,45,48],S_masks)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMin.sample((1000,), x=np.append(D[i,j,sl, IndxArr[kk]],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Min_SBI.append(NoiseEst)

In [None]:
Full_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[42,45,48],S_masks)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_LS_3Indv, guess, args=[BVecs[kk],BVals[kk],D[i,j,sl, :]],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS.append(NoiseEst_LS)

In [None]:
Min_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[42,45,48],S_masks)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_LS_3Indv_Min, guess, args=[BVecsDev[kk],BValsDev[kk],D[i, j,sl, IndxArr[kk]]],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS.append(NoiseEst_LS)

In [None]:
np.save('Full_LS_slice1.npy',Full_LS)

In [None]:
Full_LS = np.load('Full_LS_slice1.npy')

In [None]:
Min_LS = np.load('Min_LS_slice1.npy')

In [None]:
def Par_frac(i,j,Mat):
    MD = np.linalg.eigh(vals_to_mat(Mat[i,j]))[0].mean()

    FA = FracAni(np.linalg.eigh(vals_to_mat(Mat[i,j]))[0],MD)
    return i, j, [FA,MD]

In [None]:
KK = [42,45,48]
FA_Full_SBI = []
MD_Full_SBI = []
for jj in range(3):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Full_SBI[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, x1,x2 in results:
        temp1[i, j] = x1
        temp2[i, j] = x2

    FA_Full_SBI.append(temp1)
    MD_Full_SBI.append(temp2)

In [None]:
KK = [42,45,48]
FA_Min_SBI = []
MD_Min_SBI = []
for jj in range(3):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Min_SBI[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, x1,x2 in results:
        temp1[i, j] = x1
        temp2[i, j] = x2

    FA_Min_SBI.append(temp1)
    MD_Min_SBI.append(temp2)

In [None]:
KK = [42,45,48]
FA_Full_LS = []
MD_Full_LS = []
for jj in range(3):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Full_LS[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, x1,x2 in results:
        temp1[i, j] = x1
        temp2[i, j] = x2

    FA_Full_LS.append(temp1)
    MD_Full_LS.append(temp2)

KK = [42,45,48]
FA_Min_LS = []
MD_Min_LS = []
for jj in range(3):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Min_LS[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, x1,x2 in results:
        temp1[i, j] = x1
        temp2[i, j] = x2

    FA_Min_LS.append(temp1)
    MD_Min_LS.append(temp2)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr

In [None]:
from skimage.metrics import structural_similarity as ssim

In [None]:
def normalized_cross_correlation(f, g):
    
    # Subtract the means
    f_mean = f - np.mean(f)
    g_mean = g - np.mean(g)
    
    # Compute the numerator and denominators
    numerator = np.sum(f_mean * g_mean)
    denominator = np.sqrt(np.sum(f_mean ** 2) * np.sum(g_mean ** 2))
    
    # Handle division by zero
    if denominator == 0:
        return 0  # Or consider handling it as needed
    
    ncc = numerator / denominator
    return ncc

In [None]:
jj = 0
SBI_comp = []
KK = [42,45,48]
for i in range(3):
    NS1 = np.copy(Min_SBI[i][...,0])
    NS2 = np.copy(Full_SBI[i][...,0])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())
KK = [42,45,48]
LS_comp = []
for i in range(3):
    NS1 = np.copy(Min_LS[i][...,0])
    NS2 = np.copy(Full_LS[i][...,0])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())
KK = [42,45,48]
SBI_LS_comp = []
for i in range(3):
    NS1 = np.copy(Full_SBI[i][...,0])
    NS2 = np.copy(Full_LS[i][...,0])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())

plt.subplots(figsize=(3,2))
plt.scatter(1.05*np.ones(3),SBI_comp,s=50,c='paleturquoise', edgecolors='lightseagreen')
plt.plot(1.05,np.mean(SBI_comp),'_',ms=20,mew=5,c='lightseagreen')
plt.scatter(np.ones(3)*1.1,LS_comp,s=50,c='peachpuff', edgecolors='darkorange')
plt.plot(1.1,np.mean(LS_comp),'_',ms=20,mew=5,c='darkorange')
plt.scatter(np.ones(3)*1.15,SBI_LS_comp,s=50,c='gray',edgecolors='k')
plt.plot(1.15,np.mean(SBI_LS_comp),'_',ms=20,mew=5,c='k')
plt.xlim([1,1.2])
plt.xticks([1.05,1.1,1.15],['SBI','NLLS','SBI \n NLLS'],fontsize=14)
plt.yticks(fontsize=14)
plt.ylim([0.0,1])
plt.grid(axis='y')


In [None]:
plt.imshow(Min_SBI[0][...,0])
plt.colorbar()

In [None]:
plt.imshow(Full_SBI[0][...,0])
plt.colorbar()

In [None]:
plt.imshow(Full_LS[0][...,0])
plt.colorbar()

In [None]:
plt.imshow(Min_LS[0][...,0])
plt.colorbar()

In [None]:
jj = 0
SBI_comp = []
KK = [42,45,48]
for i in range(3):
    NS1 = np.copy(Min_SBI[i][...,3])
    NS2 = np.copy(Full_SBI[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())
KK = [42,45,48]
LS_comp = []
for i in range(3):
    NS1 = np.copy(Min_LS[i][...,3])
    NS2 = np.copy(Full_LS[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())
KK = [42,45,48]
SBI_LS_comp = []
for i in range(3):
    NS1 = np.copy(Full_SBI[i][...,3])
    NS2 = np.copy(Full_LS[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())

plt.subplots(figsize=(3,2))
plt.scatter(1.05*np.ones(3),SBI_comp,s=50,c='paleturquoise', edgecolors='lightseagreen')
plt.plot(1.05,np.mean(SBI_comp),'_',ms=20,mew=5,c='lightseagreen')
plt.scatter(np.ones(3)*1.1,LS_comp,s=50,c='peachpuff', edgecolors='darkorange')
plt.plot(1.1,np.mean(LS_comp),'_',ms=20,mew=5,c='darkorange')
plt.scatter(np.ones(3)*1.15,SBI_LS_comp,s=50,c='gray',edgecolors='k')
plt.plot(1.15,np.mean(SBI_LS_comp),'_',ms=20,mew=5,c='k')
plt.xlim([1,1.2])
plt.xticks([1.05,1.1,1.15],['SBI','NLLS','SBI \n NLLS'],fontsize=14)
plt.yticks(fontsize=14)
plt.ylim([0.0,1])
plt.grid(axis='y')


In [None]:
jj = 0
SBI_comp = []
KK = [42,45,48]
for i in range(3):
    NS1 = np.copy(MD_Min_SBI[i])
    NS2 = np.copy(MD_Full_SBI[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())
KK = [42,45,48]
LS_comp = []
for i in range(3):
    NS1 = np.copy(MD_Min_LS[i])
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())
KK = [42,45,48]
SBI_LS_comp = []
for i in range(3):
    NS1 = np.copy(MD_Full_SBI[i])
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())

plt.subplots(figsize=(3,2))
plt.scatter(1.05*np.ones(3),SBI_comp,s=50,c='paleturquoise', edgecolors='lightseagreen')
plt.plot(1.05,np.mean(SBI_comp),'_',ms=20,mew=5,c='lightseagreen')
plt.scatter(np.ones(3)*1.1,LS_comp,s=50,c='peachpuff', edgecolors='darkorange')
plt.plot(1.1,np.mean(LS_comp),'_',ms=20,mew=5,c='darkorange')
plt.scatter(np.ones(3)*1.15,SBI_LS_comp,s=50,c='gray',edgecolors='k')
plt.plot(1.15,np.mean(SBI_LS_comp),'_',ms=20,mew=5,c='k')
plt.xlim([1,1.2])
plt.xticks([1.05,1.1,1.15],['SBI','NLLS','SBI \n NLLS'],fontsize=14)
plt.yticks(fontsize=14)
plt.ylim([0.0,1])
plt.grid(axis='y')


In [None]:
jj = 0
SBI_comp = []
KK = [42,45,48]
for i in range(3):
    NS1 = np.copy(Min_SBI[i][...,-4])
    NS2 = np.copy(Full_SBI[i][...,-4])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())
KK = [42,45,48]
LS_comp = []
for i in range(3):
    NS1 = np.copy(Min_LS[i][...,-3])
    NS2 = np.copy(Full_LS[i][...,-3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())
KK = [42,45,48]
SBI_LS_comp = []
for i in range(3):
    NS1 = np.copy(Full_SBI[i][...,-4])
    NS2 = np.copy(Full_LS[i][...,-3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())

plt.subplots(figsize=(3,2))
plt.scatter(1.05*np.ones(3),SBI_comp,s=50,c='paleturquoise', edgecolors='lightseagreen')
plt.plot(1.05,np.mean(SBI_comp),'_',ms=20,mew=5,c='lightseagreen')
plt.scatter(np.ones(3)*1.1,LS_comp,s=50,c='peachpuff', edgecolors='darkorange')
plt.plot(1.1,np.mean(LS_comp),'_',ms=20,mew=5,c='darkorange')
plt.scatter(np.ones(3)*1.15,SBI_LS_comp,s=50,c='gray',edgecolors='k')
plt.plot(1.15,np.mean(SBI_LS_comp),'_',ms=20,mew=5,c='k')
plt.xlim([1,1.2])
plt.xticks([1.05,1.1,1.15],['SBI','NLLS','SBI \n NLLS'],fontsize=14)
plt.yticks(fontsize=14)
plt.ylim([0.0,1])
plt.grid(axis='y')


In [None]:
SBI_comp = []
KK = [42,45,48]
for i in range(3):
    SBI_comp.append(ssim(FA_Full_SBI[i],FA_Min_SBI[i], data_range=1))
KK = [42,45,48]
LS_comp = []
for i in range(3):
    LS_comp.append(ssim(FA_Full_LS[i],FA_Min_LS[i], data_range=1))
KK = [42,45,48]
SBI_LS_comp = []
for i in range(3):
    SBI_LS_comp.append(ssim(FA_Full_LS[i],FA_Full_SBI[i], data_range=1))

plt.subplots(figsize=(3,4))
plt.scatter(1.05*np.ones(3),SBI_comp,s=50,c='paleturquoise', edgecolors='lightseagreen')
plt.plot(1.05,np.mean(SBI_comp),'_',ms=20,mew=5,c='lightseagreen')
plt.scatter(np.ones(3)*1.1,LS_comp,s=50,c='peachpuff', edgecolors='darkorange')
plt.plot(1.1,np.mean(LS_comp),'_',ms=20,mew=5,c='darkorange')
plt.scatter(np.ones(3)*1.15,SBI_LS_comp,s=50,c='gray',edgecolors='k')
plt.plot(1.15,np.mean(SBI_LS_comp),'_',ms=20,mew=5,c='k')
plt.xlim([1,1.2])
plt.xticks([1.05,1.1,1.15],['SBI','NLLS','SBI \n NLLS'],fontsize=14)
plt.yticks(fontsize=14)
plt.ylim([0.0,1])


In [None]:
i = 2

In [None]:
np.min(FA_Min_LS[i])

In [None]:
plt.imshow(~Outlines[i][:,:,KK[i]])

In [None]:
NS2[~Outlines[i][:,:,KK[i]]] = 0

In [None]:
for i in range(3):
    score,ssim_map = ssim(FA_Full_LS[i],FA_Min_LS[i], data_range=1,full=True,win_size=15)
    mask = np.zeros_like(ssim_map, dtype=bool)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    print(f"Overall SSIM: {score}")
    print(f"Masked SSIM: {masked_ssim}")

In [None]:
for i in range(3):
    score,ssim_map = ssim(Min_SBI[i][...,0],Full_SBI[i][...,0], data_range=5e-5,full=True,win_size=15)
    mask = np.zeros_like(ssim_map, dtype=bool)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    print(f"Overall SSIM: {score}")
    print(f"Masked SSIM: {masked_ssim}")

In [None]:
for i in range(3):
    score,ssim_map = ssim(Min_LS[i][...,-3],Full_LS[i][...,-3], data_range=1,full=True,win_size=15)
    mask = np.zeros_like(ssim_map, dtype=bool)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    print(f"Overall SSIM: {score}")
    print(f"Masked SSIM: {masked_ssim}")

In [None]:
for i in range(3):
    score,ssim_map = ssim(Min_SBI[i][...,4],Full_SBI[i][...,4], data_range=5e-5,full=True,win_size=15)
    mask = np.zeros_like(ssim_map, dtype=bool)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    print(f"Overall SSIM: {score}")
    print(f"Masked SSIM: {masked_ssim}")

In [None]:
plt.imshow(Full_SBI[i][...,1]-Min_SBI[i][...,1])
plt.colorbar()

In [None]:
plt.imshow(Min_SBI[i][...,1])

In [None]:
for i in range(3):
    NS1 = np.copy(Min_SBI[i])[Outlines[i][:,:,KK[i]]]
    NS2 = np.copy(Full_SBI[i])[Outlines[i][:,:,KK[i]]]
    print(np.corrcoef(NS1[:,4],NS2[:,4])[0,1])

In [None]:
for i in range(3):
    NS1 = np.copy(Min_SBI[i])[Outlines[i][:,:,KK[i]]]
    NS2 = np.copy(Full_SBI[i])[Outlines[i][:,:,KK[i]]]
    print(np.corrcoef(NS1[:,4],NS2[:,4])[0,1])

In [None]:
for i in range(3):
    NS1 = np.copy(Min_LS[i])[Outlines[i][:,:,KK[i]]]
    NS2 = np.copy(Full_LS[i])[Outlines[i][:,:,KK[i]]]
    print(np.corrcoef(NS1[:,-4],NS2[:,-4])[0,1])

In [None]:
NS1[:,0]

In [None]:
np.corrcoef(NS1[:,0],NS2[:,0])

In [None]:
NS1 = np.copy(Min_SBI[i])
NS1[~Outlines[i][:,:,KK[i]]] = 0

NS2 = np.copy(Full_SBI[i])
NS2[~Outlines[i][:,:,KK[i]]] = 0
#SBI_comp.append(ssim(NS1[...,-4], NS2[...,-4], data_range=1))

In [None]:
ssim_map[~Outlines[i][:,:,KK[i]]] = math.nan

In [None]:
t1 = np.copy(Full_SBI[i][...,-4])

In [None]:
t1[~Outlines[i][:,:,KK[i]]] = math.nan

In [None]:
t1 = ssim_map[~Outlines[i][:,:,KK[i]]] = math.nan

In [None]:
X2 = abs(Min_SBI[i][...,-4]-Full_SBI[i][...,-4])[abs(Min_SBI[i][...,-4]-Full_SBI[i][...,-4])>0]
plt.boxplot(X2,showfliers=False)

In [None]:
X = abs(Min_LS[i][...,-3]-Full_LS[i][...,-3])[abs(Min_LS[i][...,-3]-Full_LS[i][...,-3])>0]
plt.boxplot(X,showfliers=False)
plt.boxplot(X2,showfliers=False,positions=[2])

In [None]:
plt.imshow(t1)

In [None]:
plt.imshow(Min_SBI[i][...,-4])

In [None]:
np.nanmean(ssim_map)

In [None]:
plt.imshow(ssim_map)
plt.colorbar()

In [None]:
FA_Full_SBI[i][Outlines[i][:,:,KK[i]]]

In [None]:
FA_Full_SBI[i][Outlines[i][:,:,KK[i]]]

In [None]:
i = 0
score, ssim_map = ssim(FA_Full_LS[i],FA_Min_LS[i],data_range=1,full = True)

In [None]:
ssim_map[Outlines[i][:,:,KK[i]]].mean()

In [None]:
plt.imshow(FA_Min_SBI[i])

In [None]:
plt.imshow(FA_Full_SBI[i])

In [None]:
i = 0
ssim(FA_Full_LS[i][Outlines[i][:,:,KK[i]]],FA_Min_LS[i][Outlines[i][:,:,KK[i]]],data_range=1)

In [None]:
i = 2
psnr(FA_Full_SBI[i],FA_Min_SBI[i],data_range=1)

In [None]:
plt.imshow(FA_Full_LS[i])
plt.colorbar()

In [None]:
plt.imshow(Full_SBI[0][...,-4])

In [None]:
plt.imshow(Min_SBI[0][...,-4])

In [None]:
plt.imshow(FA_Min_LS[i],vmin=0,vmax=1)
plt.colorbar()

In [None]:
FA_Min_LS[i]