In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from scipy.linalg import toeplitz
from typing import NamedTuple, Tuple

# Set JAX to use double precision for better numerical accuracy
jax.config.update("jax_enable_x64", True)

class NUTSState(NamedTuple):
    """State for NUTS tree building"""
    theta_minus: jnp.ndarray
    theta_plus: jnp.ndarray
    p_minus: jnp.ndarray
    p_plus: jnp.ndarray
    theta_sample: jnp.ndarray
    n_valid: int
    s: int
    alpha: float
    n_alpha: int


class JAXNUTSDeblurring:
    """
    NUTS sampler for Bayesian deblurring using JAX
    """
    
    def __init__(self, A_matrix, y_obs, sigma, L_matrix, gamma, initial_x):
        """
        Parameters:
        -----------
        A_matrix : array
            Forward operator (observation matrix)
        y_obs : array
            Observed data
        sigma : float
            Observation noise standard deviation
        L_matrix : array
            Prior precision matrix factor
        gamma : float
            Prior parameter
        initial_x : array
            Initial parameter guess
        """
        # Convert to JAX arrays
        self.A = jnp.array(A_matrix)
        self.y_obs = jnp.array(y_obs)
        self.sigma = sigma
        self.L = jnp.array(L_matrix)
        self.gamma = gamma
        self.x_init = jnp.array(initial_x)
        
        # NUTS parameters
        self.max_tree_depth = 10
        self.target_accept = 0.8
        self.Delta_max = 1000  # Maximum allowed energy difference
        
        # Mass matrix (will be adapted)
        self.mass_matrix = jnp.eye(len(initial_x))
        self.mass_matrix_inv = jnp.eye(len(initial_x))
        
        # Compile functions
        self._compile_functions()
        
    def _compile_functions(self):
        """Compile JAX functions for efficiency"""
        
        @jit
        def forward_operator(x):
            """Forward operator: x -> Ax"""
            return self.A @ x
        
        @jit
        def log_likelihood(x):
            """Compute log-likelihood"""
            predictions = forward_operator(x)
            residual = predictions - self.y_obs
            return -0.5 * jnp.sum(residual**2) / self.sigma**2
        
        @jit
        def log_prior(x):
            """Compute log-prior using precision matrix"""
            Lx = self.L @ x
            return -0.5 * jnp.sum(Lx**2) / self.gamma**2
        
        @jit
        def log_posterior(x):
            """Compute log-posterior"""
            return log_likelihood(x) + log_prior(x)
        
        # Store compiled functions
        self.forward_operator = forward_operator
        self.log_posterior_fn = log_posterior
        self.grad_log_posterior_fn = jit(grad(log_posterior))
        
        # Combined function for efficiency
        @jit
        def log_posterior_and_grad(x):
            return log_posterior(x), grad(log_posterior)(x)
        
        self.log_posterior_and_grad_fn = log_posterior_and_grad
    
    @partial(jit, static_argnums=(0,))
    def leapfrog_step(self, theta, p, epsilon):
        """Single leapfrog integration step"""
        # Half step for momentum
        _, grad_theta = self.log_posterior_and_grad_fn(theta)
        p = p + 0.5 * epsilon * grad_theta
        
        # Full step for position
        theta = theta + epsilon * self.mass_matrix_inv @ p
        
        # Half step for momentum
        _, grad_theta = self.log_posterior_and_grad_fn(theta)
        p = p + 0.5 * epsilon * grad_theta
        
        return theta, p
    
    @partial(jit, static_argnums=(0,))
    def compute_hamiltonian(self, theta, p):
        """Compute Hamiltonian (total energy)"""
        log_p = self.log_posterior_fn(theta)
        kinetic = 0.5 * p @ self.mass_matrix_inv @ p
        return -log_p + kinetic
    
    def find_reasonable_epsilon(self, theta, key):
        """Find reasonable initial step size (from Stan)"""
        # Sample momentum
        key, subkey = random.split(key)
        p = random.multivariate_normal(subkey, jnp.zeros(len(theta)), self.mass_matrix)
        
        # Initial step
        epsilon = 1.0
        theta_new, p_new = self.leapfrog_step(theta, p, epsilon)
        
        # Compute acceptance probability
        H_old = self.compute_hamiltonian(theta, p)
        H_new = self.compute_hamiltonian(theta_new, p_new)
        
        a = 2.0 * (jnp.exp(H_old - H_new) > 0.5) - 1.0
        
        # Scale epsilon until acceptance probability crosses 0.5
        while (jnp.exp(a * (H_old - H_new)) ** a) > (2.0 ** (-a)):
            epsilon = epsilon * (2.0 ** a)
            theta_new, p_new = self.leapfrog_step(theta, p, epsilon)
            H_new = self.compute_hamiltonian(theta_new, p_new)
        
        return epsilon
    
    def stop_criterion(self, theta_minus, theta_plus, p_minus, p_plus):
        """Check U-turn condition"""
        span = theta_plus - theta_minus
        # U-turn if trajectory starts going backwards
        return (jnp.dot(span, p_minus) >= 0) & (jnp.dot(span, p_plus) >= 0)
    
    def build_tree(self, theta, p, u, v, j, epsilon, key):
        """
        Recursively build binary tree for NUTS
        
        Returns:
        --------
        Tree state and key
        """
        if j == 0:
            # Base case: single leapfrog step
            theta_prime, p_prime = self.leapfrog_step(theta, p, v * epsilon)
            
            # Compute Hamiltonian
            H = self.compute_hamiltonian(theta_prime, p_prime)
            
            # Check if in slice
            n_prime = (u <= jnp.exp(-H)).astype(jnp.int32)
            s_prime = (u < jnp.exp(self.Delta_max - H)).astype(jnp.int32)
            
            # Acceptance probability
            alpha = jnp.minimum(1.0, jnp.exp(-H + self.compute_hamiltonian(theta, p)))
            
            return NUTSState(
                theta_minus=theta_prime,
                theta_plus=theta_prime,
                p_minus=p_prime,
                p_plus=p_prime,
                theta_sample=theta_prime,
                n_valid=n_prime,
                s=s_prime,
                alpha=alpha,
                n_alpha=1
            ), key
        
        else:
            # Recursion: build first subtree
            tree, key = self.build_tree(theta, p, u, v, j - 1, epsilon, key)
            
            # If first subtree didn't fail, build second
            if tree.s == 1:
                if v == -1:
                    # Build tree to the left
                    tree2, key = self.build_tree(
                        tree.theta_minus, tree.p_minus, u, v, j - 1, epsilon, key
                    )
                    theta_minus = tree2.theta_minus
                    p_minus = tree2.p_minus
                    theta_plus = tree.theta_plus
                    p_plus = tree.p_plus
                else:
                    # Build tree to the right
                    tree2, key = self.build_tree(
                        tree.theta_plus, tree.p_plus, u, v, j - 1, epsilon, key
                    )
                    theta_minus = tree.theta_minus
                    p_minus = tree.p_minus
                    theta_plus = tree2.theta_plus
                    p_plus = tree2.p_plus
                
                # Metropolis step for sampling
                key, subkey = random.split(key)
                accept = random.uniform(subkey) < (tree2.n_valid / (tree.n_valid + tree2.n_valid))
                theta_sample = jnp.where(accept, tree2.theta_sample, tree.theta_sample)
                
                # Update acceptance statistics
                alpha = tree.alpha + tree2.alpha
                n_alpha = tree.n_alpha + tree2.n_alpha
                
                # Check stopping criterion
                s = tree2.s * self.stop_criterion(theta_minus, theta_plus, p_minus, p_plus).astype(jnp.int32)
                n_valid = tree.n_valid + tree2.n_valid
                
                return NUTSState(
                    theta_minus=theta_minus,
                    theta_plus=theta_plus,
                    p_minus=p_minus,
                    p_plus=p_plus,
                    theta_sample=theta_sample,
                    n_valid=n_valid,
                    s=s,
                    alpha=alpha,
                    n_alpha=n_alpha
                ), key
            else:
                return tree, key
    
    def nuts_step(self, theta_current, epsilon, key):
        """Single NUTS transition"""
        # Sample momentum
        key, subkey = random.split(key)
        p = random.multivariate_normal(subkey, jnp.zeros(len(theta_current)), self.mass_matrix)
        
        # Sample slice variable
        key, subkey = random.split(key)
        u = random.uniform(subkey) * jnp.exp(-self.compute_hamiltonian(theta_current, p))
        
        # Initialize tree
        theta_minus = theta_current.copy()
        theta_plus = theta_current.copy()
        p_minus = p.copy()
        p_plus = p.copy()
        
        j = 0  # Tree depth
        theta_next = theta_current.copy()
        n = 1
        s = 1
        
        alpha_sum = 0.0
        n_alpha_sum = 0
        
        # Build tree until U-turn or max depth
        while (s == 1) & (j < self.max_tree_depth):
            # Choose direction uniformly
            key, subkey = random.split(key)
            v = 2 * (random.uniform(subkey) < 0.5) - 1
            
            if v == -1:
                # Build tree to the left
                tree, key = self.build_tree(theta_minus, p_minus, u, v, j, epsilon, key)
                theta_minus = tree.theta_minus
                p_minus = tree.p_minus
            else:
                # Build tree to the right
                tree, key = self.build_tree(theta_plus, p_plus, u, v, j, epsilon, key)
                theta_plus = tree.theta_plus
                p_plus = tree.p_plus
            
            # Metropolis sampling from the tree
            if tree.s == 1:
                key, subkey = random.split(key)
                accept = random.uniform(subkey) < (tree.n_valid / n)
                theta_next = jnp.where(accept, tree.theta_sample, theta_next)
            
            # Update statistics
            alpha_sum += tree.alpha
            n_alpha_sum += tree.n_alpha
            
            # Update number of valid points
            n = n + tree.n_valid
            
            # Check for U-turn
            s = tree.s * self.stop_criterion(theta_minus, theta_plus, p_minus, p_plus).astype(jnp.int32)
            
            j += 1
        
        # Average acceptance probability
        avg_alpha = alpha_sum / jnp.maximum(n_alpha_sum, 1)
        
        return theta_next, avg_alpha, j, key
    
    def adapt_step_size(self, epsilon, H_bar, iteration, avg_alpha):
        """Dual averaging for step size adaptation"""
        # Parameters for dual averaging
        gamma = 0.05
        t0 = 10.0
        kappa = 0.75
        mu = jnp.log(10.0 * epsilon)
        
        # Update H_bar
        eta = 1.0 / (iteration + t0)
        H_bar = (1.0 - eta) * H_bar + eta * (self.target_accept - avg_alpha)
        
        # Update log epsilon
        log_epsilon = mu - jnp.sqrt(iteration) / gamma * H_bar
        
        # Update epsilon with damping
        epsilon = jnp.exp(log_epsilon)
        
        # Also compute smoothed version
        epsilon_bar = jnp.exp(
            iteration ** (-kappa) * log_epsilon + 
            (1 - iteration ** (-kappa)) * jnp.log(epsilon)
        )
        
        return epsilon, epsilon_bar, H_bar
    
    def sample(self, n_samples, n_warmup=1000, seed=42):
        """
        Run NUTS sampling with adaptation
        """
        key = random.PRNGKey(seed)
        theta = self.x_init.copy()
        
        # Find reasonable initial step size
        print("Finding reasonable initial step size...")
        key, subkey = random.split(key)
        epsilon = self.find_reasonable_epsilon(theta, subkey)
        epsilon_bar = epsilon
        print(f"Initial epsilon: {epsilon:.4f}")
        
        # Storage
        samples = []
        accept_probs = []
        tree_depths = []
        epsilons = []
        
        # Adaptation parameters
        H_bar = 0.0
        mu = jnp.log(10.0 * epsilon)
        
        print("Starting warmup phase...")
        warmup_samples = []
        
        # Combined warmup and sampling
        total_iterations = n_warmup + n_samples
        
        for i in range(total_iterations):
            # NUTS step
            theta, avg_alpha, tree_depth, key = self.nuts_step(theta, epsilon, key)
            
            # During warmup
            if i < n_warmup:
                # Adapt step size
                iteration = i + 1
                epsilon, epsilon_bar, H_bar = self.adapt_step_size(
                    epsilon, H_bar, iteration, avg_alpha
                )
                
                # Store for mass matrix adaptation (second half of warmup)
                if i > n_warmup // 2:
                    warmup_samples.append(theta)
                
                # Adapt mass matrix periodically
                if (i + 1) % 100 == 0 and len(warmup_samples) > 50:
                    # Compute sample covariance
                    warmup_array = jnp.array(warmup_samples)
                    sample_cov = jnp.cov(warmup_array.T)
                    
                    # Regularize for stability
                    reg = 1e-6 * jnp.trace(sample_cov) / len(sample_cov)
                    self.mass_matrix = sample_cov + reg * jnp.eye(len(sample_cov))
                    self.mass_matrix_inv = jnp.linalg.inv(self.mass_matrix)
                
                # Use smoothed epsilon during warmup
                if i < n_warmup - 100:
                    epsilon = epsilon_bar
                
                # Progress report
                if (i + 1) % 200 == 0:
                    print(f"Warmup {i+1}/{n_warmup}, "
                          f"avg_alpha: {avg_alpha:.3f}, "
                          f"epsilon: {epsilon:.4f}, "
                          f"tree_depth: {tree_depth}")
            
            # After warmup - collect samples
            else:
                samples.append(theta)
                accept_probs.append(avg_alpha)
                tree_depths.append(tree_depth)
                epsilons.append(epsilon)
                
                # Progress report
                if (i + 1 - n_warmup) % 100 == 0:
                    current_iter = i + 1 - n_warmup
                    avg_accept = np.mean(accept_probs)
                    avg_depth = np.mean(tree_depths)
                    print(f"Sample {current_iter}/{n_samples}, "
                          f"avg_accept: {avg_accept:.3f}, "
                          f"avg_tree_depth: {avg_depth:.1f}")
        
        samples = jnp.array(samples)
        
        # Final adaptation of mass matrix from all warmup samples
        if len(warmup_samples) > 100:
            print(f"\nFinal mass matrix adaptation using {len(warmup_samples)} warmup samples")
            warmup_array = jnp.array(warmup_samples)
            sample_cov = jnp.cov(warmup_array.T)
            reg = 1e-6 * jnp.trace(sample_cov) / len(sample_cov)
            self.mass_matrix = sample_cov + reg * jnp.eye(len(sample_cov))
            self.mass_matrix_inv = jnp.linalg.inv(self.mass_matrix)
        
        # Compute diagnostics
        avg_accept = np.mean(accept_probs)
        avg_tree_depth = np.mean(tree_depths)
        max_tree_depth = np.max(tree_depths)
        
        print(f"\nSampling completed!")
        print(f"Average acceptance probability: {avg_accept:.3f}")
        print(f"Average tree depth: {avg_tree_depth:.1f}")
        print(f"Maximum tree depth: {max_tree_depth}")
        print(f"Final epsilon: {epsilon:.4f}")
        
        # Diagnostics
        diagnostics = {
            'accept_rate': avg_accept,
            'accept_probs': jnp.array(accept_probs),
            'tree_depths': jnp.array(tree_depths),
            'epsilons': jnp.array(epsilons),
            'final_epsilon': epsilon,
            'final_mass_matrix': self.mass_matrix,
            'avg_tree_depth': avg_tree_depth,
            'max_tree_depth': max_tree_depth
        }
        
        return samples, diagnostics
    
    def posterior_predictive(self, samples, n_pred=None):
        """Generate posterior predictive samples"""
        if n_pred is None:
            n_pred = len(samples)
        
        # Select random subset of samples
        key = random.PRNGKey(0)
        indices = random.choice(key, len(samples), shape=(n_pred,), replace=True)
        selected_samples = samples[indices]
        
        # Vectorized forward operator
        predictions = vmap(self.forward_operator)(selected_samples)
        
        return predictions


def compute_analytical_posterior(A, y_obs, sigma, L, gamma):
    """
    Compute the analytical posterior for linear Gaussian inverse problem
    """
    from scipy.linalg import inv
    
    # Posterior precision matrix
    precision_post = (1/sigma**2) * A.T @ A + (1/gamma**2) * L.T @ L
    
    # Posterior covariance matrix
    cov_post = inv(precision_post)
    
    # Posterior mean
    mean_post = cov_post @ ((1/sigma**2) * A.T @ y_obs)
    
    return mean_post, cov_post


def plot_nuts_vs_analytical(t, nuts_samples, xtrue, t_obs, y_obs, diagnostics, 
                            analytical_mean, analytical_cov):
    """Plot comparison between NUTS and analytical results"""
    
    # Compute NUTS statistics
    nuts_mean = np.mean(nuts_samples, axis=0)
    nuts_std = np.std(nuts_samples, axis=0)
    nuts_q025 = np.percentile(nuts_samples, 2.5, axis=0)
    nuts_q975 = np.percentile(nuts_samples, 97.5, axis=0)
    
    # Compute analytical statistics
    analytical_std = np.sqrt(np.diag(analytical_cov))
    analytical_q025 = analytical_mean - 1.96 * analytical_std
    analytical_q975 = analytical_mean + 1.96 * analytical_std
    
    # Create subplots
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    
    # Main reconstruction comparison
    ax = axes[0, 0]
    
    # NUTS uncertainty
    t_fill = np.concatenate([t, t[::-1]])
    nuts_fill = np.concatenate([nuts_q025, nuts_q975[::-1]])
    ax.fill(t_fill, nuts_fill, color='lightblue', alpha=0.5, label='NUTS 95% CI')
    
    # Analytical uncertainty
    analytical_fill = np.concatenate([analytical_q025, analytical_q975[::-1]])
    ax.fill(t_fill, analytical_fill, color='lightcoral', alpha=0.5, label='Analytical 95% CI')
    
    # Means
    ax.plot(t, nuts_mean, 'b-', linewidth=2, label='NUTS Mean')
    ax.plot(t, analytical_mean, 'r--', linewidth=2, label='Analytical Mean')
    ax.plot(t, xtrue, 'k-', linewidth=1.5, label='Truth')
    ax.plot(t_obs, y_obs, 'go', markersize=6, label='Observations')
    
    ax.legend()
    ax.set_xlabel('t')
    ax.set_ylabel('x(t)')
    ax.set_title('NUTS vs Analytical Posterior')
    ax.grid(True, alpha=0.3)
    
    # Mean comparison
    ax = axes[0, 1]
    ax.plot(analytical_mean, nuts_mean, 'b.', alpha=0.6)
    ax.plot([analytical_mean.min(), analytical_mean.max()], 
            [analytical_mean.min(), analytical_mean.max()], 'r--', label='y=x')
    ax.set_xlabel('Analytical Mean')
    ax.set_ylabel('NUTS Mean')
    ax.set_title('Posterior Mean Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Standard deviation comparison
    ax = axes[0, 2]
    ax.plot(analytical_std, nuts_std, 'b.', alpha=0.6)
    ax.plot([analytical_std.min(), analytical_std.max()], 
            [analytical_std.min(), analytical_std.max()], 'r--', label='y=x')
    ax.set_xlabel('Analytical Std')
    ax.set_ylabel('NUTS Std')
    ax.set_title('Posterior Std Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Error in mean
    ax = axes[1, 0]
    mean_error = np.abs(nuts_mean - analytical_mean)
    ax.plot(t, mean_error, 'r-', linewidth=2)
    ax.set_xlabel('t')
    ax.set_ylabel('|NUTS Mean - Analytical Mean|')
    ax.set_title('Absolute Error in Posterior Mean')
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    # Error in standard deviation
    ax = axes[1, 1]
    std_error = np.abs(nuts_std - analytical_std)
    ax.plot(t, std_error, 'b-', linewidth=2)
    ax.set_xlabel('t')
    ax.set_ylabel('|NUTS Std - Analytical Std|')
    ax.set_title('Absolute Error in Posterior Std')
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    # Tree depth over iterations
    ax = axes[1, 2]
    ax.plot(diagnostics['tree_depths'], alpha=0.7)
    ax.axhline(y=diagnostics['avg_tree_depth'], color='red', linestyle='--', 
               label=f"Average: {diagnostics['avg_tree_depth']:.1f}")
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Tree Depth')
    ax.set_title('NUTS Tree Depth')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Acceptance probability
    ax = axes[2, 0]
    ax.plot(diagnostics['accept_probs'], alpha=0.7)
    ax.axhline(y=diagnostics['accept_rate'], color='red', linestyle='--', 
               label=f"Average: {diagnostics['accept_rate']:.3f}")
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Acceptance Probability')
    ax.set_title('NUTS Acceptance Rate')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Step size evolution
    ax = axes[2, 1]
    ax.plot(diagnostics['epsilons'], alpha=0.7)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Step Size (epsilon)')
    ax.set_title('NUTS Step Size')
    ax.grid(True, alpha=0.3)
    
    # Tree depth histogram
    ax = axes[2, 2]
    ax.hist(diagnostics['tree_depths'], bins=np.arange(0, diagnostics['max_tree_depth']+2)-0.5, 
            edgecolor='black', alpha=0.7)
    ax.set_xlabel('Tree Depth')
    ax.set_ylabel('Frequency')
    ax.set_title('Tree Depth Distribution')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print numerical comparisons
    print("\n" + "="*60)
    print("NUMERICAL COMPARISON: NUTS vs Analytical")
    print("="*60)
    
    # Mean comparison
    mean_rmse = np.sqrt(np.mean((nuts_mean - analytical_mean)**2))
    mean_max_error = np.max(np.abs(nuts_mean - analytical_mean))
    print(f"Posterior Mean:")
    print(f"  RMSE: {mean_rmse:.6e}")
    print(f"  Max absolute error: {mean_max_error:.6e}")
    print(f"  Relative RMSE: {mean_rmse/np.std(analytical_mean):.6e}")
    
    # Standard deviation comparison
    std_rmse = np.sqrt(np.mean((nuts_std - analytical_std)**2))
    std_max_error = np.max(np.abs(nuts_std - analytical_std))
    print(f"\nPosterior Standard Deviation:")
    print(f"  RMSE: {std_rmse:.6e}")
    print(f"  Max absolute error: {std_max_error:.6e}")
    print(f"  Relative RMSE: {std_rmse/np.mean(analytical_std):.6e}")
    
    # Tree depth statistics
    print(f"\nNUTS Tree Depth Statistics:")
    print(f"  Mean: {diagnostics['avg_tree_depth']:.1f}")
    print(f"  Max: {diagnostics['max_tree_depth']}")
    print(f"  Min: {np.min(diagnostics['tree_depths'])}")
    print(f"  Std: {np.std(diagnostics['tree_depths']):.2f}")


def effective_sample_size(samples):
    """Compute effective sample size for each parameter"""
    from scipy import signal
    
    def ess_single_chain(x):
        """ESS for a single chain"""
        n = len(x)
        x = x - np.mean(x)
        
        # Auto-correlation function
        autocorr = signal.correlate(x, x, mode='full')
        autocorr = autocorr[n-1:]
        autocorr = autocorr / autocorr[0]
        
        # Find first negative autocorrelation
        first_negative = np.where(autocorr < 0)[0]
        if len(first_negative) > 0:
            cutoff = first_negative[0]
        else:
            cutoff = len(autocorr)
        
        # Sum autocorrelations up to cutoff
        sum_autocorr = 2 * np.sum(autocorr[1:cutoff]) + 1
        
        return n / sum_autocorr if sum_autocorr > 0 else n
    
    # Compute ESS for each parameter
    ess_values = []
    for i in range(samples.shape[1]):
        ess_values.append(ess_single_chain(samples[:, i]))
    
    return np.array(ess_values)


def main():
    """Main function to run the NUTS deblurring example"""
    
    # Set random seed for reproducibility  
    np.random.seed(20)
    
    # Setup problem
    n = 100
    s = np.linspace(0, 1, n+1)
    t = s.copy()
    
    # Sparse observation settings
    n_obs = 15
    obs_indices = np.sort(np.random.choice(n+1, n_obs, replace=False))
    t_obs = t[obs_indices]
    
    # Prior settings
    PriorFlag = 2
    
    # Discretize the deblurring kernel
    beta = 0.05
    a = (1/np.sqrt(2*np.pi*beta**2)) * np.exp(-0.5*(1/beta**2)*t**2)
    A_full = (1/n) * toeplitz(a)
    A = A_full[obs_indices, :]  # Sparse observation matrix
    
    # Truth
    xtrue = 10*(t-0.5)*np.exp(-0.5*1e2*(t-0.5)**2) - 0.8 + 1.6*t
    
    # Generate observations with noise
    noise = 5
    y0_full = A_full @ xtrue
    y0 = y0_full[obs_indices]
    sigma = np.max(np.abs(y0_full)) * noise / 100
    y_obs = y0 + sigma * np.random.randn(n_obs)
    
    # Prior construction
    gamma = 1/n
    
    if PriorFlag == 1:
        L = (np.diag(np.ones(n+1)) - 
             np.diag(0.5*np.ones(n), 1) - 
             np.diag(0.5*np.ones(n), -1))
    elif PriorFlag == 2:
        from scipy.linalg import inv
        L_D = (np.diag(np.ones(n+1)) - 
               np.diag(0.5*np.ones(n), 1) - 
               np.diag(0.5*np.ones(n), -1))
        
        L_Dinv = inv(L_D)
        Dev = np.sqrt(gamma**2 * np.diag(L_Dinv @ L_Dinv.T))
        
        delta = gamma / Dev[n//2]
        L = L_D.copy()
        L[0, :] = 0
        L[0, 0] = delta
        L[-1, :] = 0
        L[-1, -1] = delta
    
    # Compute analytical posterior solution
    print("Computing analytical posterior...")
    analytical_mean, analytical_cov = compute_analytical_posterior(A, y_obs, sigma, L, gamma)
    
    # Get MAP estimate as initial point for NUTS (should equal analytical mean)
    A_aug = np.vstack([(1/sigma)*A, (1/gamma)*L])
    b_aug = np.concatenate([(1/sigma)*y_obs, np.zeros(n+1)])
    x_map = np.linalg.lstsq(A_aug, b_aug, rcond=None)[0]
    
    print("Setting up NUTS sampler...")
    
    # Initialize NUTS sampler
    sampler = JAXNUTSDeblurring(
        A_matrix=A,
        y_obs=y_obs,
        sigma=sigma,
        L_matrix=L,
        gamma=gamma,
        initial_x=x_map  # Start from MAP estimate
    )
    
    # Run NUTS sampling
    print("Running NUTS sampling...")
    samples, diagnostics = sampler.sample(
        n_samples=2000,  # Number of samples to collect
        n_warmup=1000,   # Warmup iterations for adaptation
        seed=42
    )
    
    # Print results
    print(f"\nSampling completed!")
    print(f"Final acceptance rate: {diagnostics['accept_rate']:.3f}")
    print(f"Average tree depth: {diagnostics['avg_tree_depth']:.1f}")
    print(f"Final step size: {diagnostics['final_epsilon']:.4f}")
    
    # Compute NUTS posterior statistics
    nuts_mean = np.mean(samples, axis=0)
    nuts_std = np.std(samples, axis=0)
    
    print(f"\nFirst 5 elements comparison:")
    print(f"Truth:           {xtrue[:5]}")
    print(f"Analytical mean: {analytical_mean[:5]}")
    print(f"NUTS mean:       {nuts_mean[:5]}")
    print(f"MAP estimate:    {x_map[:5]}")
    
    # Compute effective sample size
    ess = effective_sample_size(samples)
    print(f"\nEffective sample size statistics:")
    print(f"  Mean ESS: {np.mean(ess):.1f}")
    print(f"  Min ESS:  {np.min(ess):.1f}")
    print(f"  Max ESS:  {np.max(ess):.1f}")
    print(f"  ESS/iteration: {np.mean(ess)/len(samples):.2%}")
    
    # Verification that MAP = analytical mean (should be very close)
    map_analytical_error = np.max(np.abs(x_map - analytical_mean))
    print(f"\nMAP vs Analytical mean max error: {map_analytical_error:.2e}")
    print("(Should be very small - confirms analytical solution)")
    
    # Compare NUTS efficiency to standard HMC
    print(f"\nNUTS Efficiency Metrics:")
    print(f"  Total leapfrog steps: {np.sum(diagnostics['tree_depths'])}")
    print(f"  Average leapfrog steps per sample: {diagnostics['avg_tree_depth']:.1f}")
    print(f"  Compared to fixed HMC with L=15: {diagnostics['avg_tree_depth']/15:.2f}x")
    
    # Plot comprehensive comparison
    plot_nuts_vs_analytical(t, samples, xtrue, t_obs, y_obs, diagnostics, 
                            analytical_mean, analytical_cov)
    
    return samples, diagnostics, sampler, analytical_mean, analytical_cov


if __name__ == "__main__":
    samples, diagnostics, sampler, analytical_mean, analytical_cov = main()