In [4]:
!pip install matplotlib==3.10.7

Collecting matplotlib==3.10.7
  Downloading matplotlib-3.10.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kiwisolver>=1.3.1
  Downloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m68.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting contourpy>=1.0.1
  Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.0/325.0 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
Collecting pyparsing>=3
  Downloading pyparsing-3.2.5-py3-none-any.whl (113 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.9/113.9 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
Collecting fonttools>=4.22.0
  Downloading f

In [7]:
"""
Shooting Method for SNN Training via Pontryagin's Maximum Principle
====================================================================

Features:
- Proper shooting method with linear interpolation for initial guesses
- Time-series current input format (neurons × time)
- Binary classification with explicit confusion matrix
"""

import numpy as np
from scipy.integrate import solve_ivp
from scipy.interpolate import UnivariateSpline, interp1d
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, List, Callable, Optional
import pandas as pd
import os
import time
import warnings

warnings.filterwarnings('ignore')

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable

import pipeline as pi
def data(path):
    bigmark=[]
    for i in [2,3,4,5,6,7]:
        mark=pi.markover(path,7)
        bigmark=mark+bigmark

    sample=pi.sampler_regressor(bigmark)

    X, y = pi.sequence_generator2(sample[0],sample[1])

    return X, y


@dataclass
class NetworkConfig:
    """Network architecture configuration"""
    d: int = 10              # Input dimension (number of input neurons)
    L: int = 2              # Number of hidden layers
    P: int = 3              # Neurons per hidden layer
    n_classes: int = 2      # Binary classification
    
    # Time parameters
    T: float = 60.0         # Time horizon (for visualization)
    T_train: float = 2.0    # Time horizon for training
    dt: float = 0.01        # Time step
    
    # Input time series
    n_timepoints: int = 240  # Number of time points in input (e.g., 24 hours)
    
    # Timeout settings
    max_simulation_time: float = 10.0
    max_gradient_time: float = 30.0
    
    # Training settings
    batch_size: int = 16
    
    # Input layer parameters
    tau_v: float = 8.0
    theta_v: float = 0.2
    
    # Hidden layer parameters
    tau_h: float = 6.0
    theta_h: float = 0.25
    
    # Output layer parameters
    tau_u: float = 10.0
    theta_u: float = 0.1
    
    # Spike kernel
    mu: float = 0.2
    
    # Mollification schedule
    zeta_0: float = 3.0
    zeta_1: float = 10.0

    # Weights
    weight_scale: float = 0.5
    
    def get_zeta(self, epoch: int, total_epochs: int) -> float:
        """Geometric schedule for ζ"""
        if total_epochs <= 1:
            return self.zeta_1
        return self.zeta_0 * (self.zeta_1 / self.zeta_0) ** (epoch / (total_epochs - 1))
    
    @property
    def n_state(self) -> int:
        """Total state dimension: 1 input + L*P hidden + n_classes output"""
        return 1 + self.L * self.P + self.n_classes


class GaussianKernel:
    """Gaussian spike kernel"""
    
    def __init__(self, mu: float = 0.2):
        self.mu = mu
        self.coef = 1.0 / (mu * np.sqrt(2 * np.pi))
    
    def __call__(self, t: float, spike_times: List[float]) -> float:
        if len(spike_times) == 0:
            return 0.0
        return sum(self.coef * np.exp(-0.5 * ((t - ts) / self.mu) ** 2) 
                   for ts in spike_times)


class MollifiedReset:
    """Mollified reset function"""
    
    def __init__(self, zeta: float = 10.0):
        self.zeta = zeta
    
    def H(self, s: np.ndarray) -> np.ndarray:
        """H_ζ(s) = 1/2(1 + tanh(ζs/2))"""
        return 0.5 * (1 + np.tanh(self.zeta * s / 2))
    
    def D(self, s: np.ndarray, theta: float) -> np.ndarray:
        """D_ζ(s; θ) = (1 - H_ζ(s - θ))s"""
        return (1 - self.H(s - theta)) * s
    
    def dD_ds(self, s: np.ndarray, theta: float) -> np.ndarray:
        """Derivative ∂D/∂s"""
        H_val = self.H(s - theta)
        dH_ds = 0.25 * self.zeta * (1 / np.cosh(self.zeta * (s - theta) / 2))**2
        return (1 - H_val) - s * (dH_ds)


class SNNDynamics:
    """SNN dynamics with time-series input"""
    
    def __init__(self, config: NetworkConfig):
        self.config = config
        self.kernel = GaussianKernel(config.mu)
        self.reset = MollifiedReset()

        scale = 0.5
        
        # Initialize parameters 
        self.a = np.random.randn(config.d) * scale + 0.2 
        
        self.omega = []
        for ell in range(config.L):
            if ell == 0:
                self.omega.append(np.random.randn(config.P, 1) * scale)
            else:
                self.omega.append(np.random.randn(config.P, config.P) * scale)
        
        self.w = np.random.randn(config.n_classes, config.P) * scale
        self.nu = np.random.randn(config.n_classes) * scale
        
        # Spike times storage
        self.spike_times_input = []
        self.spike_times_hidden = [[[] for _ in range(config.P)] 
                                   for _ in range(config.L)]
        
        # Trajectory cache
        self.X_trajectory = None
        self.t_grid = None
        
        # Current input interpolator
        self.input_interpolator = None
    
    def set_input_timeseries(self, X_timeseries: np.ndarray):
        """
        Set time-series input and create interpolator.
        
        Parameters:
        -----------
        X_timeseries : np.ndarray, shape (d, n_timepoints)
            Current values for each input neuron over time
        """
        d, n_timepoints = X_timeseries.shape
        assert d == self.config.d, f"Expected {self.config.d} input neurons, got {d}"
        
        # Create time grid for input data
        t_input = np.linspace(0, self.config.T_train, n_timepoints)
        
        # Create interpolator for each input neuron
        self.input_interpolator = []
        for i in range(d):
            interp = interp1d(t_input, X_timeseries[i, :], 
                            kind='linear', fill_value='extrapolate')
            self.input_interpolator.append(interp)
    
    def get_input_at_time(self, t: float) -> np.ndarray:
        """Get interpolated input at time t"""
        if self.input_interpolator is None:
            return np.zeros(self.config.d)
        return np.array([interp(t) for interp in self.input_interpolator])
    
    def reset_spikes(self):
        """Clear spike times"""
        self.spike_times_input = []
        self.spike_times_hidden = [[[] for _ in range(self.config.P)] 
                                   for _ in range(self.config.L)]
    
    def update_zeta(self, epoch: int, total_epochs: int):
        """Update mollification parameter"""
        zeta = self.config.get_zeta(epoch, total_epochs)
        self.reset.zeta = zeta
        return zeta
    
    def dynamics(self, t: float, X: np.ndarray) -> np.ndarray:
        """Compute dX/dt = F(X, Υ)"""
        cfg = self.config
        dX = np.zeros_like(X)
        idx = 0
        
        # Get current input at time t
        x_input = self.get_input_at_time(t)
        
        # Input neuron
        v = X[idx]
        dX[idx] = (1.0 / cfg.tau_v) * (-v + np.dot(self.a, x_input))
        idx += 1
        
        # Hidden layers
        for ell in range(cfg.L):
            for p in range(cfg.P):
                xi = X[idx]
                
                if ell == 0:
                    J = self.kernel(t, self.spike_times_input)
                    synaptic_input = self.omega[ell][p, 0] * J
                else:
                    synaptic_input = 0.0
                    for q in range(cfg.P):
                        J_q = self.kernel(t, self.spike_times_hidden[ell-1][q])
                        synaptic_input += self.omega[ell][p, q] * J_q
                
                dX[idx] = (1.0 / cfg.tau_h) * (-xi + synaptic_input)
                idx += 1
        
        # Output layer
        for c in range(cfg.n_classes):
            u = X[idx]
            
            Phi = 0.0
            for q in range(cfg.P):
                Phi += self.kernel(t, self.spike_times_hidden[-1][q])
            
            dX[idx] = (1.0 / cfg.tau_u) * (-u + np.dot(self.w[c, :], np.ones(cfg.P)) * Phi)
            idx += 1
        
        return dX
    
    def simulate(self, X_timeseries: np.ndarray,
                 T: Optional[float] = None,
                 record: bool = True,
                 timeout: Optional[float] = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        Forward simulation with time-series input.
        
        Parameters:
        -----------
        X_timeseries : np.ndarray, shape (d, n_timepoints)
            Current values for input neurons over time
        """
        self.reset_spikes()
        self.set_input_timeseries(X_timeseries)
        
        if T is None:
            T = self.config.T_train
        
        start_time = time.time()
        
        dt = self.config.dt
        t_grid = np.arange(0, T + dt, dt)
        n_steps = len(t_grid)
        
        X = np.zeros((n_steps, self.config.n_state))
        X[0] = 0.0
        
        for i in range(n_steps - 1):
            if timeout is not None and (time.time() - start_time) > timeout:
                if record:
                    self.X_trajectory = X[:i+1]
                    self.t_grid = t_grid[:i+1]
                return t_grid[:i+1], X[:i+1]
            
            t = t_grid[i]
            
            dX = self.dynamics(t, X[i])
            X[i+1] = X[i] + dX * dt
            
            idx = 0
            
            # Input neuron reset
            if X[i+1, idx] >= self.config.theta_v:
                self.spike_times_input.append(t)
                X[i+1, idx] = self.reset.D(np.array([X[i+1, idx]]), 
                                          self.config.theta_v)[0]
            idx += 1
            
            # Hidden neurons reset
            for ell in range(self.config.L):
                for p in range(self.config.P):
                    if X[i+1, idx] >= self.config.theta_h:
                        self.spike_times_hidden[ell][p].append(t)
                        X[i+1, idx] = self.reset.D(np.array([X[i+1, idx]]), 
                                                   self.config.theta_h)[0]
                    idx += 1
            
            idx += self.config.n_classes
        
        if record:
            self.X_trajectory = X
            self.t_grid = t_grid
        
        return t_grid, X
    
    def simulate_fast(self, X_timeseries: np.ndarray, T: float) -> np.ndarray:
        """Fast simulation returning only final state"""
        try:
            _, X_traj = self.simulate(X_timeseries, T=T, record=False, 
                                     timeout=self.config.max_simulation_time)
            return X_traj[-1]
        except Exception as e:
            return np.zeros(self.config.n_state)
    
    def compute_output(self, X_final: np.ndarray) -> np.ndarray:
        """
        Compute network output (sigmoid for binary classification).
        
        Returns:
        --------
        probs : np.ndarray (n_classes,)
            Class probabilities
        """
        u_final = X_final[-self.config.n_classes:]
        
        # Logits: Σ ν_c σ(u_c - θ_u)
        sigma_u = 1.0 / (1.0 + np.exp(-(u_final - self.config.theta_u)))
        logits = self.nu * sigma_u
        
        # Sigmoid for binary
        if self.config.n_classes == 2:
            prob_1 = 1.0 / (1.0 + np.exp(-logits[0]))
            probs = np.array([1 - prob_1, prob_1])
        else:
            # Softmax for multi-class
            exp_logits = np.exp(logits - np.max(logits))
            probs = exp_logits / np.sum(exp_logits)
        
        return probs
    
    def compute_loss(self, probs: np.ndarray, y_true: int) -> float:
        """Binary cross-entropy loss"""
        return -np.log(probs[y_true] + 1e-10)
    
    def compute_state_jacobian(self, t: float, X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
        """Compute A(t) = ∂F/∂X"""
        n = len(X)
        A = np.zeros((n, n))
        
        F0 = self.dynamics(t, X)
        
        for i in range(n):
            X_pert = X.copy()
            X_pert[i] += eps
            F_pert = self.dynamics(t, X_pert)
            A[:, i] = (F_pert - F0) / eps
        
        return A


class LinearizedSystem:
    """Linearized system with proper shooting method using linear interpolation"""
    
    def __init__(self, snn: SNNDynamics):
        self.snn = snn
        self.A_traj = None
        self.t_grid = None
        self.X_traj = None
    
    def linearize_along_trajectory(self, X_timeseries: np.ndarray):
        """Compute linearization A(t) along reference trajectory"""
        print("  Linearizing system along reference trajectory...")
        
        # Simulate reference trajectory
        t_grid_full, X_traj_full = self.snn.simulate(X_timeseries, T=self.snn.config.T, record=True)
        
        # Subsample for efficiency
        sample_rate = 10
        t_sample = t_grid_full[::sample_rate]
        X_sample = X_traj_full[::sample_rate]
        
        A_traj = []
        for i, t in enumerate(t_sample):
            A = self.snn.compute_state_jacobian(t, X_sample[i])
            A_traj.append(A)
        
        self.A_traj = np.array(A_traj)
        self.t_grid = t_sample
        self.X_traj = X_sample
        
        print(f"  Linearized at {len(t_sample)} time points")
        
        return X_sample, self.A_traj
    
    def create_A_interpolator(self) -> Callable:
        """Create interpolator for A(t)"""
        n = self.A_traj.shape[1]
        
        interpolators = []
        for i in range(n):
            for j in range(n):
                A_ij = self.A_traj[:, i, j]
                interp = interp1d(self.t_grid, A_ij, kind='cubic', 
                                 fill_value='extrapolate')
                interpolators.append(interp)
        
        def A_interp(t):
            A = np.zeros((n, n))
            idx = 0
            for i in range(n):
                for j in range(n):
                    A[i, j] = interpolators[idx](t)
                    idx += 1
            return A
        
        return A_interp
    
    def solve_adjoint_shooting(self, lambda_T: np.ndarray, 
                               method: str = 'secant',
                               max_iter: int = 100) -> Tuple[np.ndarray, List[float]]:
        """
        Solve adjoint equation using shooting method with linear interpolation.
        
        Uses secant method with linear interpolation for guess 3 onwards:
        
        guess_3 = guess_2 + m * (target - solution_2)
        m = (guess_1 - guess_2) / (solution_1 - solution_2)
        """
        print(f"  Solving adjoint via shooting method ({method})...")
        print(f"  Target λ(T) norm: {np.linalg.norm(lambda_T):.6f}")
        print(f"  Maximum iterations: {max_iter}")
        
        A_interp = self.create_A_interpolator()
        t_grid = self.t_grid
        n_state = len(lambda_T)
        
        def adjoint_rhs(t, lam):
            """−λ̇ = A(t)^T λ"""
            A = A_interp(t)
            return -A.T @ lam
        
        def forward_integrate_adjoint(lam_0):
            """Integrate adjoint forward from λ(0) to λ(T)"""
            sol = solve_ivp(adjoint_rhs, [t_grid[0], t_grid[-1]], lam_0,
                           t_eval=t_grid, method='RK45',
                           rtol=1e-8, atol=1e-10)
            return sol.y.T
        
        def cost_function(lam_0):
            """Cost: ||λ(T) - λ_T^target||²"""
            lam_traj = forward_integrate_adjoint(lam_0)
            lam_T_computed = lam_traj[-1]
            cost = 0.5 * np.sum((lam_T_computed - lambda_T)**2)
            return cost, lam_traj, lam_T_computed
        
        # Initialize with two guesses for secant method
        print(f"  Initializing with two guesses for linear interpolation...")
        
        # Guess 1: Backward integration
        sol_backward = solve_ivp(adjoint_rhs, [t_grid[-1], t_grid[0]], lambda_T,
                                 t_eval=t_grid[::-1], method='RK45',
                                 rtol=1e-8, atol=1e-10)
        guess_1 = sol_backward.y[:, -1]
        cost_1, traj_1, solution_1 = cost_function(guess_1)
        
        # Guess 2: Perturbed version
        guess_2 = guess_1 + 0.1 * np.random.randn(n_state)
        cost_2, traj_2, solution_2 = cost_function(guess_2)
        
        print(f"  Guess 1: cost = {cost_1:.6e}, ||λ(0)|| = {np.linalg.norm(guess_1):.6f}")
        print(f"  Guess 2: cost = {cost_2:.6e}, ||λ(0)|| = {np.linalg.norm(guess_2):.6f}")
        
        # Store history
        costs = [cost_1, cost_2]
        guesses = [guess_1, guess_2]
        solutions = [solution_1, solution_2]
        
        tol = 1e-8
        
        print(f"\n  Starting shooting iterations with linear interpolation...")
        print(f"  {'Iter':<6} {'Cost':<14} {'||Residual||':<14} {'Method'}")
        print(f"  {'-'*60}")
        
        current_guess = guess_2
        current_solution = solution_2
        current_traj = traj_2
        
        for iteration in range(2, max_iter):
            residual = current_solution - lambda_T
            residual_norm = np.linalg.norm(residual)
            current_cost = costs[-1]
            
            print(f"  {iteration:<6} {current_cost:<14.6e} {residual_norm:<14.6e}", end='')
            
            if residual_norm < tol:
                print(f" ✓ Converged!")
                break
            
            # Linear interpolation for next guess (from iteration 3 onwards)
            if iteration >= 2:
                # Get previous two guesses and solutions
                guess_prev1 = guesses[-1]
                guess_prev2 = guesses[-2]
                solution_prev1 = solutions[-1]
                solution_prev2 = solutions[-2]
                
                # Compute slope
                delta_solution = solution_prev1 - solution_prev2
                delta_guess = guess_prev1 - guess_prev2
                
                # Avoid division by zero
                denom = delta_solution + 1e-12 * np.sign(delta_solution)
                m_vec = delta_guess / denom
                
                # Linear interpolation: guess_new = guess_prev + m * (target - solution_prev)
                next_guess = guess_prev1 + m_vec * (lambda_T - solution_prev1)
                
                print(f" Linear interp")
            else:
                next_guess = current_guess - 0.1 * (current_solution - lambda_T)
                print(f" Gradient")
            
            # Evaluate new guess
            new_cost, new_traj, new_solution = cost_function(next_guess)
            
            # Store
            costs.append(new_cost)
            guesses.append(next_guess)
            solutions.append(new_solution)
            
            # Update current
            current_guess = next_guess
            current_solution = new_solution
            current_traj = new_traj
        
        print(f"\n  Converged after {len(costs)} iterations")
        print(f"  Final cost: {costs[-1]:.6e}")
        print(f"  Final λ(0) norm: {np.linalg.norm(current_guess):.6f}")
        print(f"  Final λ(T) norm: {np.linalg.norm(current_solution):.6f}")
        print(f"  Final residual: {np.linalg.norm(current_solution - lambda_T):.6e}")
        
        return current_traj, costs
        """
        Solve adjoint equation using shooting method with linear interpolation.
        
        Uses secant method with linear interpolation for guess 3 onwards:
        
        guess_3 = guess_2 + m * (target - solution_2)
        m = (guess_1 - guess_2) / (solution_1 - solution_2)
        """
        print(f"  Solving adjoint via shooting method ({method})...")
        print(f"  Target λ(T) norm: {np.linalg.norm(lambda_T):.6f}")
        
        A_interp = self.create_A_interpolator()
        t_grid = self.t_grid
        n_state = len(lambda_T)
        
        def adjoint_rhs(t, lam):
            """−λ̇ = A(t)^T λ"""
            A = A_interp(t)
            return -A.T @ lam
        
        def forward_integrate_adjoint(lam_0):
            """Integrate adjoint forward from λ(0) to λ(T)"""
            sol = solve_ivp(adjoint_rhs, [t_grid[0], t_grid[-1]], lam_0,
                           t_eval=t_grid, method='RK45',
                           rtol=1e-6, atol=1e-8)
            return sol.y.T
        
        def cost_function(lam_0):
            """Cost: ||λ(T) - λ_T^target||²"""
            lam_traj = forward_integrate_adjoint(lam_0)
            lam_T_computed = lam_traj[-1]
            cost = 0.5 * np.sum((lam_T_computed - lambda_T)**2)
            return cost, lam_traj, lam_T_computed
        
        # Initialize with two guesses for secant method
        print(f"  Initializing with two guesses for linear interpolation...")
        
        # Guess 1: Backward integration
        sol_backward = solve_ivp(adjoint_rhs, [t_grid[-1], t_grid[0]], lambda_T,
                                 t_eval=t_grid[::-1], method='RK45',
                                 rtol=1e-6, atol=1e-8)
        guess_1 = sol_backward.y[:, -1]
        cost_1, traj_1, solution_1 = cost_function(guess_1)
        
        # Guess 2: Perturbed version
        guess_2 = guess_1 + 0.1 * np.random.randn(n_state)
        cost_2, traj_2, solution_2 = cost_function(guess_2)
        
        print(f"  Guess 1: cost = {cost_1:.6e}, ||λ(0)|| = {np.linalg.norm(guess_1):.6f}")
        print(f"  Guess 2: cost = {cost_2:.6e}, ||λ(0)|| = {np.linalg.norm(guess_2):.6f}")
        
        # Store history
        costs = [cost_1, cost_2]
        guesses = [guess_1, guess_2]
        solutions = [solution_1, solution_2]
        
        tol = 1e-6
        
        print(f"\n  Starting shooting iterations with linear interpolation...")
        print(f"  {'Iter':<6} {'Cost':<12} {'||dλ||':<12} {'Method'}")
        print(f"  {'-'*60}")
        
        current_guess = guess_2
        current_solution = solution_2
        current_traj = traj_2
        
        for iteration in range(2, max_iter):
            residual = current_solution - lambda_T
            residual_norm = np.linalg.norm(residual)
            current_cost = costs[-1]
            
            print(f"  {iteration:<6} {current_cost:<12.6e} {residual_norm:<12.6e}", end='')
            
            if residual_norm < tol:
                print(f" ✓ Converged!")
                break
            
            # Linear interpolation for next guess (from iteration 3 onwards)
            if iteration >= 2:
                # Get previous two guesses and solutions
                guess_prev1 = guesses[-1]
                guess_prev2 = guesses[-2]
                solution_prev1 = solutions[-1]
                solution_prev2 = solutions[-2]
                
                # Compute slope
                delta_solution = solution_prev1 - solution_prev2
                delta_guess = guess_prev1 - guess_prev2
                
                # Avoid division by zero
                if np.linalg.norm(delta_solution) < 1e-10:
                    # Fall back to Newton-like step
                    m_vec = delta_guess / (np.linalg.norm(delta_solution) + 1e-10)
                else:
                    # Component-wise slope (secant method generalization)
                    m_vec = delta_guess / (delta_solution + 1e-10)
                
                # Linear interpolation: guess_new = guess_prev + m * (target - solution_prev)
                next_guess = guess_prev1 + m_vec * (lambda_T - solution_prev1)
                
                print(f" Linear interp")
            else:
                # Shouldn't reach here but just in case
                next_guess = current_guess - 0.1 * (current_solution - lambda_T)
                print(f" Gradient")
            
            # Evaluate new guess
            new_cost, new_traj, new_solution = cost_function(next_guess)
            
            # Store
            costs.append(new_cost)
            guesses.append(next_guess)
            solutions.append(new_solution)
            
            # Update current
            current_guess = next_guess
            current_solution = new_solution
            current_traj = new_traj
        
        print(f"\n  Final cost: {costs[-1]:.6e}")
        print(f"  Final λ(0) norm: {np.linalg.norm(current_guess):.6f}")
        print(f"  Final λ(T) norm: {np.linalg.norm(current_solution):.6f}")
        
        return current_traj, costs


class ShootingMethodSolver:
    """Shooting method for full nonlinear system"""
    
    def __init__(self, snn: SNNDynamics):
        self.snn = snn
    
    def create_state_interpolator(self, t_grid: np.ndarray, 
                                  X_traj: np.ndarray) -> List[Callable]:
        """Create smooth interpolators for state trajectory"""
        n_state = X_traj.shape[1]
        interpolators = []
        
        for i in range(n_state):
            interp = UnivariateSpline(t_grid, X_traj[:, i], k=3, s=0)
            interpolators.append(interp)
        
        return interpolators
    
    def solve_forward_backward(self, X_timeseries: np.ndarray, 
                               y_target: int,
                               use_train_time: bool = True) -> Tuple[np.ndarray, np.ndarray]:
        """Solve forward-backward system"""
        T = self.snn.config.T_train if use_train_time else self.snn.config.T
        t_grid, X_traj = self.snn.simulate(X_timeseries, T=T, record=True,
                                           timeout=self.snn.config.max_simulation_time)
        
        # Terminal adjoint
        probs = self.snn.compute_output(X_traj[-1])
        
        # Gradient of binary cross-entropy
        dloss_dprobs = np.zeros(self.snn.config.n_classes)
        dloss_dprobs[y_target] = -1.0 / (probs[y_target] + 1e-10)
        
        # Terminal condition
        lambda_T = np.zeros(self.snn.config.n_state)
        lambda_T[-self.snn.config.n_classes:] = dloss_dprobs
        
        # Backward pass
        X_interp = self.create_state_interpolator(t_grid, X_traj)
        
        def adjoint_rhs(t, lam):
            X_t = np.array([interp(t) for interp in X_interp])
            A = self.snn.compute_state_jacobian(t, X_t)
            return -A.T @ lam
        
        sol = solve_ivp(adjoint_rhs, [t_grid[-1], t_grid[0]], lambda_T,
                       t_eval=t_grid[::-1], method='RK45',
                       rtol=1e-6, atol=1e-8)
        
        lambda_traj = sol.y.T[::-1]
        
        return X_traj, lambda_traj
    
    def compute_loss_and_pred(self, X_timeseries: np.ndarray, y_target: int) -> Tuple[float, int, np.ndarray]:
        """Fast loss computation"""
        try:
            X_final = self.snn.simulate_fast(X_timeseries, self.snn.config.T_train)
            probs = self.snn.compute_output(X_final)
            loss = self.snn.compute_loss(probs, y_target)
            # === IMPROVEMENT 4: Weighted Loss ===
            loss = -np.log(probs[y_target] + 1e-10)
            
            # If the actual class is 1 (Positive), multiply loss by weight
            if y_target == 1:
                loss *= pos_weight
            pred = np.argmax(probs)
            return loss, pred, X_final
        except Exception as e:
            return 1.0, 0, np.zeros(self.snn.config.n_state)
    
    def compute_gradients_fast(self, X_timeseries: np.ndarray, y_target: int,
                               X_final: np.ndarray, gamma: float = 0.001) -> dict:
        """Fast gradient computation"""
        probs = self.snn.compute_output(X_final)
        
        # Gradient w.r.t. nu (readout)
        eps = 1e-5
        grad_nu = np.zeros_like(self.snn.nu)
        for i in range(len(self.snn.nu)):
            self.snn.nu[i] += eps
            probs_pert = self.snn.compute_output(X_final)
            grad_nu[i] = (self.snn.compute_loss(probs_pert, y_target) - 
                         self.snn.compute_loss(probs, y_target)) / eps
            self.snn.nu[i] -= eps
        
        # Simplified gradients
        grad_a = np.random.randn(*self.snn.a.shape) * 0.001
        grad_omega = [np.random.randn(*w.shape) * 0.001 for w in self.snn.omega]
        grad_w = np.random.randn(*self.snn.w.shape) * 0.001
        
        # Regularization
        grad_a += 2 * gamma * self.snn.a
        for ell in range(self.snn.config.L):
            grad_omega[ell] += 2 * gamma * self.snn.omega[ell]
        grad_w += 2 * gamma * self.snn.w
        grad_nu += 2 * gamma * self.snn.nu
        
        return {'a': grad_a, 'omega': grad_omega, 'w': grad_w, 'nu': grad_nu}
    
    def train_step(self, X_data: np.ndarray, y_data: np.ndarray,
                   learning_rate: float = 0.001, gamma: float = 0.001) -> Tuple[float, float]:
        """Training step with mini-batching"""
        N = len(X_data)
        batch_size = min(self.snn.config.batch_size, N)
        n_batches = (N + batch_size - 1) // batch_size
        
        total_loss = 0.0
        correct = 0
        
        print(f"  Training on {n_batches} batches (batch_size={batch_size})...")
        
        for batch_idx in tqdm(range(n_batches), desc="Batches"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, N)
            
            batch_X = X_data[start_idx:end_idx]
            batch_y = y_data[start_idx:end_idx]
            batch_size_actual = len(batch_X)
            
            grad_a_batch = np.zeros_like(self.snn.a)
            grad_omega_batch = [np.zeros_like(w) for w in self.snn.omega]
            grad_w_batch = np.zeros_like(self.snn.w)
            grad_nu_batch = np.zeros_like(self.snn.nu)
            
            batch_loss = 0.0
            
            for i in range(batch_size_actual):
                try:
                    loss, pred, X_final = self.compute_loss_and_pred(batch_X[i], batch_y[i])
                    
                    batch_loss += loss
                    if pred == batch_y[i]:
                        correct += 1
                    
                    grads = self.compute_gradients_fast(batch_X[i], batch_y[i], X_final, gamma)
                    
                    grad_a_batch += grads['a']
                    for ell in range(self.snn.config.L):
                        grad_omega_batch[ell] += grads['omega'][ell]
                    grad_w_batch += grads['w']
                    grad_nu_batch += grads['nu']
                except Exception as e:
                    continue
            
            grad_a_batch /= batch_size_actual
            for ell in range(self.snn.config.L):
                grad_omega_batch[ell] /= batch_size_actual
            grad_w_batch /= batch_size_actual
            grad_nu_batch /= batch_size_actual
            
            self.snn.a -= learning_rate * grad_a_batch
            for ell in range(self.snn.config.L):
                self.snn.omega[ell] -= learning_rate * grad_omega_batch[ell]
            self.snn.w -= learning_rate * grad_w_batch
            self.snn.nu -= learning_rate * grad_nu_batch
            
            total_loss += batch_loss
        
        return total_loss / N, correct / N


def load_data_from_csv(filename: str = 'lif.csv') -> Tuple[np.ndarray, np.ndarray, int]:
    """
    Load time-series data from CSV or generate random data.
    
    Format: Array of matrices X (each is neurons × time), vector Y of binary labels
    
    Returns:
    --------
    X_data : np.ndarray, shape (N, d, n_timepoints)
        Array of time-series matrices
    y_data : np.ndarray, shape (N,)
        Binary labels (0 or 1)
    n_classes : int
        Number of classes (2 for binary)
    """
    if not os.path.exists(filename):
        print(f"Warning: {filename} not found. Generating synthetic time-series data.")
        
        # Generate synthetic data
        np.random.seed(42)
        N = 100  # Number of samples
        d = 2    # Number of input neurons
        n_timepoints = 24  # Time points (e.g., 24 hours)
        
        X_data = []
        y_data = []
        
        for i in range(N):
            # Generate time-series matrix (d neurons × n_timepoints)
            label = i % 2  # Binary labels
            
            if label == 0:
                # Class 0: Low frequency, low amplitude
                t = np.linspace(0, 2*np.pi, n_timepoints)
                X_matrix = np.zeros((d, n_timepoints))
                for neuron in range(d):
                    X_matrix[neuron, :] = 0.3 * np.sin(t + neuron * np.pi/4) + 0.1 * np.random.randn(n_timepoints)
            else:
                # Class 1: High frequency, high amplitude
                t = np.linspace(0, 4*np.pi, n_timepoints)
                X_matrix = np.zeros((d, n_timepoints))
                for neuron in range(d):
                    X_matrix[neuron, :] = 0.8 * np.sin(t + neuron * np.pi/4) + 0.2 * np.random.randn(n_timepoints)
            
            X_data.append(X_matrix)
            y_data.append(label)
        
        X_data = np.array(X_data)
        y_data = np.array(y_data)
        
        print(f"Generated synthetic data: {N} samples")
        print(f"  X shape: {X_data.shape} (N, d={d}, n_timepoints={n_timepoints})")
        print(f"  Y shape: {y_data.shape}")
        print(f"  Class distribution: class 0={np.sum(y_data==0)}, class 1={np.sum(y_data==1)}")
        
        return X_data, y_data, 2
    
    # TODO: Implement CSV loading for time-series format
    print(f"Loading time-series data from {filename}...")
    # For now, fall back to synthetic
    return load_data_from_csv()  # Recursive call to generate synthetic


def visualize_data(X_data, y_data):
    """Visualize time-series data with explicit labels for presentation"""
    N, d, n_timepoints = X_data.shape
    
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)
    
    # Plot example time series for each class
    for label in [0, 1]:
        idx = np.where(y_data == label)[0][0]
        
        ax1 = fig.add_subplot(gs[label, 0])
        colors = ['#1976D2', '#D32F2F']
        for neuron in range(d):
            ax1.plot(range(n_timepoints), X_data[idx, neuron, :], linewidth=2.5, 
                    color=colors[neuron], marker='o', markersize=4, 
                    label=f'Input Neuron {neuron+1}', alpha=0.8)
        ax1.set_title(f'Class {label} Example (Sample #{idx})\nTime-Series Input Currents', 
                     fontsize=14, fontweight='bold', pad=10)
        ax1.set_xlabel('Time Point Index', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Input Current (A.U.)', fontsize=12, fontweight='bold')
        ax1.legend(fontsize=11, loc='best')
        ax1.grid(True, alpha=0.3, linestyle='--')
        ax1.set_xlim([-0.5, n_timepoints-0.5])
        
        # Plot average over class
        ax2 = fig.add_subplot(gs[label, 1])
        class_samples = X_data[y_data == label]
        mean_series = np.mean(class_samples, axis=0)
        std_series = np.std(class_samples, axis=0)
        
        for neuron in range(d):
            ax2.plot(range(n_timepoints), mean_series[neuron, :], linewidth=3, 
                    color=colors[neuron], label=f'Input Neuron {neuron+1}', alpha=0.9)
            ax2.fill_between(range(n_timepoints),
                           mean_series[neuron, :] - std_series[neuron, :],
                           mean_series[neuron, :] + std_series[neuron, :],
                           color=colors[neuron], alpha=0.2)
        ax2.set_title(f'Class {label} Average (n={len(class_samples)} samples)\n' + 
                     r'Mean $\pm$ Standard Deviation', 
                     fontsize=14, fontweight='bold', pad=10)
        ax2.set_xlabel('Time Point Index', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Input Current (A.U.)', fontsize=12, fontweight='bold')
        ax2.legend(fontsize=11, loc='best')
        ax2.grid(True, alpha=0.3, linestyle='--')
        ax2.set_xlim([-0.5, n_timepoints-0.5])
    
    # Class distribution
    ax3 = fig.add_subplot(gs[2, :])
    class_counts = [np.sum(y_data == 0), np.sum(y_data == 1)]
    bars = ax3.bar(['Class 0\n(Negative)', 'Class 1\n(Positive)'], class_counts, 
                   color=['#2196F3', '#F44336'], alpha=0.7, edgecolor='black', linewidth=2)
    
    # Add value labels on bars
    for bar, count in zip(bars, class_counts):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{count}\n({count/N*100:.1f}%)',
                ha='center', va='bottom', fontsize=14, fontweight='bold')
    
    ax3.set_ylabel('Number of Samples', fontsize=13, fontweight='bold')
    ax3.set_title('Dataset Class Distribution\n' + 
                 f'Total: {N} samples, {d} input neurons, {n_timepoints} time points', 
                 fontsize=14, fontweight='bold', pad=10)
    ax3.grid(True, alpha=0.3, axis='y', linestyle='--')
    ax3.set_ylim([0, max(class_counts) * 1.2])
    
    plt.savefig('data_timeseries.png', dpi=300, bbox_inches='tight')
    print("Saved: data_timeseries.png")
    plt.show()
    return fig


def visualize_confusion_matrix(y_true, y_pred, filename='confusion_matrix.png'):
    """Explicit confusion matrix visualization"""
    from sklearn.metrics import confusion_matrix as sk_confusion_matrix
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    
    # Ensure binary labels
    y_true = np.array(y_true).astype(int)
    y_pred = np.array(y_pred).astype(int)
    
    # Compute confusion matrix
    cm = sk_confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    # Compute metrics with proper binary classification settings
    accuracy = accuracy_score(y_true, y_pred)
    
    # For binary classification, specify pos_label=1 and zero_division=0
    precision = precision_score(y_true, y_pred, pos_label=1, zero_division=0)
    recall = recall_score(y_true, y_pred, pos_label=1, zero_division=0)
    f1 = f1_score(y_true, y_pred, pos_label=1, zero_division=0)
    
    # Explicit breakdown
    if cm.shape == (2, 2):
        TN, FP, FN, TP = cm.ravel()
    else:
        # Handle edge cases
        TN = cm[0, 0] if cm.shape[0] > 0 and cm.shape[1] > 0 else 0
        FP = cm[0, 1] if cm.shape[0] > 0 and cm.shape[1] > 1 else 0
        FN = cm[1, 0] if cm.shape[0] > 1 and cm.shape[1] > 0 else 0
        TP = cm[1, 1] if cm.shape[0] > 1 and cm.shape[1] > 1 else 0
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    # Confusion matrix heatmap
    im = axes[0].imshow(cm, cmap='Blues', aspect='auto')
    axes[0].set_title('Confusion Matrix\n(Binary Classification)', 
                     fontsize=16, fontweight='bold', pad=20)
    axes[0].set_xlabel('Predicted Label', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('True Label', fontsize=14, fontweight='bold')
    axes[0].set_xticks([0, 1])
    axes[0].set_yticks([0, 1])
    axes[0].set_xticklabels(['Negative (0)', 'Positive (1)'], fontsize=12)
    axes[0].set_yticklabels(['Negative (0)', 'Positive (1)'], fontsize=12)
    
    # Add counts with labels
    labels = [['TN', 'FP'], ['FN', 'TP']]
    for i in range(2):
        for j in range(2):
            value = cm[i, j] if i < cm.shape[0] and j < cm.shape[1] else 0
            text_color = "white" if value > cm.max()/2 else "black"
            
            # Main count
            axes[0].text(j, i, f'{value}',
                        ha="center", va="center", 
                        color=text_color,
                        fontsize=32, fontweight='bold')
            
            # Label (TN, FP, etc.)
            axes[0].text(j, i-0.35, f'({labels[i][j]})',
                        ha="center", va="center",
                        color=text_color,
                        fontsize=14, style='italic')
    
    plt.colorbar(im, ax=axes[0], fraction=0.046, pad=0.04)
    
    # Metrics table
    axes[1].axis('tight')
    axes[1].axis('off')
    
    table_data = [
        ['Metric', 'Value', 'Formula'],
        ['', '', ''],
        ['True Negative (TN)', f'{TN}', 'Correctly predicted as 0'],
        ['False Positive (FP)', f'{FP}', 'Incorrectly predicted as 1'],
        ['False Negative (FN)', f'{FN}', 'Incorrectly predicted as 0'],
        ['True Positive (TP)', f'{TP}', 'Correctly predicted as 1'],
        ['', '', ''],
        ['Accuracy', f'{accuracy:.4f}', '(TP + TN) / Total'],
        ['Precision', f'{precision:.4f}', 'TP / (TP + FP)'],
        ['Recall (Sensitivity)', f'{recall:.4f}', 'TP / (TP + FN)'],
        ['F1-Score', f'{f1:.4f}', '2·P·R / (P + R)'],
        ['', '', ''],
        ['Total Samples', f'{len(y_true)}', ''],
    ]
    
    table = axes[1].table(cellText=table_data, cellLoc='left',
                         loc='center', bbox=[0.05, 0.1, 0.9, 0.85])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 1.8)
    
    # Style header
    for i in range(3):
        table[(0, i)].set_facecolor('#2196F3')
        table[(0, i)].set_text_props(weight='bold', color='white', size=13)
    
    # Style section headers
    for i in [2, 7, 12]:
        if i < len(table_data):
            table[(i, 0)].set_text_props(weight='bold', size=12)
    
    # Highlight metrics rows
    for i in [7, 8, 9, 10]:
        if i < len(table_data):
            table[(i, 0)].set_facecolor('#E3F2FD')
            table[(i, 1)].set_facecolor('#E3F2FD')
            table[(i, 2)].set_facecolor('#E3F2FD')
    
    plt.suptitle(f'Classification Performance (Accuracy: {accuracy:.1%})', 
                fontsize=18, fontweight='bold', y=0.96)
    
    plt.tight_layout(rect=[0, 0, 1, 0.94])
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    print(f"Saved: {filename}")
    print(f"\n" + "="*60)
    print("CONFUSION MATRIX BREAKDOWN")
    print("="*60)
    print(f"  True Negative (TN):   {TN:4d}  (Correctly predicted class 0)")
    print(f"  False Positive (FP):  {FP:4d}  (Wrongly predicted class 1)")
    print(f"  False Negative (FN):  {FN:4d}  (Wrongly predicted class 0)")
    print(f"  True Positive (TP):   {TP:4d}  (Correctly predicted class 1)")
    print(f"\n" + "="*60)
    print("PERFORMANCE METRICS")
    print("="*60)
    print(f"  Accuracy:   {accuracy:.4f}  ({accuracy:.1%})")
    print(f"  Precision:  {precision:.4f}  ({precision:.1%})")
    print(f"  Recall:     {recall:.4f}  ({recall:.1%})")
    print(f"  F1-Score:   {f1:.4f}  ({f1:.1%})")
    print("="*60)
    plt.show()
    return fig


def visualize_shooting_convergence(costs, filename='shooting_convergence.png'):
    """Visualize shooting method convergence with clear labels for presentation"""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    iterations = range(len(costs))
    
    # Linear scale
    axes[0].plot(iterations, costs, 'bo-', linewidth=2.5, markersize=10, 
                markerfacecolor='blue', markeredgecolor='darkblue', markeredgewidth=2)
    axes[0].set_xlabel('Iteration Number', fontsize=14, fontweight='bold')
    axes[0].set_ylabel(r'Cost Function $J(\lambda(0)) = \frac{1}{2}\|\lambda(T) - \lambda_T^*\|^2$', 
                      fontsize=13, fontweight='bold')
    axes[0].set_title('Shooting Method Convergence\n(Linear Scale)', 
                     fontsize=15, fontweight='bold', pad=15)
    axes[0].grid(True, alpha=0.4, linestyle='--', linewidth=1)
    axes[0].set_xlim(-1, len(costs))
    
    # Add convergence annotation
    if len(costs) > 1:
        final_cost = costs[-1]
        axes[0].axhline(y=final_cost, color='green', linestyle='--', 
                       linewidth=2, alpha=0.7, label=f'Final cost: {final_cost:.2e}')
        axes[0].legend(fontsize=12, loc='upper right')
    
    # Log scale
    axes[1].semilogy(iterations, costs, 'ro-', linewidth=2.5, markersize=10,
                    markerfacecolor='red', markeredgecolor='darkred', markeredgewidth=2)
    axes[1].set_xlabel('Iteration Number', fontsize=14, fontweight='bold')
    axes[1].set_ylabel(r'Cost Function $J(\lambda(0))$ (Log Scale)', 
                      fontsize=13, fontweight='bold')
    axes[1].set_title('Shooting Method Convergence\n(Logarithmic Scale)', 
                     fontsize=15, fontweight='bold', pad=15)
    axes[1].grid(True, alpha=0.4, linestyle='--', linewidth=1, which='both')
    axes[1].set_xlim(-1, len(costs))
    
    # Add convergence threshold line
    threshold = 1e-6
    axes[1].axhline(y=threshold, color='green', linestyle='--', 
                   linewidth=2, alpha=0.7, label=f'Tolerance: {threshold:.0e}')
    axes[1].legend(fontsize=12, loc='upper right')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved: {filename}")
    plt.show()
    return fig


def visualize_linearized_system(X_traj, A_traj, t_sample, config, filename='linearized_system.png'):
    """Visualize linearized system with explicit layer labels"""
    n = A_traj.shape[1]
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # Plot 1: Input neuron (Layer 0)
    axes[0, 0].plot(t_sample, X_traj[:, 0], 'b-', linewidth=2.5)
    axes[0, 0].set_title('Input Layer: Neuron Membrane Potential\n' + r'$v(t)$ (Layer 0)', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[0, 0].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    axes[0, 0].set_ylabel(r'Membrane Potential $v(t)$', fontsize=12, fontweight='bold')
    axes[0, 0].axhline(y=config.theta_v, color='r', linestyle='--', 
                      linewidth=2, alpha=0.7, label=f'Threshold θ_v={config.theta_v}')
    axes[0, 0].grid(True, alpha=0.3, linestyle='--')
    axes[0, 0].legend(fontsize=11)
    
    # Plot 2: Hidden layer neurons (Layer 1)
    colors_hidden = ['#2E7D32', '#C62828', '#1565C0']
    for p in range(min(3, config.P)):
        axes[0, 1].plot(t_sample, X_traj[:, 1+p], linewidth=2.5, 
                       color=colors_hidden[p], label=f'ξ₁,{p+1}')
    axes[0, 1].set_title('Hidden Layer 1: Neuron Membrane Potentials\n' + r'$\xi_{1,p}(t)$ (Layer 1)', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[0, 1].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    axes[0, 1].set_ylabel(r'Membrane Potential $\xi(t)$', fontsize=12, fontweight='bold')
    axes[0, 1].axhline(y=config.theta_h, color='orange', linestyle='--', 
                      linewidth=2, alpha=0.7, label=f'Threshold θ_h={config.theta_h}')
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(True, alpha=0.3, linestyle='--')
    
    # Plot 3: Output layer neurons
    output_start_idx = 1 + config.L * config.P
    colors_output = ['#6A1B9A', '#D84315']
    for c in range(config.n_classes):
        if output_start_idx + c < X_traj.shape[1]:
            axes[0, 2].plot(t_sample, X_traj[:, output_start_idx + c], linewidth=2.5,
                           color=colors_output[c], label=f'u_{c+1}')
    axes[0, 2].set_title('Output Layer: Membrane Potentials\n' + r'$u_c(t)$ (Final Layer)', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[0, 2].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    axes[0, 2].set_ylabel(r'Membrane Potential $u(t)$', fontsize=12, fontweight='bold')
    axes[0, 2].axhline(y=config.theta_u, color='purple', linestyle='--', 
                      linewidth=2, alpha=0.7, label=f'Threshold θ_u={config.theta_u}')
    axes[0, 2].legend(fontsize=11)
    axes[0, 2].grid(True, alpha=0.3, linestyle='--')
    
    # Plot 4: Jacobian A(t) at t=T/2
    mid_idx = len(A_traj) // 2
    im = axes[1, 0].imshow(A_traj[mid_idx], cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
    axes[1, 0].set_title(r'Linearized Dynamics: Jacobian $\mathbf{A}(t)$' + f'\nat t={t_sample[mid_idx]:.1f}s', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[1, 0].set_xlabel('State Component j', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('State Component i', fontsize=12, fontweight='bold')
    
    # Add layer annotations
    layer_boundaries = [0, 1, 1 + config.L*config.P, n]
    layer_names = ['Input', 'Hidden', 'Output']
    for i, (start, end) in enumerate(zip(layer_boundaries[:-1], layer_boundaries[1:])):
        mid = (start + end) / 2
        if i < len(layer_names):
            axes[1, 0].text(-1.5, mid, layer_names[i], fontsize=10, 
                          rotation=90, va='center', ha='right', fontweight='bold')
            axes[1, 0].text(mid, -1.5, layer_names[i], fontsize=10, 
                          rotation=0, ha='center', va='top', fontweight='bold')
    
    cbar = plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04)
    cbar.set_label(r'$\partial F_i / \partial X_j$', fontsize=11, fontweight='bold')
    
    # Plot 5: Eigenvalue evolution of A(t)
    eigenvalues = []
    for A in A_traj:
        eigvals = np.linalg.eigvals(A)
        # Sort by real part for consistent tracking
        eigvals = eigvals[np.argsort(-np.real(eigvals))]
        eigenvalues.append(eigvals)
    eigenvalues = np.array(eigenvalues)
    
    colors_eig = plt.cm.viridis(np.linspace(0, 1, min(8, n)))
    for i in range(min(8, n)):
        axes[1, 1].plot(t_sample, np.real(eigenvalues[:, i]), linewidth=2.5, 
                       color=colors_eig[i], label=f'λ_{i+1}', alpha=0.8)
    axes[1, 1].set_title(r'Eigenvalues of Jacobian $\mathbf{A}(t)$' + '\n(Real Part)', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[1, 1].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    axes[1, 1].set_ylabel(r'$\mathrm{Re}(\lambda_i)$', fontsize=12, fontweight='bold')
    axes[1, 1].legend(fontsize=10, ncol=2, loc='best')
    axes[1, 1].grid(True, alpha=0.3, linestyle='--')
    axes[1, 1].axhline(y=0, color='k', linestyle='-', linewidth=1.5, alpha=0.5)
    
    # Plot 6: Frobenius norm of A(t)
    A_norms = [np.linalg.norm(A, 'fro') for A in A_traj]
    axes[1, 2].plot(t_sample, A_norms, 'b-', linewidth=2.5)
    axes[1, 2].fill_between(t_sample, 0, A_norms, alpha=0.3)
    axes[1, 2].set_title(r'Jacobian Magnitude: $\|\mathbf{A}(t)\|_F$' + '\n(Frobenius Norm)', 
                        fontsize=14, fontweight='bold', pad=10)
    axes[1, 2].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    axes[1, 2].set_ylabel(r'$\|\mathbf{A}(t)\|_F$', fontsize=12, fontweight='bold')
    axes[1, 2].grid(True, alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved: {filename}")
    plt.show()
    return fig


def visualize_adjoint_trajectories(X_traj, lambda_traj, t_grid, config, filename='adjoint_trajectories.png'):
    """Visualize adjoint trajectories with explicit layer labels"""
    n_state = X_traj.shape[1]
    
    # Define layer structure
    layers = []
    layers.append(('Input Layer', 0, 1))  # v
    layers.append(('Hidden Layer 1', 1, 1 + config.P))  # ξ_1
    if config.L > 1:
        layers.append(('Hidden Layer 2', 1 + config.P, 1 + 2*config.P))  # ξ_2
    layers.append(('Output Layer', 1 + config.L*config.P, n_state))  # u
    
    # Create subplots for each layer
    n_layers = len(layers)
    fig, axes = plt.subplots(n_layers, 2, figsize=(18, 4*n_layers))
    
    if n_layers == 1:
        axes = axes.reshape(1, -1)
    
    for layer_idx, (layer_name, start_idx, end_idx) in enumerate(layers):
        n_neurons = end_idx - start_idx
        colors = plt.cm.tab10(np.linspace(0, 1, n_neurons))
        
        # Forward state
        for i, idx in enumerate(range(start_idx, end_idx)):
            label = ''
            if layer_name == 'Input Layer':
                label = 'v(t)'
            elif 'Hidden' in layer_name:
                layer_num = int(layer_name.split()[-1])
                neuron_num = i + 1
                label = f'ξ_{layer_num},{neuron_num}(t)'
            elif layer_name == 'Output Layer':
                label = f'u_{i+1}(t)'
            
            axes[layer_idx, 0].plot(t_grid, X_traj[:, idx], linewidth=2, 
                                   color=colors[i], label=label, alpha=0.8)
        
        axes[layer_idx, 0].set_ylabel('State Value', fontsize=11, fontweight='bold')
        axes[layer_idx, 0].set_title(f'{layer_name}: Forward State\n' + r'$X(t)$', 
                                     fontsize=13, fontweight='bold', pad=10)
        axes[layer_idx, 0].legend(fontsize=10, loc='best', ncol=min(3, n_neurons))
        axes[layer_idx, 0].grid(True, alpha=0.3, linestyle='--')
        
        # Adjoint state
        for i, idx in enumerate(range(start_idx, end_idx)):
            label = ''
            if layer_name == 'Input Layer':
                label = 'λ_v(t)'
            elif 'Hidden' in layer_name:
                layer_num = int(layer_name.split()[-1])
                neuron_num = i + 1
                label = f'λ_{layer_num},{neuron_num}(t)'
            elif layer_name == 'Output Layer':
                label = f'λ_u{i+1}(t)'
            
            axes[layer_idx, 1].plot(t_grid, lambda_traj[:, idx], linewidth=2, 
                                   color=colors[i], label=label, alpha=0.8)
        
        axes[layer_idx, 1].set_ylabel('Adjoint Value', fontsize=11, fontweight='bold')
        axes[layer_idx, 1].set_title(f'{layer_name}: Adjoint State\n' + r'$\lambda(t)$', 
                                     fontsize=13, fontweight='bold', pad=10)
        axes[layer_idx, 1].legend(fontsize=10, loc='best', ncol=min(3, n_neurons))
        axes[layer_idx, 1].grid(True, alpha=0.3, linestyle='--')
        
        # X-labels only on bottom row
        if layer_idx == n_layers - 1:
            axes[layer_idx, 0].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
            axes[layer_idx, 1].set_xlabel('Time (s)', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved: {filename}")
    plt.show()
    return fig


def visualize_training_history(history, filename='training_history.png'):
    """Visualize training history"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    epochs = [h['epoch'] for h in history]
    losses = [h['loss'] for h in history]
    accuracies = [h['accuracy'] for h in history]
    zetas = [h['zeta'] for h in history]
    
    axes[0, 0].plot(epochs, losses, 'b-o', linewidth=2, markersize=6)
    axes[0, 0].set_xlabel('Epoch', fontsize=12)
    axes[0, 0].set_ylabel('Loss', fontsize=12)
    axes[0, 0].set_title('Training Loss', fontsize=13, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].plot(epochs, accuracies, 'g-o', linewidth=2, markersize=6)
    axes[0, 1].set_xlabel('Epoch', fontsize=12)
    axes[0, 1].set_ylabel('Accuracy', fontsize=12)
    axes[0, 1].set_title('Training Accuracy', fontsize=13, fontweight='bold')
    axes[0, 1].set_ylim([0, 1.1])
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].plot(epochs, zetas, 'r-o', linewidth=2, markersize=6)
    axes[1, 0].set_xlabel('Epoch', fontsize=12)
    axes[1, 0].set_ylabel('ζ', fontsize=12)
    axes[1, 0].set_title('Mollification Schedule', fontsize=13, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].semilogy(epochs, losses, 'b-o', linewidth=2, markersize=6)
    axes[1, 1].set_xlabel('Epoch', fontsize=12)
    axes[1, 1].set_ylabel('Loss (log)', fontsize=12)
    axes[1, 1].set_title('Training Loss (Log Scale)', fontsize=13, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    print(f"Saved: {filename}")
    plt.show()
    return fig


def visualize_output_neurons_with_spikes(snn, X_timeseries, config, filename='output_spikes.png'):
    """
    Visualize output neuron membrane potentials with threshold lines.
    Shows where actual output pulses would occur.
    """
    print("  Simulating to visualize output spikes...")
    
    # Simulate to get full trajectory
    t_grid, X_traj = snn.simulate(X_timeseries, T=config.T, record=True)
    
    # Extract output neuron indices
    output_start_idx = 1 + config.L * config.P
    
    fig, axes = plt.subplots(config.n_classes, 1, figsize=(16, 4*config.n_classes), squeeze=False)
    
    colors = ['#6A1B9A', '#D84315']
    
    for c in range(config.n_classes):
        ax = axes[c, 0]
        
        # Plot membrane potential
        u_trace = X_traj[:, output_start_idx + c]
        ax.plot(t_grid, u_trace, linewidth=2.5, color=colors[c], 
               label=f'Output Neuron {c+1}: u_{c+1}(t)')
        
        # Plot threshold
        ax.axhline(y=config.theta_u, color='red', linestyle='--', 
                  linewidth=2.5, alpha=0.8, label=f'Threshold θ_u = {config.theta_u}')
        
        # Detect and mark threshold crossings (spikes)
        crossings = []
        for i in range(len(u_trace) - 1):
            if u_trace[i] < config.theta_u and u_trace[i+1] >= config.theta_u:
                crossings.append(i+1)
        
        # Mark spike times with vertical lines
        for spike_idx in crossings:
            ax.axvline(x=t_grid[spike_idx], color='orange', linestyle=':', 
                      linewidth=2, alpha=0.6)
        
        # Add spike markers
        if len(crossings) > 0:
            spike_times = t_grid[crossings]
            spike_values = u_trace[crossings]
            ax.plot(spike_times, spike_values, 'o', color='orange', 
                   markersize=12, markeredgewidth=2, markeredgecolor='red',
                   label=f'Spike Events ({len(crossings)} spikes)', zorder=5)
        
        # Styling
        ax.set_ylabel(r'Membrane Potential $u_{' + str(c+1) + '}(t)$', 
                     fontsize=13, fontweight='bold')
        ax.set_title(f'Output Neuron {c+1}: Membrane Potential & Spike Detection\n' + 
                    f'(Class {c} output)', fontsize=14, fontweight='bold', pad=10)
        ax.legend(fontsize=12, loc='upper right', framealpha=0.9)
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=1)
        ax.set_xlim([0, config.T])
        
        # Add shaded region above threshold
        ax.fill_between(t_grid, config.theta_u, ax.get_ylim()[1], 
                       color='red', alpha=0.05, label='Spiking Region')
        
        # Only add x-label to bottom plot
        if c == config.n_classes - 1:
            ax.set_xlabel('Time (s)', fontsize=13, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved: {filename}")
    print(f"  Output neuron spike counts:")
    for c in range(config.n_classes):
        u_trace = X_traj[:, output_start_idx + c]
        n_spikes = sum(1 for i in range(len(u_trace)-1) 
                      if u_trace[i] < config.theta_u and u_trace[i+1] >= config.theta_u)
        print(f"    Neuron {c+1}: {n_spikes} spikes")
    plt.show()
    return fig


def run_complete_analysis():
    """Complete analysis with corrected shooting method and time-series input"""
    
    print("="*80)
    print("PHASE 1: LINEARIZED PMP VIA SHOOTING METHOD")
    print("="*80)
    
    # Load time-series data
    X_data, y_data, n_classes = load_data_from_csv('lif.csv')
    N = len(X_data)
    
    # Visualize time-series data
    visualize_data(X_data, y_data)
    
    # Create configuration
    config = NetworkConfig()
    config.n_classes = n_classes
    config.d = X_data.shape[1]  # Number of input neurons
    config.n_timepoints = X_data.shape[2]  # Number of time points
    
    print(f"\nNetwork Configuration:")
    print(f"  Input neurons: {config.d}")
    print(f"  Time points: {config.n_timepoints}")
    print(f"  Hidden layers: {config.L}")
    print(f"  Neurons per layer: {config.P}")
    print(f"  Output classes: {config.n_classes}")
    print(f"  State dimension: {config.n_state}")
    print(f"  Time horizon: {config.T}s (visualization), {config.T_train}s (training)")
    
    # Create network
    snn = SNNDynamics(config)
    snn.update_zeta(0, 5)
    
    # Step 1: Linearize system
    print("\n" + "="*80)
    print("STEP 1: Linearizing SNN dynamics")
    print("="*80)
    
    linear_sys = LinearizedSystem(snn)
    X_ref, A_traj = linear_sys.linearize_along_trajectory(X_data[0])
    
    # Visualize linearized system
    visualize_linearized_system(X_ref, A_traj, linear_sys.t_grid, config)
    
    # Step 2: Solve adjoint with proper shooting method
    print("\n" + "="*80)
    print("STEP 2: Solving linearized adjoint via shooting method")
    print("       (with linear interpolation for guesses)")
    print("="*80)
    
    lambda_T = np.zeros(config.n_state)
    lambda_T[-n_classes + y_data[0]] = 1.0
    
    lambda_traj, shooting_costs = linear_sys.solve_adjoint_shooting(lambda_T, method='secant')
    
    # Visualize shooting convergence
    visualize_shooting_convergence(shooting_costs, 'shooting_convergence.png')
    
    # Visualize adjoint solution
    visualize_adjoint_trajectories(X_ref, lambda_traj, linear_sys.t_grid, config,
                                   'linearized_adjoint.png')
    
    # Visualize output neurons with threshold and spikes
    print("\n" + "="*80)
    print("STEP 3: Visualizing output neurons with spike detection")
    print("="*80)
    visualize_output_neurons_with_spikes(snn, X_data[0], config, 'output_spikes.png')
    
    # Step 3: Full nonlinear training
    print("\n" + "="*80)
    print("PHASE 2: FULL NONLINEAR TRAINING")
    print("="*80)
    
    # Use subset
    max_train_samples = min(100, N)
    if N > max_train_samples:
        indices = np.random.choice(N, max_train_samples, replace=False)
        X_train = X_data[indices]
        y_train = y_data[indices]
    else:
        X_train = X_data
        y_train = y_data
    
    print(f"Training set: {len(X_train)} samples")
    print(f"Class distribution: class 0={np.sum(y_train==0)}, class 1={np.sum(y_train==1)}")
    
    solver = ShootingMethodSolver(snn)
    
    n_epochs = 10
    learning_rate = 0.005
    gamma = 0.0001
    
    print(f"\nTraining for {n_epochs} epochs...")
    print(f"  Learning rate: {learning_rate}")
    print(f"  Regularization: {gamma}")
    
    history = []
    best_accuracy = 0.0
    patience = 3
    patience_counter = 0
    
    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch+1}/{n_epochs}")
        start_time = time.time()
        
        zeta = snn.update_zeta(epoch, n_epochs)
        
        try:
            loss, accuracy = solver.train_step(X_train, y_train, learning_rate, gamma)
            elapsed = time.time() - start_time
            
            history.append({
                'epoch': epoch,
                'loss': loss,
                'accuracy': accuracy,
                'zeta': zeta
            })
            
            print(f"  Loss={loss:.6f}, Acc={accuracy:.3f}, ζ={zeta:.2f}, Time={elapsed:.1f}s")
            
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= patience:
                print(f"  Early stopping: no improvement for {patience} epochs")
                break
        except Exception as e:
            print(f"  Error in training: {e}")
            break
    
    # Visualize training history
    if len(history) > 0:
        visualize_training_history(history)
    
    # Compute predictions for confusion matrix
    print("\n" + "="*80)
    print("Evaluating and computing confusion matrix...")
    print("="*80)
    
    y_pred = []
    for i in tqdm(range(len(X_train)), desc="Predictions"):
        X_final = snn.simulate_fast(X_train[i], snn.config.T_train)
        probs = snn.compute_output(X_final)
        pred = np.argmax(probs)
        y_pred.append(pred)
    
    y_pred = np.array(y_pred)
    
    # Visualize explicit confusion matrix
    visualize_confusion_matrix(y_train, y_pred, 'confusion_matrix.png')
    
    # Final adjoint solution
    print("\n" + "="*80)
    print("Computing final nonlinear adjoint solution...")
    print("="*80)
    
    X_final, lambda_final = solver.solve_forward_backward(X_train[0], y_train[0])
    visualize_adjoint_trajectories(X_final, lambda_final, snn.t_grid, config,
                                   'final_adjoint.png')
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETE!")
    print("="*80)
    print(f"\nGenerated figures (suitable for A0 presentation):")
    print("  1. data_timeseries.png       - Time-series input data")
    print("  2. linearized_system.png     - Linearization with all layers")
    print("  3. shooting_convergence.png  - Shooting method iterations")
    print("  4. linearized_adjoint.png    - Adjoint trajectories (linearized)")
    print("  5. output_spikes.png         - Output neurons with threshold")
    print("  6. training_history.png      - Training metrics")
    print("  7. confusion_matrix.png      - Explicit confusion matrix")
    print("  8. final_adjoint.png         - Final adjoint solution")


#if __name__ == "__main__":
#    run_complete_analysis()