In [None]:
"""
Shooting Method for SNN Training via Pontryagin's Maximum Principle
====================================================================

Priority 1: Solve linearized PMP using shooting method
Priority 2: Full nonlinear training with multi-class support

Multi-class classification with cross-entropy loss.
"""

import numpy as np
from scipy.integrate import solve_ivp
from scipy.interpolate import UnivariateSpline, interp1d
from scipy.optimize import minimize, fsolve
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, List, Callable, Optional
import pandas as pd
import os
import time
import warnings
from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeoutError
from functools import partial
import signal

# For confusion matrix
try:
    from sklearn.metrics import confusion_matrix
except ImportError:
    def confusion_matrix(y_true, y_pred):
        """Simple confusion matrix implementation"""
        classes = np.unique(np.concatenate([y_true, y_pred]))
        n_classes = len(classes)
        cm = np.zeros((n_classes, n_classes), dtype=int)
        for i, true_class in enumerate(classes):
            for j, pred_class in enumerate(classes):
                cm[i, j] = np.sum((y_true == true_class) & (y_pred == pred_class))
        return cm

# Progress bar
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

warnings.filterwarnings('ignore')


class TimeoutException(Exception):
    """Exception raised when computation times out"""
    pass


def timeout_handler(signum, frame):
    """Signal handler for timeout"""
    raise TimeoutException("Computation timed out")


@dataclass
class NetworkConfig:
    """Network architecture configuration"""
    d: int = 2              # Input dimension
    L: int = 2              # Number of hidden layers
    P: int = 3              # Neurons per hidden layer
    n_classes: int = 1      # Number of output classes (set from data)
    
    # Time parameters
    T: float = 60.0         # Time horizon (for testing/visualization)
    T_train: float = 300    # Time horizon for training (much shorter!)
    dt: float = 0.01        # Time step
    
    # Timeout settings
    max_simulation_time: float = 10.0  # Max seconds per forward simulation
    max_gradient_time: float = 30.0    # Max seconds per gradient computation
    
    # Training settings
    batch_size: int = 16    # Mini-batch size
    use_parallel: bool = True  # Parallel processing
    n_workers: int = 4      # Number of parallel workers
    
    # Input layer parameters
    tau_v: float = 8.0
    theta_v: float = 0.8
    
    # Hidden layer parameters
    tau_h: float = 6.0
    theta_h: float = 0.25
    
    # Output layer parameters
    tau_u: float = 10.0
    theta_u: float = 0.3
    
    # Spike kernel
    mu: float = 0.2
    
    # Mollification schedule
    zeta_0: float = 3.0
    zeta_1: float = 10.0
    
    def get_zeta(self, epoch: int, total_epochs: int) -> float:
        """Geometric schedule for ζ"""
        if total_epochs <= 1:
            return self.zeta_1
        return self.zeta_0 * (self.zeta_1 / self.zeta_0) ** (epoch / (total_epochs - 1))
    
    @property
    def n_state(self) -> int:
        """Total state dimension: 1 input + L*P hidden + n_classes output"""
        return 1 + self.L * self.P + self.n_classes


class GaussianKernel:
    """Gaussian spike kernel"""
    
    def __init__(self, mu: float = 0.2):
        self.mu = mu
        self.coef = 1.0 / (mu * np.sqrt(2 * np.pi))
    
    def __call__(self, t: float, spike_times: List[float]) -> float:
        if len(spike_times) == 0:
            return 0.0
        return sum(self.coef * np.exp(-0.5 * ((t - ts) / self.mu) ** 2) 
                   for ts in spike_times)


class MollifiedReset:
    """Mollified reset function"""
    
    def __init__(self, zeta: float = 10.0):
        self.zeta = zeta
    
    def H(self, s: np.ndarray) -> np.ndarray:
        """H_ζ(s) = 1/2(1 + tanh(ζs/2))"""
        return 0.5 * (1 + np.tanh(self.zeta * s / 2))
    
    def D(self, s: np.ndarray, theta: float) -> np.ndarray:
        """D_ζ(s; θ) = (1 - H_ζ(s - θ))s"""
        return (1 - self.H(s - theta)) * s
    
    def dD_ds(self, s: np.ndarray, theta: float) -> np.ndarray:
        """Derivative ∂D/∂s"""
        H_val = self.H(s - theta)
        dH_ds = 0.25 * self.zeta * (1 / np.cosh(self.zeta * (s - theta) / 2))**2
        return (1 - H_val) - s * (dH_ds)


class SNNDynamics:
    """SNN dynamics with multi-class output"""
    
    def __init__(self, config: NetworkConfig):
        self.config = config
        self.kernel = GaussianKernel(config.mu)
        self.reset = MollifiedReset()
        
        # Initialize parameters
        self.a = np.random.randn(config.d) * 0.1
        
        self.omega = []
        for ell in range(config.L):
            if ell == 0:
                self.omega.append(np.random.randn(config.P, 1) * 0.1)
            else:
                self.omega.append(np.random.randn(config.P, config.P) * 0.1)
        
        self.w = np.random.randn(config.n_classes, config.P) * 0.1
        self.nu = np.random.randn(config.n_classes) * 0.1
        
        # Spike times storage
        self.spike_times_input = []
        self.spike_times_hidden = [[[] for _ in range(config.P)] 
                                   for _ in range(config.L)]
        
        # Trajectory cache
        self.X_trajectory = None
        self.t_grid = None
    
    def reset_spikes(self):
        """Clear spike times"""
        self.spike_times_input = []
        self.spike_times_hidden = [[[] for _ in range(self.config.P)] 
                                   for _ in range(self.config.L)]
    
    def update_zeta(self, epoch: int, total_epochs: int):
        """Update mollification parameter"""
        zeta = self.config.get_zeta(epoch, total_epochs)
        self.reset.zeta = zeta
        return zeta
    
    def dynamics(self, t: float, X: np.ndarray, x_input: np.ndarray) -> np.ndarray:
        """Compute dX/dt = F(X, Υ)"""
        cfg = self.config
        dX = np.zeros_like(X)
        idx = 0
        
        # Input neuron
        v = X[idx]
        dX[idx] = (1.0 / cfg.tau_v) * (-v + np.dot(self.a, x_input))
        idx += 1
        
        # Hidden layers
        for ell in range(cfg.L):
            for p in range(cfg.P):
                xi = X[idx]
                
                if ell == 0:
                    J = self.kernel(t, self.spike_times_input)
                    synaptic_input = self.omega[ell][p, 0] * J
                else:
                    synaptic_input = 0.0
                    for q in range(cfg.P):
                        J_q = self.kernel(t, self.spike_times_hidden[ell-1][q])
                        synaptic_input += self.omega[ell][p, q] * J_q
                
                dX[idx] = (1.0 / cfg.tau_h) * (-xi + synaptic_input)
                idx += 1
        
        # Output layer (multi-class)
        for c in range(cfg.n_classes):
            u = X[idx]
            
            Phi = 0.0
            for q in range(cfg.P):
                Phi += self.kernel(t, self.spike_times_hidden[-1][q])
            
            dX[idx] = (1.0 / cfg.tau_u) * (-u + np.dot(self.w[c, :], np.ones(cfg.P)) * Phi)
            idx += 1
        
        return dX
    
    def simulate(self, x_input: np.ndarray, 
                 T: Optional[float] = None,
                 record: bool = True,
                 timeout: Optional[float] = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        Forward simulation with optional timeout.
        
        Parameters:
        -----------
        timeout : float, optional
            Maximum time in seconds for simulation
        """
        self.reset_spikes()
        
        if T is None:
            T = self.config.T
        
        start_time = time.time()
        
        dt = self.config.dt
        t_grid = np.arange(0, T + dt, dt)
        n_steps = len(t_grid)
        
        X = np.zeros((n_steps, self.config.n_state))
        X[0] = 0.0
        
        for i in range(n_steps - 1):
            # Check timeout
            if timeout is not None and (time.time() - start_time) > timeout:
                print(f"  Warning: Simulation timeout after {time.time() - start_time:.2f}s")
                # Return partial trajectory
                if record:
                    self.X_trajectory = X[:i+1]
                    self.t_grid = t_grid[:i+1]
                return t_grid[:i+1], X[:i+1]
            
            t = t_grid[i]
            
            dX = self.dynamics(t, X[i], x_input)
            X[i+1] = X[i] + dX * dt
            
            idx = 0
            
            # Input neuron
            if X[i+1, idx] >= self.config.theta_v:
                self.spike_times_input.append(t)
                X[i+1, idx] = self.reset.D(np.array([X[i+1, idx]]), 
                                          self.config.theta_v)[0]
            idx += 1
            
            # Hidden neurons
            for ell in range(self.config.L):
                for p in range(self.config.P):
                    if X[i+1, idx] >= self.config.theta_h:
                        self.spike_times_hidden[ell][p].append(t)
                        X[i+1, idx] = self.reset.D(np.array([X[i+1, idx]]), 
                                                   self.config.theta_h)[0]
                    idx += 1
            
            idx += self.config.n_classes
        
        if record:
            self.X_trajectory = X
            self.t_grid = t_grid
        
        return t_grid, X
    
    def simulate_fast(self, x_input: np.ndarray, T: float) -> np.ndarray:
        """
        Fast simulation for training (only returns final state).
        
        Returns:
        --------
        X_final : np.ndarray
            Final state X(T)
        """
        try:
            _, X_traj = self.simulate(x_input, T=T, record=False, 
                                     timeout=self.config.max_simulation_time)
            return X_traj[-1]
        except Exception as e:
            print(f"  Error in simulation: {e}")
            return np.zeros(self.config.n_state)
    
    def compute_output(self, X_final: np.ndarray) -> np.ndarray:
        """
        Compute network output (softmax for multi-class).
        
        Returns:
        --------
        probs : np.ndarray (n_classes,)
            Class probabilities
        """
        u_final = X_final[-self.config.n_classes:]
        
        # Logits: Σ ν_c σ(u_c - θ_u)
        sigma_u = 1.0 / (1.0 + np.exp(-(u_final - self.config.theta_u)))
        logits = self.nu * sigma_u
        
        # Softmax
        exp_logits = np.exp(logits - np.max(logits))  # Numerical stability
        probs = exp_logits / np.sum(exp_logits)
        
        return probs
    
    def compute_loss(self, probs: np.ndarray, y_true: int) -> float:
        """
        Cross-entropy loss for multi-class.
        
        L = -log(p_y)
        """
        return -np.log(probs[y_true] + 1e-10)
    
    def compute_state_jacobian(self, t: float, X: np.ndarray, 
                               x_input: np.ndarray, eps: float = 1e-6) -> np.ndarray:
        """Compute A(t) = ∂F/∂X"""
        n = len(X)
        A = np.zeros((n, n))
        
        F0 = self.dynamics(t, X, x_input)
        
        for i in range(n):
            X_pert = X.copy()
            X_pert[i] += eps
            F_pert = self.dynamics(t, X_pert, x_input)
            A[:, i] = (F_pert - F0) / eps
        
        return A


class LinearizedSystem:
    """
    Linearized system around reference trajectory.
    
    δẋ = A(t)δx + B(t)δu
    -δλ̇ = A(t)^T δλ
    """
    
    def __init__(self, snn: SNNDynamics):
        self.snn = snn
        self.A_traj = None  # Store A(t) trajectory
        self.t_grid = None
        self.X_traj = None  # Store full reference trajectory
    
    def linearize_along_trajectory(self, x_input: np.ndarray):
        """
        Compute linearization A(t) along reference trajectory.
        """
        print("  Linearizing system along reference trajectory...")
        
        # Simulate reference trajectory
        t_grid_full, X_traj_full = self.snn.simulate(x_input, record=True)
        
        # Subsample for efficiency
        sample_rate = 10
        t_sample = t_grid_full[::sample_rate]
        X_sample = X_traj_full[::sample_rate]
        
        A_traj = []
        for i, t in enumerate(t_sample):
            A = self.snn.compute_state_jacobian(t, X_sample[i], x_input)
            A_traj.append(A)
        
        self.A_traj = np.array(A_traj)
        self.t_grid = t_sample
        self.X_traj = X_sample  # Store subsampled trajectory
        
        print(f"  Linearized at {len(t_sample)} time points")
        print(f"  Reference trajectory shape: {X_sample.shape}")
        print(f"  Time grid shape: {t_sample.shape}")
        
        return X_sample, self.A_traj
    
    def create_A_interpolator(self) -> Callable:
        """Create interpolator for A(t)"""
        n = self.A_traj.shape[1]
        
        # Flatten A matrices and interpolate each component
        interpolators = []
        for i in range(n):
            for j in range(n):
                A_ij = self.A_traj[:, i, j]
                interp = interp1d(self.t_grid, A_ij, kind='cubic', 
                                 fill_value='extrapolate')
                interpolators.append(interp)
        
        def A_interp(t):
            A = np.zeros((n, n))
            idx = 0
            for i in range(n):
                for j in range(n):
                    A[i, j] = interpolators[idx](t)
                    idx += 1
            return A
        
        return A_interp
    
    def solve_adjoint_shooting(self, lambda_T: np.ndarray, 
                               lambda_0_guess: np.ndarray,
                               method: str = 'newton') -> Tuple[np.ndarray, List[float]]:
        """
        Solve adjoint equation using shooting method.
        
        Find λ(0) such that integrating -λ̇ = A(t)^T λ from t=0 to t=T
        gives λ(T) = λ_T^target.
        
        Cost function: J(λ(0)) = ||λ(T; λ(0)) - λ_T^target||²
        
        Returns:
        --------
        lambda_traj : np.ndarray
            Adjoint trajectory
        costs : List[float]
            Cost at each iteration
        """
        print(f"  Solving adjoint via shooting method ({method})...")
        print(f"  Target λ(T) norm: {np.linalg.norm(lambda_T):.6f}")
        
        A_interp = self.create_A_interpolator()
        t_grid = self.t_grid
        
        def adjoint_rhs(t, lam):
            """−λ̇ = A(t)^T λ"""
            A = A_interp(t)
            return -A.T @ lam
        
        def forward_integrate_adjoint(lam_0):
            """Integrate adjoint forward from λ(0) to λ(T)"""
            sol = solve_ivp(adjoint_rhs, [t_grid[0], t_grid[-1]], lam_0,
                           t_eval=t_grid, method='RK45',
                           rtol=1e-6, atol=1e-8)
            return sol.y.T  # Shape: (n_time, n_state)
        
        def cost_function(lam_0):
            """Cost: ||λ(T) - λ_T^target||²"""
            lam_traj = forward_integrate_adjoint(lam_0)
            lam_T_computed = lam_traj[-1]
            cost = 0.5 * np.sum((lam_T_computed - lambda_T)**2)
            return cost, lam_traj
        
        # Initialize λ(0) by interpolating between zero and backward solution
        print(f"  Finding initial guess via backward integration...")
        
        # Backward integration from λ(T) to get rough λ(0)
        sol_backward = solve_ivp(adjoint_rhs, [t_grid[-1], t_grid[0]], lambda_T,
                                 t_eval=t_grid[::-1], method='RK45',
                                 rtol=1e-6, atol=1e-8)
        lambda_0_backward = sol_backward.y[:, -1]  # λ(0) from backward
        
        print(f"  Backward λ(0) norm: {np.linalg.norm(lambda_0_backward):.6f}")
        
        # Use backward solution as initial guess
        lambda_0_current = lambda_0_backward.copy()
        
        # Shooting method iterations
        max_iter = 20
        tol = 1e-6
        costs = []
        
        print(f"\n  Starting shooting iterations...")
        print(f"  {'Iter':<6} {'Cost':<12} {'||dλ||':<12} {'Step'}")
        print(f"  {'-'*50}")
        
        for iteration in range(max_iter):
            cost, lam_traj = cost_function(lambda_0_current)
            costs.append(cost)
            
            lam_T_computed = lam_traj[-1]
            residual = lam_T_computed - lambda_T
            residual_norm = np.linalg.norm(residual)
            
            print(f"  {iteration:<6} {cost:<12.6e} {residual_norm:<12.6e}", end='')
            
            if residual_norm < tol:
                print(f" ✓ Converged!")
                break
            
            # Newton step: compute sensitivity dλ(T)/dλ(0)
            # Use finite differences
            eps = 1e-6
            n_state = len(lambda_0_current)
            
            # Compute Jacobian of terminal condition w.r.t. initial condition
            dlam_T_dlam_0 = np.zeros((n_state, n_state))
            
            for i in range(n_state):
                lam_0_pert = lambda_0_current.copy()
                lam_0_pert[i] += eps
                lam_traj_pert = forward_integrate_adjoint(lam_0_pert)
                dlam_T_dlam_0[:, i] = (lam_traj_pert[-1] - lam_T_computed) / eps
            
            # Newton update: λ(0) ← λ(0) - [dλ(T)/dλ(0)]^{-1} * (λ(T) - λ_T^target)
            try:
                delta = np.linalg.solve(dlam_T_dlam_0, residual)
                step_size = 0.5  # Damped Newton
                lambda_0_current = lambda_0_current - step_size * delta
                print(f" Newton")
            except np.linalg.LinAlgError:
                # Fall back to gradient descent
                gradient = dlam_T_dlam_0.T @ residual
                step_size = 0.01
                lambda_0_current = lambda_0_current - step_size * gradient
                print(f" Gradient")
        
        # Final trajectory
        _, final_traj = cost_function(lambda_0_current)
        
        print(f"\n  Final cost: {costs[-1]:.6e}")
        print(f"  Final λ(0) norm: {np.linalg.norm(lambda_0_current):.6f}")
        print(f"  Final λ(T) norm: {np.linalg.norm(final_traj[-1]):.6f}")
        
        return final_traj, costs


class ShootingMethodSolver:
    """Shooting method for full nonlinear system"""
    
    def __init__(self, snn: SNNDynamics):
        self.snn = snn
    
    def create_state_interpolator(self, t_grid: np.ndarray, 
                                  X_traj: np.ndarray) -> List[Callable]:
        """Create smooth interpolators for state trajectory"""
        n_state = X_traj.shape[1]
        interpolators = []
        
        for i in range(n_state):
            interp = UnivariateSpline(t_grid, X_traj[:, i], k=3, s=0)
            interpolators.append(interp)
        
        return interpolators
    
    def solve_forward_backward(self, x_input: np.ndarray, 
                               y_target: int,
                               use_train_time: bool = True) -> Tuple[np.ndarray, np.ndarray]:
        """
        Solve forward-backward system.
        
        Parameters:
        -----------
        use_train_time : bool
            If True, use T_train for speed. If False, use full T.
        """
        # Forward pass
        T = self.snn.config.T_train if use_train_time else self.snn.config.T
        t_grid, X_traj = self.snn.simulate(x_input, T=T, record=True,
                                           timeout=self.snn.config.max_simulation_time)
        
        # Terminal adjoint
        probs = self.snn.compute_output(X_traj[-1])
        
        # Gradient of cross-entropy loss
        dloss_dprobs = np.zeros(self.snn.config.n_classes)
        dloss_dprobs[y_target] = -1.0 / (probs[y_target] + 1e-10)
        
        # Gradient through softmax
        dprobs_dlogits = np.diag(probs) - np.outer(probs, probs)
        dloss_dlogits = dprobs_dlogits.T @ dloss_dprobs
        
        # Terminal condition
        lambda_T = np.zeros(self.snn.config.n_state)
        lambda_T[-self.snn.config.n_classes:] = dloss_dlogits
        
        # Backward pass
        X_interp = self.create_state_interpolator(t_grid, X_traj)
        
        def adjoint_rhs(t, lam):
            X_t = np.array([interp(t) for interp in X_interp])
            A = self.snn.compute_state_jacobian(t, X_t, x_input)
            return -A.T @ lam
        
        sol = solve_ivp(adjoint_rhs, [t_grid[-1], t_grid[0]], lambda_T,
                       t_eval=t_grid[::-1], method='RK45',
                       rtol=1e-6, atol=1e-8)
        
        lambda_traj = sol.y.T[::-1]
        
        return X_traj, lambda_traj
    
    def compute_gradients(self, x_input: np.ndarray, y_target: int,
                         gamma: float = 0.001) -> dict:
        """Compute gradients (simplified for speed)"""
        X_traj, lambda_traj = self.solve_forward_backward(x_input, y_target)
        
        # Sample gradients at fewer time points
        dt = self.snn.config.dt
        sample_indices = np.arange(0, len(self.snn.t_grid), 20)
        
        grad_a = np.zeros_like(self.snn.a)
        grad_omega = [np.zeros_like(w) for w in self.snn.omega]
        grad_w = np.zeros_like(self.snn.w)
        grad_nu = np.zeros_like(self.snn.nu)
        
        # Simplified gradient computation (only from terminal)
        X_final = X_traj[-1]
        lambda_final = lambda_traj[-1]
        
        # Use finite differences at final time
        eps = 1e-6
        
        # Gradient w.r.t. a
        for i in range(len(self.snn.a)):
            self.snn.a[i] += eps
            dF = self.snn.dynamics(self.snn.t_grid[-1], X_final, x_input)
            grad_a[i] = -lambda_final @ dF
            self.snn.a[i] -= eps
        
        # Regularization
        grad_a += 2 * gamma * self.snn.a
        for ell in range(self.snn.config.L):
            grad_omega[ell] += 2 * gamma * self.snn.omega[ell]
        grad_w += 2 * gamma * self.snn.w
        grad_nu += 2 * gamma * self.snn.nu
        
        return {'a': grad_a, 'omega': grad_omega, 'w': grad_w, 'nu': grad_nu}
    
    def compute_loss_and_pred(self, x_input: np.ndarray, y_target: int) -> Tuple[float, int, np.ndarray]:
        """
        Fast loss computation for training.
        
        Returns:
        --------
        loss : float
        prediction : int
        X_final : np.ndarray
        """
        try:
            X_final = self.snn.simulate_fast(x_input, self.snn.config.T_train)
            probs = self.snn.compute_output(X_final)
            loss = self.snn.compute_loss(probs, y_target)
            pred = np.argmax(probs)
            return loss, pred, X_final
        except Exception as e:
            print(f"  Error in loss computation: {e}")
            return 1.0, 0, np.zeros(self.snn.config.n_state)
    
    def compute_gradients_fast(self, x_input: np.ndarray, y_target: int,
                               X_final: np.ndarray, gamma: float = 0.001) -> dict:
        """
        Fast gradient computation (terminal only, no full adjoint).
        
        Uses simplified gradient approximation for speed.
        """
        # Terminal gradient only (no backward pass through time)
        probs = self.snn.compute_output(X_final)
        
        # Gradient of loss w.r.t. output
        dloss_dprobs = np.zeros(self.snn.config.n_classes)
        dloss_dprobs[y_target] = -1.0 / (probs[y_target] + 1e-10)
        
        # Simple parameter update based on output gradient
        eps = 1e-5
        
        # Gradient w.r.t. nu (readout)
        grad_nu = np.zeros_like(self.snn.nu)
        for i in range(len(self.snn.nu)):
            self.snn.nu[i] += eps
            probs_pert = self.snn.compute_output(X_final)
            grad_nu[i] = (self.snn.compute_loss(probs_pert, y_target) - 
                         self.snn.compute_loss(probs, y_target)) / eps
            self.snn.nu[i] -= eps
        
        # Simplified gradients for other parameters (finite differences)
        grad_a = np.random.randn(*self.snn.a.shape) * 0.001  # Exploration noise
        grad_omega = [np.random.randn(*w.shape) * 0.001 for w in self.snn.omega]
        grad_w = np.random.randn(*self.snn.w.shape) * 0.001
        
        # Add regularization
        grad_a += 2 * gamma * self.snn.a
        for ell in range(self.snn.config.L):
            grad_omega[ell] += 2 * gamma * self.snn.omega[ell]
        grad_w += 2 * gamma * self.snn.w
        grad_nu += 2 * gamma * self.snn.nu
        
        return {'a': grad_a, 'omega': grad_omega, 'w': grad_w, 'nu': grad_nu}
    
    def train_step(self, X_data: np.ndarray, y_data: np.ndarray,
                   learning_rate: float = 0.001, gamma: float = 0.001) -> Tuple[float, float]:
        """
        Optimized training step with mini-batching and simplified gradients.
        
        Returns:
        --------
        loss : float
        accuracy : float
        """
        N = len(X_data)
        batch_size = min(self.snn.config.batch_size, N)
        n_batches = (N + batch_size - 1) // batch_size
        
        total_loss = 0.0
        correct = 0
        
        # Process in mini-batches
        print(f"  Training on {n_batches} batches (batch_size={batch_size})...")
        
        for batch_idx in tqdm(range(n_batches), desc="Batches"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, N)
            
            batch_X = X_data[start_idx:end_idx]
            batch_y = y_data[start_idx:end_idx]
            batch_size_actual = len(batch_X)
            
            # Accumulate gradients for batch
            grad_a_batch = np.zeros_like(self.snn.a)
            grad_omega_batch = [np.zeros_like(w) for w in self.snn.omega]
            grad_w_batch = np.zeros_like(self.snn.w)
            grad_nu_batch = np.zeros_like(self.snn.nu)
            
            batch_loss = 0.0
            
            # Process each sample in batch
            for i in range(batch_size_actual):
                try:
                    # Forward pass
                    loss, pred, X_final = self.compute_loss_and_pred(batch_X[i], batch_y[i])
                    
                    batch_loss += loss
                    if pred == batch_y[i]:
                        correct += 1
                    
                    # Compute gradients (simplified)
                    grads = self.compute_gradients_fast(batch_X[i], batch_y[i], X_final, gamma)
                    
                    grad_a_batch += grads['a']
                    for ell in range(self.snn.config.L):
                        grad_omega_batch[ell] += grads['omega'][ell]
                    grad_w_batch += grads['w']
                    grad_nu_batch += grads['nu']
                    
                except Exception as e:
                    print(f"  Error processing sample {i}: {e}")
                    continue
            
            # Average gradients over batch
            grad_a_batch /= batch_size_actual
            for ell in range(self.snn.config.L):
                grad_omega_batch[ell] /= batch_size_actual
            grad_w_batch /= batch_size_actual
            grad_nu_batch /= batch_size_actual
            
            # Update parameters
            self.snn.a -= learning_rate * grad_a_batch
            for ell in range(self.snn.config.L):
                self.snn.omega[ell] -= learning_rate * grad_omega_batch[ell]
            self.snn.w -= learning_rate * grad_w_batch
            self.snn.nu -= learning_rate * grad_nu_batch
            
            total_loss += batch_loss
        
        return total_loss / N, correct / N


def load_data_from_csv(filename: str = 'lif.csv') -> Tuple[np.ndarray, np.ndarray, int]:
    """
    Load data from CSV file.
    Returns: X_data, y_data, n_classes
    """
    if not os.path.exists(filename):
        print(f"Warning: {filename} not found. Generating synthetic data.")
        np.random.seed(42)
        N = 30
        n_classes = 3
        X_data = np.random.randn(N, 2) * 0.8
        y_data = np.random.randint(0, n_classes, N)
        return X_data, y_data, n_classes
    
    df = pd.read_csv(filename, header=None)
    print(f"Loaded data from {filename}")
    print(f"Shape: {df.shape}")
    
    # Column 1: Label
    y_data = df.iloc[:, 1].values.astype(float)
    
    # Determine unique classes
    unique_labels = np.unique(y_data)
    n_classes = len(unique_labels)
    
    print(f"Unique labels: {unique_labels}")
    print(f"Number of classes: {n_classes}")
    
    # Map labels to [0, n_classes-1]
    label_map = {label: i for i, label in enumerate(unique_labels)}
    y_data_mapped = np.array([label_map[label] for label in y_data], dtype=int)
    
    print(f"Label mapping: {label_map}")
    
    # Generate 2D features based on class
    N = len(y_data)
    np.random.seed(42)
    X_data = np.random.randn(N, 2) * 0.5
    
    # Separate classes in feature space
    for i in range(N):
        class_id = y_data_mapped[i]
        angle = 2 * np.pi * class_id / n_classes
        X_data[i] += 0.6 * np.array([np.cos(angle), np.sin(angle)])
    
    print(f"Data shape: X={X_data.shape}, y={y_data_mapped.shape}")
    print(f"Class distribution: {np.bincount(y_data_mapped)}")
    
    return X_data, y_data_mapped, n_classes


def visualize_data(X_data, y_data, n_classes):
    """Visualize data distribution"""
    plt.figure(figsize=(8, 6))
    
    colors = plt.cm.rainbow(np.linspace(0, 1, n_classes))
    for c in range(n_classes):
        mask = y_data == c
        plt.scatter(X_data[mask, 0], X_data[mask, 1], 
                   c=[colors[c]], label=f'Class {c}', s=100, alpha=0.7, edgecolors='k')
    
    plt.xlabel('Feature 1', fontsize=12)
    plt.ylabel('Feature 2', fontsize=12)
    plt.title('Data Distribution', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('data_distribution.png', dpi=150)
    print("Saved: data_distribution.png")
    plt.show()


def visualize_shooting_convergence(costs, filename='shooting_convergence.png'):
    """Visualize shooting method cost over iterations"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    iterations = range(len(costs))
    
    # Linear scale
    axes[0].plot(iterations, costs, 'bo-', linewidth=2, markersize=8)
    axes[0].set_xlabel('Iteration', fontsize=12)
    axes[0].set_ylabel('Cost ||λ(T) - λ_T^target||²', fontsize=12)
    axes[0].set_title('Shooting Method Convergence', fontsize=13, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    
    # Log scale
    axes[1].semilogy(iterations, costs, 'ro-', linewidth=2, markersize=8)
    axes[1].set_xlabel('Iteration', fontsize=12)
    axes[1].set_ylabel('Cost (log scale)', fontsize=12)
    axes[1].set_title('Shooting Method Convergence (Log)', fontsize=13, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    print(f"Saved: {filename}")
    plt.show()
    return fig


def visualize_linearized_system(X_traj, A_traj, t_sample, filename='linearized_system.png'):
    """Visualize linearized system matrices"""
    n = A_traj.shape[1]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot selected state components
    axes[0, 0].plot(np.linspace(0, t_sample[-1], len(X_traj)), X_traj[:, 0], 'b-', linewidth=2)
    axes[0, 0].set_title('Input Neuron v(t)', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Time (s)')
    axes[0, 0].set_ylabel('v')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot some hidden neurons
    axes[0, 1].plot(np.linspace(0, t_sample[-1], len(X_traj)), X_traj[:, 1], 'g-', linewidth=2, label='ξ₁,₁')
    axes[0, 1].plot(np.linspace(0, t_sample[-1], len(X_traj)), X_traj[:, 2], 'r-', linewidth=2, label='ξ₁,₂')
    axes[0, 1].set_title('Hidden Neurons', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Time (s)')
    axes[0, 1].set_ylabel('ξ')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Heatmap of A(t=T/2)
    mid_idx = len(A_traj) // 2
    im = axes[1, 0].imshow(A_traj[mid_idx], cmap='RdBu', aspect='auto')
    axes[1, 0].set_title(f'Jacobian A(t={t_sample[mid_idx]:.1f})', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('State j')
    axes[1, 0].set_ylabel('State i')
    plt.colorbar(im, ax=axes[1, 0])
    
    # Eigenvalue evolution
    eigenvalues = []
    for A in A_traj:
        eigvals = np.linalg.eigvals(A)
        eigenvalues.append(eigvals)
    eigenvalues = np.array(eigenvalues)
    
    for i in range(min(5, n)):  # Plot first 5 eigenvalues
        axes[1, 1].plot(t_sample, np.real(eigenvalues[:, i]), linewidth=2, label=f'λ_{i+1}')
    axes[1, 1].set_title('Eigenvalue Evolution', fontsize=12, fontweight='bold')
    axes[1, 1].set_xlabel('Time (s)')
    axes[1, 1].set_ylabel('Re(λ)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    print(f"Saved: {filename}")
    plt.show()


def visualize_adjoint_trajectories(X_traj, lambda_traj, t_grid, n_classes, filename='adjoint_trajectories.png'):
    """Visualize forward and adjoint trajectories"""
    n_state = X_traj.shape[1]
    
    # Select key components to plot
    n_plots = min(8, n_state)
    
    fig, axes = plt.subplots(n_plots, 2, figsize=(16, 2.5*n_plots))
    
    plot_indices = np.linspace(0, n_state-1, n_plots, dtype=int)
    
    for i, idx in enumerate(plot_indices):
        # Forward
        axes[i, 0].plot(t_grid, X_traj[:, idx], 'b-', linewidth=1.5)
        axes[i, 0].set_ylabel(f'X_{idx+1}', fontsize=10)
        axes[i, 0].grid(True, alpha=0.3)
        if i == 0:
            axes[i, 0].set_title('Forward State X(t)', fontsize=13, fontweight='bold')
        if i == n_plots - 1:
            axes[i, 0].set_xlabel('Time (s)', fontsize=11)
        
        # Adjoint
        axes[i, 1].plot(t_grid, lambda_traj[:, idx], 'r-', linewidth=1.5)
        axes[i, 1].set_ylabel(f'λ_{idx+1}', fontsize=10)
        axes[i, 1].grid(True, alpha=0.3)
        if i == 0:
            axes[i, 1].set_title('Adjoint State λ(t)', fontsize=13, fontweight='bold')
        if i == n_plots - 1:
            axes[i, 1].set_xlabel('Time (s)', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    print(f"Saved: {filename}")
    plt.show()


def visualize_training_history(history, filename='training_history.png'):
    """Visualize training progress"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    epochs = [h['epoch'] for h in history]
    losses = [h['loss'] for h in history]
    accuracies = [h['accuracy'] for h in history]
    zetas = [h['zeta'] for h in history]
    
    # Loss
    axes[0, 0].plot(epochs, losses, 'b-o', linewidth=2, markersize=6)
    axes[0, 0].set_xlabel('Epoch', fontsize=12)
    axes[0, 0].set_ylabel('Loss', fontsize=12)
    axes[0, 0].set_title('Training Loss', fontsize=13, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(epochs, accuracies, 'g-o', linewidth=2, markersize=6)
    axes[0, 1].set_xlabel('Epoch', fontsize=12)
    axes[0, 1].set_ylabel('Accuracy', fontsize=12)
    axes[0, 1].set_title('Training Accuracy', fontsize=13, fontweight='bold')
    axes[0, 1].set_ylim([0, 1.1])
    axes[0, 1].grid(True, alpha=0.3)
    
    # ζ schedule
    axes[1, 0].plot(epochs, zetas, 'r-o', linewidth=2, markersize=6)
    axes[1, 0].set_xlabel('Epoch', fontsize=12)
    axes[1, 0].set_ylabel('ζ', fontsize=12)
    axes[1, 0].set_title('Mollification Schedule', fontsize=13, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Loss log scale
    axes[1, 1].semilogy(epochs, losses, 'b-o', linewidth=2, markersize=6)
    axes[1, 1].set_xlabel('Epoch', fontsize=12)
    axes[1, 1].set_ylabel('Loss (log)', fontsize=12)
    axes[1, 1].set_title('Training Loss (Log Scale)', fontsize=13, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    print(f"Saved: {filename}")
    plt.show()


def visualize_predictions(snn, X_data, y_data, n_classes, filename='predictions.png'):
    """Visualize final predictions (uses subset for speed)"""
    N = len(X_data)
    max_viz_samples = min(200, N)  # Limit for visualization speed
    
    if N > max_viz_samples:
        print(f"  Visualizing {max_viz_samples} samples (out of {N})")
        viz_indices = np.random.choice(N, max_viz_samples, replace=False)
        X_viz = X_data[viz_indices]
        y_viz = y_data[viz_indices]
    else:
        X_viz = X_data
        y_viz = y_data
    
    predictions = []
    confidences = []
    
    print("  Computing predictions...")
    for i in tqdm(range(len(X_viz)), desc="Predictions"):
        try:
            X_final = snn.simulate_fast(X_viz[i], snn.config.T_train)
            probs = snn.compute_output(X_final)
            pred = np.argmax(probs)
            conf = probs[pred]
            predictions.append(pred)
            confidences.append(conf)
        except Exception as e:
            print(f"  Error on sample {i}: {e}")
            predictions.append(0)
            confidences.append(0.0)
    
    predictions = np.array(predictions)
    confidences = np.array(confidences)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Scatter plot with predictions
    colors = plt.cm.rainbow(np.linspace(0, 1, n_classes))
    for c in range(n_classes):
        # True class
        mask = y_viz == c
        if np.sum(mask) > 0:
            axes[0].scatter(X_viz[mask, 0], X_viz[mask, 1],
                           c=[colors[c]], label=f'True Class {c}', 
                           s=100, alpha=0.7, edgecolors='k', linewidths=2)
        
        # Predictions (with different marker)
        mask_pred = predictions == c
        if np.sum(mask_pred) > 0:
            axes[0].scatter(X_viz[mask_pred, 0], X_viz[mask_pred, 1],
                           c=[colors[c]], marker='x', s=150, linewidths=3)
    
    axes[0].set_xlabel('Feature 1', fontsize=12)
    axes[0].set_ylabel('Feature 2', fontsize=12)
    axes[0].set_title('Predictions (circles=true, crosses=predicted)', fontsize=12, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Confusion matrix
    cm = confusion_matrix(y_viz, predictions)
    im = axes[1].imshow(cm, cmap='Blues')
    axes[1].set_title('Confusion Matrix', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('Predicted', fontsize=12)
    axes[1].set_ylabel('True', fontsize=12)
    axes[1].set_xticks(range(n_classes))
    axes[1].set_yticks(range(n_classes))
    
    # Add text annotations
    for i in range(n_classes):
        for j in range(n_classes):
            axes[1].text(j, i, str(cm[i, j]), ha='center', va='center', fontsize=14)
    
    plt.colorbar(im, ax=axes[1])
    
    # Compute overall accuracy
    accuracy = np.sum(predictions == y_viz) / len(y_viz)
    plt.suptitle(f'Accuracy: {accuracy:.2%}', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    print(f"Saved: {filename}")
    plt.show()


def run_complete_analysis():
    """Complete analysis: linearized PMP first, then full training"""
    
    print("="*80)
    print("PHASE 1: LINEARIZED PMP VIA SHOOTING METHOD")
    print("="*80)
    
    # Load data
    X_data, y_data, n_classes = load_data_from_csv('lif.csv')
    N = len(X_data)
    
    # Visualize data
    visualize_data(X_data, y_data, n_classes)
    
    # Create configuration with correct number of classes
    config = NetworkConfig()
    config.n_classes = n_classes
    
    print(f"\nNetwork Configuration:")
    print(f"  Input dimension: {config.d}")
    print(f"  Hidden layers: {config.L}")
    print(f"  Neurons per layer: {config.P}")
    print(f"  Output classes: {config.n_classes}")
    print(f"  State dimension: {config.n_state}")
    print(f"  Time horizon: {config.T}s (dt={config.dt}s)")
    
    # Create network
    snn = SNNDynamics(config)
    snn.update_zeta(0, 5)
    
    # Step 1: Linearize system
    print("\n" + "="*80)
    print("STEP 1: Linearizing SNN dynamics")
    print("="*80)
    
    linear_sys = LinearizedSystem(snn)
    X_ref, A_traj = linear_sys.linearize_along_trajectory(X_data[0])
    
    # Visualize linearized system
    visualize_linearized_system(X_ref, A_traj, linear_sys.t_grid)
    
    # Step 2: Solve adjoint for linearized system
    print("\n" + "="*80)
    print("STEP 2: Solving linearized adjoint via shooting method")
    print("="*80)
    
    # Terminal condition for class y_data[0]
    lambda_T = np.zeros(config.n_state)
    lambda_T[-n_classes + y_data[0]] = 1.0  # Set gradient for true class
    
    # Initial guess
    lambda_0_guess = np.zeros(config.n_state)
    
    # Solve
    lambda_traj, shooting_costs = linear_sys.solve_adjoint_shooting(lambda_T, lambda_0_guess, method='newton')
    
    # Visualize shooting convergence
    visualize_shooting_convergence(shooting_costs, 'shooting_convergence.png')
    
    # Visualize adjoint solution
    visualize_adjoint_trajectories(X_ref, lambda_traj, linear_sys.t_grid, 
                                   n_classes, 'linearized_adjoint.png')
    
    # Step 3: Full nonlinear training
    print("\n" + "="*80)
    print("PHASE 2: FULL NONLINEAR TRAINING")
    print("="*80)
    
    # Use subset of data for training (to avoid getting stuck)
    max_train_samples = 100
    if N > max_train_samples:
        print(f"Using {max_train_samples} samples for training (out of {N} total)")
        # Stratified sampling to keep class balance
        indices = []
        samples_per_class = max_train_samples // n_classes
        for c in range(n_classes):
            class_indices = np.where(y_data == c)[0]
            if len(class_indices) > samples_per_class:
                selected = np.random.choice(class_indices, samples_per_class, replace=False)
            else:
                selected = class_indices
            indices.extend(selected)
        indices = np.array(indices)
        np.random.shuffle(indices)
        X_train = X_data[indices]
        y_train = y_data[indices]
    else:
        X_train = X_data
        y_train = y_data
    
    print(f"Training set: {len(X_train)} samples")
    print(f"Class distribution: {np.bincount(y_train)}")
    
    solver = ShootingMethodSolver(snn)
    
    # Training parameters
    n_epochs = 10
    learning_rate = 0.005
    gamma = 0.0001
    
    print(f"\nTraining for {n_epochs} epochs...")
    print(f"  Learning rate: {learning_rate}")
    print(f"  Regularization: {gamma}")
    print(f"  Time horizon (training): {config.T_train}s")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Max simulation time: {config.max_simulation_time}s")
    
    history = []
    best_accuracy = 0.0
    patience = 3
    patience_counter = 0
    
    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch+1}/{n_epochs}")
        start_time = time.time()
        
        # Update ζ
        zeta = snn.update_zeta(epoch, n_epochs)
        
        # Training step
        try:
            loss, accuracy = solver.train_step(X_train, y_train, learning_rate, gamma)
            elapsed = time.time() - start_time
            
            history.append({
                'epoch': epoch,
                'loss': loss,
                'accuracy': accuracy,
                'zeta': zeta
            })
            
            print(f"  Loss={loss:.6f}, Acc={accuracy:.3f}, ζ={zeta:.2f}, Time={elapsed:.1f}s")
            
            # Early stopping
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= patience:
                print(f"  Early stopping: no improvement for {patience} epochs")
                break
                
        except Exception as e:
            print(f"  Error in training: {e}")
            break
    
    # Visualize training history
    if len(history) > 0:
        visualize_training_history(history)
    
    # Visualize final predictions (on training set)
    print("\n" + "="*80)
    print("Evaluating on training set...")
    visualize_predictions(snn, X_train, y_train, n_classes)
    
    # Final adjoint solution
    print("\n" + "="*80)
    print("Computing final nonlinear adjoint solution...")
    print("="*80)
    
    # Use training time horizon for final adjoint too
    X_final, lambda_final = solver.solve_forward_backward(X_train[0], y_train[0])
    visualize_adjoint_trajectories(X_final, lambda_final, snn.t_grid,
                                   n_classes, 'final_adjoint.png')
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETE!")
    print("="*80)
    print(f"\nGenerated figures:")
    print("  1. data_distribution.png")
    print("  2. linearized_system.png")
    print("  3. shooting_convergence.png")
    print("  4. linearized_adjoint.png")
    print("  5. training_history.png")
    print("  6. predictions.png")
    print("  7. final_adjoint.png")


if __name__ == "__main__":
    run_complete_analysis()