# High-Level Overview of the Spiking Neural Network Simulation Script
This script implements a biologically inspired spiking neural network (SNN) based on recent research. The network mimics certain aspects of how the human brain processes visual information and learns from experience. Below is a brief, non-technical summary of its main components and functionality:

## 1. Input Processing (Retina Module)
Purpose:
Converts an input image (using MNIST as an example) into "spike trains"—a series of electrical pulses similar to signals sent by biological neurons.

### How it Works:
The image is processed via average pooling to reduce its resolution (mimicking the compression of visual data in the human retina). The resulting values are normalized and used to generate spike trains through a Poisson process.

## 2. Middle (Liquid) Layer
Purpose:
Processes the spike trains from the retina in a network of neurons that resemble those in the brain.

### Components:

Excitatory Neurons: Send signals that stimulate other neurons.

Inhibitory Neurons: Send signals that suppress activity.

### How it Works:
Neurons follow a Leaky Integrate-and-Fire (LIF) model. They accumulate incoming electrical signals over time and "fire" a spike when a threshold is reached, then reset to a resting state.

## 3. Output Layer (Decision Making)
### Purpose:
Makes the final classification decision.

### How it Works:
Each neuron in the output layer corresponds to a class (for MNIST, digits 0-9). The decision is made based on which output neuron fires the most over the simulation period.

## 4. Learning via Actor–Critic Module
Purpose:
Enables the network to learn and improve its performance over time.

### Components:

Spike-Timing Dependent Plasticity (STDP): Adjusts the strength of connections based on the timing of spikes.

Actor–Critic Reinforcement Learning: Uses a "critic" to assess and provide feedback on the network's decisions, enabling reinforcement and fine-tuning of synaptic connections.

### How it Works:
The network adjusts its internal wiring (i.e., synaptic weights) using a combination of STDP and reward-based feedback (via the actor–critic method).

## 5. GPU Acceleration with PyTorch
### Purpose:
To speed up computationally heavy tasks, particularly those involving large connectivity matrices.

### How it Works:
Key operations (like building and updating connectivity matrices) are performed using PyTorch. If a compatible GPU is available, these operations will run on the GPU, significantly reducing simulation time.

## 6. Modularity and Tunability
### Design Philosophy:
The code is organized using clear, object-oriented modules:

Each key component (Retina, Neuron, Layer, Actor–Critic, PSAC Network) is encapsulated in its own class.

All important parameters (e.g., time constants, thresholds, learning rates, connectivity probabilities) are stored in a central configuration section (CONFIG), making it easy to adjust and experiment with the model.

In [None]:
# %% [code]
"""
Extended GPU-accelerated Spiking Neural Network Training with PSAC for MNIST
Adapted for use in a Jupyter Notebook.
This cell defines all classes and functions and then runs the training and testing loops.
"""

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Set the device to CUDA if available (use "mps" for Apple M2 if needed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -------------------- Configuration --------------------
CONFIG = {
    "simulation": {
        "dt": 1.0,             # time-step (ms)
        "t_total": 1200,       # total simulation time per episode (ms)
        "t_skip": 200,         # initial period for stabilization (ms)
    },
    "neuron": {
        "v_rest": 0.0,         # resting potential (mV)
        "v_thresh": 18.0,      # spike threshold (mV)
        "v_reset": 0.0,        # reset potential (mV)
    },
    "middle_layer": {
        "num_neurons": 5000,
        "ratio_exc": 0.8,      # 80% excitatory neurons, 20% inhibitory
        "tau_m_exc": 20.0,     # membrane time constant for excitatory neurons (ms)
        "tau_m_inh": 10.0,     # membrane time constant for inhibitory neurons (ms)
        "refractory_exc": 2.0, # refractory period (ms) for excitatory neurons
        "refractory_inh": 1.0, # refractory period (ms) for inhibitory neurons
    },
    "output_layer": {
        "num_neurons": 10,     # MNIST: 10 classes
        "tau_m": 20.0,         # membrane time constant for output neurons (ms)
        "refractory": 2.0,     # refractory period for output neurons (ms)
    },
    "network": {
        "connection_prob": 0.2,   # probability for random connectivity from middle to output
    },
    "actor_critic": {
        "num_critic_neurons": 20,
        "gamma": 0.99,            # discount factor for future rewards
        "tau_r": 20.0,            # time constant for critic neurons (ms)
    },
    "learning": {
        "learning_rate": 0.001,   # learning rate for weight updates
    },
    "retina": {
        "input_size": (28, 28),   # MNIST images are 28x28
        "pool_size": 2,           # average pooling window size (2x2)
        "stride": 2,              # stride for pooling (results in 7x7 output)
        "spike_rate_scaling": 50.0  # multiplier to boost retina spike probability
    },
    "training": {
        "num_epochs": 2,          # Number of training epochs (set low for a prototype)
        "batch_size": 1,          # We'll process one image per episode for simplicity
    }
}

# -------------------- Retina Module --------------------
class Retina:
    def __init__(self, config=CONFIG):
        self.config = config
        self.input_size = config["retina"]["input_size"]
        self.pool_size = config["retina"]["pool_size"]
        self.stride = config["retina"]["stride"]
        self.output_dim = (self.input_size[0] // self.stride,
                           self.input_size[1] // self.stride)

    def process_image(self, image):
        """
        Downsample the input image using average pooling.
        Returns a 2D array of activations (normalized between 0 and 1).
        """
        h, w = self.input_size
        pool_h, pool_w = self.pool_size, self.pool_size
        out_h, out_w = self.output_dim
        pooled = np.zeros((out_h, out_w))
        for i in range(out_h):
            for j in range(out_w):
                window = image[i*self.stride : i*self.stride+pool_h,
                               j*self.stride : j*self.stride+pool_w]
                pooled[i, j] = np.mean(window)
        spike_rates = pooled / 255.0
        return spike_rates

    def generate_spike_train(self, spike_rates, simulation_time, dt):
        """
        Generate spike trains using a Poisson process.
        Returns a torch tensor of shape [num_units, num_time_steps] on device.
        """
        num_units = spike_rates.size
        num_steps = int(simulation_time / dt)
        spike_train = np.zeros((num_units, num_steps), dtype=np.float32)
        scaling = self.config["retina"].get("spike_rate_scaling", 1.0)
        for unit in range(num_units):
            rate = spike_rates.flatten()[unit]  # normalized [0, 1]
            probs = np.random.rand(num_steps)
            spikes = (probs < (rate * scaling * dt / 1000.0)).astype(np.float32)
            spike_train[unit] = spikes
        return torch.tensor(spike_train, device=device, dtype=torch.float32)

# -------------------- GPU-Accelerated Layer Class --------------------
class LayerGPU:
    def __init__(self, num_neurons, tau_m, v_rest, v_thresh, v_reset, refractory, device):
        """
        Vectorized layer of LIF neurons.
        tau_m and refractory are torch tensors of shape [num_neurons].
        """
        self.num_neurons = num_neurons
        self.device = device
        self.tau_m = tau_m.to(device)
        self.v = torch.full((num_neurons,), v_rest, device=device)
        self.v_rest = v_rest
        self.v_thresh = v_thresh
        self.v_reset = v_reset
        self.refractory = refractory.to(device)
        self.last_spike_time = torch.full((num_neurons,), -1e6, device=device)
        self.spike_counts = torch.zeros((num_neurons,), device=device)

    def update(self, t, dt, input_current):
        not_refractory = (t - self.last_spike_time) >= self.refractory
        dv = torch.zeros_like(self.v)
        dv[not_refractory] = ((-self.v[not_refractory] + input_current[not_refractory]) * dt /
                              self.tau_m[not_refractory])
        self.v = self.v + dv
        spiked = self.v >= self.v_thresh
        if spiked.any():
            self.last_spike_time[spiked] = t
            self.v[spiked] = self.v_reset
            self.spike_counts[spiked] += 1
        return spiked.float()

    def reset(self):
        self.v.fill_(self.v_rest)
        self.spike_counts.zero_()
        self.last_spike_time.fill_(-1e6)

# -------------------- GPU-Accelerated Actor-Critic Module --------------------
class ActorCriticGPU:
    def __init__(self, config, device):
        self.config = config
        num_critic = config["actor_critic"]["num_critic_neurons"]
        tau_m = torch.full((num_critic,), 20.0, device=device)
        refractory = torch.full((num_critic,), 2.0, device=device)
        self.critic_layer = LayerGPU(num_critic,
                                     tau_m,
                                     v_rest=config["neuron"]["v_rest"],
                                     v_thresh=config["neuron"]["v_thresh"],
                                     v_reset=config["neuron"]["v_reset"],
                                     refractory=refractory,
                                     device=device)
        self.last_value = 0.0

    def compute_value(self):
        return self.critic_layer.spike_counts.mean().item()

    def update(self, reward, gamma):
        current_value = self.compute_value()
        delta = reward + gamma * current_value - self.last_value
        self.last_value = current_value
        return delta

    def simulate(self, t, dt):
        input_current = torch.zeros((self.critic_layer.num_neurons,), device=device)
        self.critic_layer.update(t, dt, input_current)

    def reset(self):
        self.critic_layer.reset()
        self.last_value = 0.0

# -------------------- GPU-Accelerated PSAC Network --------------------
class PSACNetworkGPU:
    def __init__(self, config, device):
        self.config = config
        self.device = device
        self.dt = config["simulation"]["dt"]
        self.simulation_time = config["simulation"]["t_total"]

        self.retina = Retina(config=config)

        # Initialize Middle (Liquid) Layer
        num_middle = config["middle_layer"]["num_neurons"]
        num_exc = int(num_middle * config["middle_layer"]["ratio_exc"])
        tau_m = torch.empty(num_middle, device=device)
        refractory = torch.empty(num_middle, device=device)
        tau_m[:num_exc] = config["middle_layer"]["tau_m_exc"]
        refractory[:num_exc] = config["middle_layer"]["refractory_exc"]
        tau_m[num_exc:] = config["middle_layer"]["tau_m_inh"]
        refractory[num_exc:] = config["middle_layer"]["refractory_inh"]
        self.middle_layer = LayerGPU(num_middle,
                                     tau_m,
                                     v_rest=config["neuron"]["v_rest"],
                                     v_thresh=config["neuron"]["v_thresh"],
                                     v_reset=config["neuron"]["v_reset"],
                                     refractory=refractory,
                                     device=device)

        # Initialize Output Layer
        num_output = config["output_layer"]["num_neurons"]
        tau_m_output = torch.full((num_output,), config["output_layer"]["tau_m"], device=device)
        refractory_output = torch.full((num_output,), config["output_layer"]["refractory"], device=device)
        self.output_layer = LayerGPU(num_output,
                                     tau_m_output,
                                     v_rest=config["neuron"]["v_rest"],
                                     v_thresh=config["neuron"]["v_thresh"],
                                     v_reset=config["neuron"]["v_reset"],
                                     refractory=refractory_output,
                                     device=device)

        # Initialize Actor-Critic Module
        self.actor_critic = ActorCriticGPU(config, device)

        # Connectivity from Middle to Output Layer
        m = self.middle_layer.num_neurons
        n = self.output_layer.num_neurons
        self.middle_to_output = (torch.rand((m, n), device=device) < config["network"]["connection_prob"]).float()
        weights = torch.normal(0.5, 0.1, size=(m, n), device=device)
        self.middle_to_output = self.middle_to_output * weights

    def run_simulation(self, retina_spike_train, reward, verbose=False):
        dt = self.dt
        num_steps = retina_spike_train.shape[1]
        output_spike_counts = torch.zeros((self.output_layer.num_neurons,), device=device)

        for step in range(num_steps):
            t = step * dt
            retina_input = retina_spike_train[:, step].sum().item()
            feed_forward = retina_input * 500.0  # scaling parameter (tunable)
            middle_input = torch.full((self.middle_layer.num_neurons,), feed_forward, device=device)
            middle_spikes = self.middle_layer.update(t, dt, middle_input)

            output_input = torch.matmul(middle_spikes.unsqueeze(0), self.middle_to_output).squeeze(0)
            output_input = output_input * 10.0  # tuning parameter (tunable)
            output_spikes = self.output_layer.update(t, dt, output_input)

            if t >= self.config["simulation"]["t_skip"]:
                output_spike_counts += output_spikes

            self.actor_critic.simulate(t, dt)
            if middle_spikes.sum() > 0 and output_spikes.sum() > 0:
                delta = self.actor_critic.update(reward, self.config["actor_critic"]["gamma"])
                weight_update = self.config["learning"]["learning_rate"] * delta
                update_matrix = torch.ger(middle_spikes, output_spikes) * weight_update
                self.middle_to_output += update_matrix
                self.middle_to_output.clamp_(0, 1.0)

            if verbose and step % 100 == 0:
                print(f"Time {t:.1f} ms: Middle spikes sum {middle_spikes.sum().item()}, Output spikes sum {output_spikes.sum().item()}")

        predicted_class = torch.argmax(output_spike_counts).item()
        return predicted_class, output_spike_counts.cpu().numpy()

    def reset_state(self):
        self.middle_layer.reset()
        self.output_layer.reset()
        self.actor_critic.reset()

# -------------------- Helper: Reward Function --------------------
def compute_reward(predicted_class, true_label):
    return 1.0 if predicted_class == true_label else -1.0

# -------------------- Training Loop --------------------
def train_network():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.squeeze(0) * 255)
    ])
    train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)

    retina = Retina(CONFIG)
    network = PSACNetworkGPU(CONFIG, device)

    num_epochs = CONFIG["training"]["num_epochs"]
    total_samples = 0
    correct_samples = 0

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        for batch_idx, (image, label) in enumerate(train_loader):
            try:
                image_np = image.squeeze(0).numpy().astype(np.uint8)
                true_label = label.item()
                spike_rates = retina.process_image(image_np)
                retina_spike_train = retina.generate_spike_train(spike_rates, CONFIG["simulation"]["t_total"], CONFIG["simulation"]["dt"])
                network.reset_state()
                predicted_class, _ = network.run_simulation(retina_spike_train, reward=0.0, verbose=(batch_idx < 2))
                reward = compute_reward(predicted_class, true_label)
                network.reset_state()
                predicted_class, _ = network.run_simulation(retina_spike_train, reward=reward, verbose=False)
                total_samples += 1
                if predicted_class == true_label:
                    correct_samples += 1
                if (batch_idx + 1) % 100 == 0:
                    print(f"Sample {batch_idx+1}: True {true_label}, Predicted {predicted_class}, Running accuracy = {correct_samples/total_samples:.2f}")
            except Exception as e:
                print(f"Error processing sample index {batch_idx}. True label: {label.item()}.")
                raise e

        print(f"Epoch {epoch+1} Accuracy: {correct_samples/total_samples:.2f}")
    print(f"Final Training Accuracy: {correct_samples/total_samples:.2f}")
    return network

# -------------------- Testing Loop --------------------
def test_network():
    network = PSACNetworkGPU(CONFIG, device)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.squeeze(0) * 255)
    ])
    test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    retina = Retina(CONFIG)

    total = 0
    correct = 0
    print(f"Starting testing on {len(test_dataset)} samples.")
    for idx, (image, label) in enumerate(test_loader):
        try:
            image_np = image.squeeze(0).numpy().astype(np.uint8)
            true_label = label.item()
            spike_rates = retina.process_image(image_np)
            retina_spike_train = retina.generate_spike_train(spike_rates, CONFIG["simulation"]["t_total"], CONFIG["simulation"]["dt"])
            network.reset_state()
            predicted_class, _ = network.run_simulation(retina_spike_train, reward=0.0, verbose=False)
            total += 1
            if predicted_class == true_label:
                correct += 1
            if total % 100 == 0:
                print(f"[{total}/{len(test_dataset)}] Processed sample {total}: True label = {true_label}, Predicted = {predicted_class}, Running accuracy = {correct/total:.2f}")
        except Exception as e:
            print(f"Error processing sample index {idx}.")
            raise e
    acc = correct / total if total > 0 else 0.0
    print(f"Final Test Accuracy: {acc:.2f}")

# -------------------- Run Training and Testing --------------------
# In a notebook, you can run these cells interactively.
trained_network = train_network()
test_network(trained_network)

# Optionally, visualize the output spike counts for one sample from the test set.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.squeeze(0) * 255)
])
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
sample, label = test_dataset[0]
sample_np = sample.numpy().astype(np.uint8)
print("Visualizing output for test sample, True label:", label)
retina = Retina(CONFIG)
spike_rates = retina.process_image(sample_np)
retina_spike_train = retina.generate_spike_train(spike_rates, CONFIG["simulation"]["t_total"], CONFIG["simulation"]["dt"])
trained_network.reset_state()
predicted_class, output_counts = trained_network.run_simulation(retina_spike_train, reward=0.0, verbose=False)
print("Predicted class:", predicted_class)
plt.figure(figsize=(6, 4))
plt.bar(range(len(output_counts)), output_counts)
plt.xlabel("Output Neuron (Class)")
plt.ylabel("Spike Count")
plt.title("Output Layer Spike Counts")
plt.show()
