# High-Level Overview of the Spiking Neural Network Simulation Script
This script implements a biologically inspired spiking neural network (SNN) based on recent research. The network mimics certain aspects of how the human brain processes visual information and learns from experience. Below is a brief, non-technical summary of its main components and functionality:

## 1. Input Processing (Retina Module)
Purpose:
Converts an input image (using MNIST as an example) into "spike trains"—a series of electrical pulses similar to signals sent by biological neurons.

### How it Works:
The image is processed via average pooling to reduce its resolution (mimicking the compression of visual data in the human retina). The resulting values are normalized and used to generate spike trains through a Poisson process.

## 2. Middle (Liquid) Layer
Purpose:
Processes the spike trains from the retina in a network of neurons that resemble those in the brain.

### Components:

Excitatory Neurons: Send signals that stimulate other neurons.

Inhibitory Neurons: Send signals that suppress activity.

### How it Works:
Neurons follow a Leaky Integrate-and-Fire (LIF) model. They accumulate incoming electrical signals over time and "fire" a spike when a threshold is reached, then reset to a resting state.

## 3. Output Layer (Decision Making)
### Purpose:
Makes the final classification decision.

### How it Works:
Each neuron in the output layer corresponds to a class (for MNIST, digits 0-9). The decision is made based on which output neuron fires the most over the simulation period.

## 4. Learning via Actor–Critic Module
Purpose:
Enables the network to learn and improve its performance over time.

### Components:

Spike-Timing Dependent Plasticity (STDP): Adjusts the strength of connections based on the timing of spikes.

Actor–Critic Reinforcement Learning: Uses a "critic" to assess and provide feedback on the network's decisions, enabling reinforcement and fine-tuning of synaptic connections.

### How it Works:
The network adjusts its internal wiring (i.e., synaptic weights) using a combination of STDP and reward-based feedback (via the actor–critic method).

## 5. GPU Acceleration with PyTorch
### Purpose:
To speed up computationally heavy tasks, particularly those involving large connectivity matrices.

### How it Works:
Key operations (like building and updating connectivity matrices) are performed using PyTorch. If a compatible GPU is available, these operations will run on the GPU, significantly reducing simulation time.

## 6. Modularity and Tunability
### Design Philosophy:
The code is organized using clear, object-oriented modules:

Each key component (Retina, Neuron, Layer, Actor–Critic, PSAC Network) is encapsulated in its own class.

All important parameters (e.g., time constants, thresholds, learning rates, connectivity probabilities) are stored in a central configuration section (CONFIG), making it easy to adjust and experiment with the model.

In [None]:
# %% [code]
#!/usr/bin/env python3
"""
A complete spiking neural network simulation and PSAC learning framework
based on the paper “An accurate and fast learning approach in the biologically spiking neural network.”

This version:
  1. Loads the MNIST dataset from torchvision.
  2. Utilizes PyTorch to offload some computations (connectivity matrices) to the GPU.
  3. Maintains a modular, class-based code structure with clearly tunable parameters.
"""

import numpy as np
import matplotlib.pyplot as plt
import torch

# Check for GPU support
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on device:", device)

# Global configuration dictionary: modify these to tune the simulation.
CONFIG = {
    "simulation": {
        "dt": 1.0,             # time-step (ms)
        "t_total": 1200,       # total simulation time in ms
        "t_skip": 200,         # initial period skipped for stabilization (ms)
    },
    "neuron": {
        "tau_m_exc": 20.0,     # membrane time constant for excitatory neurons (ms)
        "tau_m_inh": 10.0,     # for inhibitory neurons (ms)
        "v_rest": 0.0,         # resting potential (mV)
        "v_thresh": 18.0,      # spike threshold (mV)
        "v_reset": 0.0,        # reset potential after spike (mV)
        "refractory_exc": 2.0, # refractory period (ms)
        "refractory_inh": 1.0, # refractory period for inhibitory neurons (ms)
    },
    "synapse": {
        "tau_rA": 1.0,         # excitatory synaptic rise time (ms)
        "tau_dA": 20.0,        # excitatory synaptic decay time (ms)
        "tau_rG": 1.0,         # inhibitory synaptic rise time (ms)
        "tau_dG": 1.0,         # inhibitory synaptic decay time (ms)
    },
    "network": {
        "num_middle_neurons": 5000,  # number of neurons in the middle (liquid) layer
        "ratio_exc": 0.8,            # ratio of excitatory neurons in the middle layer
        "ratio_inh": 0.2,            # ratio of inhibitory neurons in the middle layer
        "connection_prob": 0.2,      # connection probability (sparse connectivity)
        "max_distance": None,        # computed based on spatial layout; see below.
    },
    "output": {
        "num_classes": 10,      # number of output neurons (for MNIST, 10 classes)
        "simulation_window": 1000,  # ms over which output neurons’ firing rates are measured
    },
    "actor_critic": {
        "num_critic_neurons": 20,
        "gamma": 0.99,         # discount factor for future rewards
        "tau_r": 20.0,         # time constant scaling factor for the critic (ms)
    },
    "learning": {
        "learning_rate": 0.001,  # base learning rate for weight updates (STDP modulation)
        "stdp_window": 20.0,     # window (ms) for STDP temporal difference
    },
    "retina": {
        "input_size": (28, 28),  # input image size (MNIST)
        "pool_size": 2,          # pooling window size (e.g. 2x2)
        "stride": 2,             # stride for pooling
    }
}

# -------------------- Neuron, Synapse, and Layer Classes --------------------

class Neuron:
    def __init__(self, neuron_id, is_excitatory=True, config=CONFIG):
        self.id = neuron_id
        self.is_excitatory = is_excitatory
        self.config = config
        if self.is_excitatory:
            self.tau_m = config["neuron"]["tau_m_exc"]
            self.refractory_period = config["neuron"]["refractory_exc"]
        else:
            self.tau_m = config["neuron"]["tau_m_inh"]
            self.refractory_period = config["neuron"]["refractory_inh"]
        self.v = config["neuron"]["v_rest"]
        self.v_thresh = config["neuron"]["v_thresh"]
        self.v_reset = config["neuron"]["v_reset"]
        self.last_spike_time = -np.inf  # tracks last spike time (ms)
        self.spike_times = []           # records all spike times

    def update(self, t, dt, input_current):
        # Check for refractory period
        if (t - self.last_spike_time) < self.refractory_period:
            self.v = self.v_reset
            return False

        # Euler integration for LIF dynamics
        dv = (-self.v + input_current) * dt / self.tau_m
        self.v += dv

        # Spike if threshold is crossed
        if self.v >= self.v_thresh:
            self.spike(t)
            return True
        return False

    def spike(self, t):
        self.spike_times.append(t)
        self.last_spike_time = t
        self.v = self.v_reset  # reset membrane potential

# A minimal Synapse blueprint is provided. In this implementation, synaptic dynamics are managed via connectivity matrices.
class Synapse:
    def __init__(self, pre_neuron, post_neuron, weight, delay=1.0, is_excitatory=True, config=CONFIG):
        self.pre = pre_neuron
        self.post = post_neuron
        self.weight = weight
        self.delay = delay  # synaptic delay (ms)
        self.is_excitatory = is_excitatory
        self.config = config
        self.x = 0.0  # auxiliary variable for synaptic rise dynamics
        self.I = 0.0  # synaptic current variable
        self.tau_r = config["synapse"]["tau_rA"] if is_excitatory else config["synapse"]["tau_rG"]
        self.tau_d = config["synapse"]["tau_dA"] if is_excitatory else config["synapse"]["tau_dG"]

    def update(self, dt):
        # Euler integration for synaptic dynamics (simplified)
        dx = -self.x * dt / self.tau_r
        self.x += dx
        dI = (-self.I + self.x) * dt / self.tau_d
        self.I += dI

    def transmit_spike(self):
        # When a spike is transmitted, update the auxiliary variable
        self.x += self.weight  # additional scaling can be added as needed


class Layer:
    def __init__(self, size, is_excitatory_array, config=CONFIG, name="layer"):
        self.config = config
        self.name = name
        self.neurons = []
        for i in range(size):
            neuron = Neuron(i, is_excitatory=is_excitatory_array[i], config=config)
            self.neurons.append(neuron)
        self.size = size

    def update(self, t, dt, input_currents):
        """Update every neuron in the layer with its corresponding input current.
           Returns a list of indices for neurons that spiked at time t."""
        spikes = []
        for i, neuron in enumerate(self.neurons):
            if neuron.update(t, dt, input_currents[i]):
                spikes.append(i)
        return spikes


# -------------------- Retina Module --------------------

class Retina:
    def __init__(self, config=CONFIG):
        self.config = config
        self.input_size = config["retina"]["input_size"]
        self.pool_size = config["retina"]["pool_size"]
        self.stride = config["retina"]["stride"]
        # Determine output dimensions after pooling
        self.output_dim = (
            self.input_size[0] // self.stride,
            self.input_size[1] // self.stride
        )

    def process_image(self, image):
        """
        Process the image through average pooling to reduce dimensionality.
        The resulting array represents activation levels.
        """
        pooled = self.pool(image, self.pool_size, self.stride)
        # Normalize pooled image (assuming pixel values 0-255)
        spike_rates = pooled / 255.0  # yields a value between 0 and 1
        return spike_rates

    def pool(self, image, pool_size, stride):
        out_h = image.shape[0] // stride
        out_w = image.shape[1] // stride
        pooled = np.zeros((out_h, out_w))
        for i in range(out_h):
            for j in range(out_w):
                window = image[i*stride:i*stride+pool_size, j*stride:j*stride+pool_size]
                pooled[i, j] = np.mean(window)
        return pooled

    def generate_spike_trains(self, spike_rates, simulation_time, dt):
        """
        For each pooled unit, generate a spike train based on a Poisson process.
        Returns a dictionary: key: neuron index, value: list of spike times.
        """
        num_neurons = spike_rates.shape[0] * spike_rates.shape[1]
        spike_trains = {i: [] for i in range(num_neurons)}
        num_steps = int(simulation_time / dt)
        for step in range(num_steps):
            t = step * dt
            for i in range(spike_rates.shape[0]):
                for j in range(spike_rates.shape[1]):
                    idx = i * spike_rates.shape[1] + j
                    # Using the spike_rate (interpreted as Hz) to decide spike emission
                    if np.random.rand() < spike_rates[i, j] * dt / 1000.0:
                        spike_trains[idx].append(t)
        return spike_trains


# -------------------- Actor-Critic Module --------------------

class ActorCritic:
    def __init__(self, config=CONFIG):
        self.config = config
        self.num_neurons = config["actor_critic"]["num_critic_neurons"]
        # Create a critic layer (all excitatory for simplicity)
        self.critic_layer = Layer(self.num_neurons, [True]*self.num_neurons, config=config, name="Critic")
        self.last_value = 0.0

    def compute_value(self):
        """Estimate the state value as the average firing rate of the critic neurons."""
        rates = [len(neuron.spike_times) for neuron in self.critic_layer.neurons]
        value = np.mean(rates)
        return value

    def update(self, reward, gamma):
        """
        Compute the reward prediction error (delta) and update state.
        delta = r_{t+1} + gamma * V(s_{t+1}) - V(s_t)
        """
        current_value = self.compute_value()
        delta = reward + gamma * current_value - self.last_value
        self.last_value = current_value
        return delta

    def simulate(self, t, dt):
        """Update the critic neurons (with a dummy input current of 0)."""
        input_currents = [0.0] * self.num_neurons
        _ = self.critic_layer.update(t, dt, input_currents)


# -------------------- PSAC Network (Main) --------------------

class PSACNetwork:
    def __init__(self, config=CONFIG):
        self.config = config
        self.dt = config["simulation"]["dt"]

        # Build the retina module
        self.retina = Retina(config=config)
        self.input_dim = self.retina.output_dim
        self.num_input_neurons = self.input_dim[0] * self.input_dim[1]

        # Build the middle layer (liquid layer)
        num_middle = config["network"]["num_middle_neurons"]
        num_exc = int(num_middle * config["network"]["ratio_exc"])
        num_inh = num_middle - num_exc
        exc_flags = [True] * num_exc + [False] * num_inh
        np.random.shuffle(exc_flags)  # randomize neuron types
        self.middle_layer = Layer(num_middle, exc_flags, config=config, name="Middle")

        # Build the output layer (all excitatory for classification)
        num_classes = config["output"]["num_classes"]
        self.output_layer = Layer(num_classes, [True] * num_classes, config=config, name="Output")

        # Build the Actor-Critic module.
        self.actor_critic = ActorCritic(config=config)

        # Set up connectivity matrices between layers:
        # (1) From middle to output: create full connectivity with a probability mask.
        self.middle_to_output = self.create_connectivity(len(self.middle_layer.neurons),
                                                         len(self.output_layer.neurons), config)
        # (2) In the middle layer, use distance-based (exponential decay) connectivity.
        self.middle_layer_positions = self.initialize_positions(num_middle)
        self.config["network"]["max_distance"] = np.sqrt(
            (np.max(self.middle_layer_positions[:, 0]) - np.min(self.middle_layer_positions[:, 0])) ** 2 +
            (np.max(self.middle_layer_positions[:, 1]) - np.min(self.middle_layer_positions[:, 1])) ** 2
        )
        self.middle_layer_connectivity = self.create_spatial_connectivity(num_middle, config)

    def initialize_positions(self, num_neurons):
        """Assign each neuron a random 2D position for distance-dependent connectivity."""
        positions = np.random.rand(num_neurons, 2) * 100.0  # positions in the range [0, 100)
        return positions

    def create_connectivity(self, pre_size, post_size, config):
        connection_prob = config["network"]["connection_prob"]
        # Create connectivity matrix using PyTorch on the GPU if available.
        connectivity = (torch.rand(pre_size, post_size, device=device) < connection_prob).float()
        weights = torch.normal(mean=0.5, std=0.1, size=(pre_size, post_size), device=device)
        weights *= connectivity  # apply mask so only some connections exist
        return weights  # stays on the GPU

    def create_spatial_connectivity(self, num_neurons, config):
        connection_prob = config["network"]["connection_prob"]
        positions = self.middle_layer_positions
        connectivity = np.zeros((num_neurons, num_neurons))
        for i in range(num_neurons):
            for j in range(num_neurons):
                if i == j:
                    continue
                if np.random.rand() < connection_prob:
                    r = np.linalg.norm(positions[i] - positions[j])
                    D = config["network"]["max_distance"]
                    weight = np.exp(-r / D)
                    connectivity[i, j] = weight
        # Convert spatial connectivity to a torch tensor on the GPU.
        return torch.tensor(connectivity, dtype=torch.float32, device=device)

    def run_simulation(self, input_image, reward, verbose=False):
        """
        Run one simulation episode for a single image.
        Simulate the network dynamics over time, apply STDP/actor-critic updates,
        and return the predicted class (based on highest output spike count).
        """
        # Process the input image through the retina.
        spike_rates = self.retina.process_image(input_image)
        input_spike_trains = self.retina.generate_spike_trains(spike_rates, self.config["simulation"]["t_total"], self.dt)

        t_total = self.config["simulation"]["t_total"]
        dt = self.dt
        num_steps = int(t_total / dt)
        output_spike_counts = np.zeros(len(self.output_layer.neurons))

        # Main simulation loop.
        for step in range(num_steps):
            t = step * dt

            # --- Input layer processing ---
            input_currents = np.zeros(self.num_input_neurons)
            for neuron_idx, spike_times in input_spike_trains.items():
                if t in spike_times:
                    input_currents[neuron_idx] = 10.0  # tunable scale factor

            # Map input layer to middle layer: using a simple uniform projection.
            middle_input = np.zeros(len(self.middle_layer.neurons))
            middle_input += np.sum(input_currents) * 0.01

            # --- Middle layer update ---
            middle_spikes = self.middle_layer.update(t, dt, middle_input)

            # Propagate recurrent input within the middle layer using spatial connectivity.
            recurrent_input = np.zeros(len(self.middle_layer.neurons))
            for neuron_idx in middle_spikes:
                # Convert the corresponding row from GPU (torch) to numpy.
                recurrent_input += self.middle_layer_connectivity[neuron_idx, :].cpu().numpy()

            # --- Projection to Output layer ---
            output_input = np.zeros(len(self.output_layer.neurons))
            for i in middle_spikes:
                # Accumulate contributions from the middle-to-output connections.
                output_input += self.middle_to_output[i, :].cpu().numpy() * 10.0  # tunable factor
            output_spikes = self.output_layer.update(t, dt, output_input)

            # Record output spikes after stabilization period.
            if t >= self.config["simulation"]["t_skip"]:
                for idx in output_spikes:
                    output_spike_counts[idx] += 1

            # --- Actor-Critic update ---
            self.actor_critic.simulate(t, dt)
            # A placeholder STDP update: if middle and output neurons spiked, adjust weights modulated by RPE.
            if middle_spikes and output_spikes:
                delta = self.actor_critic.update(reward, self.config["actor_critic"]["gamma"])
                for m in middle_spikes:
                    for o in output_spikes:
                        # Update weight on the GPU, then clamp between 0 and 1.
                        self.middle_to_output[m, o] += self.config["learning"]["learning_rate"] * delta
                        self.middle_to_output[m, o] = torch.clamp(self.middle_to_output[m, o], 0, 1.0)

            if verbose and step % 100 == 0:
                print(f"Time {t:.1f} ms: Middle spikes: {len(middle_spikes)}, Output spikes: {len(output_spikes)}")

        # Determine the decision by taking the output neuron with the highest spike count.
        predicted_class = int(np.argmax(output_spike_counts))
        return predicted_class, output_spike_counts


# -------------------- Main Routine (Using MNIST) --------------------

# Import torchvision to load MNIST
from torchvision import datasets, transforms

# Define transform to convert image to tensor and scale it to [0, 255]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 255)
])

# Download and load the MNIST test dataset
mnist_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
sample_image_tensor, label = mnist_dataset[0]
dummy_image = sample_image_tensor.squeeze().numpy()  # Convert to (28,28) numpy array

print("True label of the selected MNIST sample:", label)

# Define a dummy reward (for example, 1.0 if correct classification)
reward = 1.0

# Instantiate the PSACNetwork
network = PSACNetwork(CONFIG)
predicted_class, output_counts = network.run_simulation(dummy_image, reward, verbose=True)

print("Predicted class:", predicted_class)
print("Output spike counts:", output_counts)

# Optionally, visualize the output spike counts
plt.figure(figsize=(6, 4))
plt.bar(range(len(output_counts)), output_counts)
plt.xlabel("Output Neuron (Class)")
plt.ylabel("Spike Count")
plt.title("Output Layer Spike Counts After Simulation")
plt.show()
