# High-Level Overview of the Spiking Neural Network Training Script
## Purpose:
This script is designed to develop a computer model that mimics how biological brains process visual information. Specifically, it focuses on recognizing handwritten digits (0–9) from the MNIST dataset—a standard collection of digit images used for testing image classifiers. The model uses principles inspired by real brain activity, such as neurons firing in response to visual signals, and it “learns” from feedback (rewards) to improve its accuracy over time.

## Key Components:

### Image Preprocessing (The Retina Module):

What It Does:
The script first takes a handwritten digit image and processes it in a way similar to the human retina. This means it reduces the image size and converts the visual information into “spike trains” (a series of electrical impulses), which represent how neurons in the eye might signal the brain.

Why It Matters:
This step ensures that the image is in a format that the simulated brain (our neural network) can understand and process.

Simulated Brain (Spiking Neural Network):

What It Does:
The network is built using thousands of simple units called neurons that simulate the behavior of brain cells. These neurons are organized into layers:

### Middle (Liquid) Layer: Processes the incoming signals from the retina and performs the heavy-lifting for pattern recognition.

### Output Layer: Makes the final decision on what digit is being shown by observing which neurons fire the most.

Why It Matters:
Unlike traditional artificial neural networks that use complex mathematics, this model uses spike-based communication—closer to how real biological brains operate—allowing for a more energy-efficient and potentially more adaptable system.

### Learning Mechanism (Actor–Critic Module & Power-STDP):

What It Does:
The script includes a built-in feedback loop. When the network makes a prediction (like saying "this is a 7") and the prediction is correct, it receives a reward. If the prediction is wrong, it gets penalized. This reward feedback is combined with a learning rule (called Power-STDP) that adjusts the connections between neurons to improve future predictions.

Why It Matters:
This learning process is analogous to how humans learn from feedback and mistakes, helping the model improve its accuracy with each example it processes.

### Training and Testing Loop:

What It Does:
The script automatically runs multiple rounds (or “epochs”) where the network processes many digit images. In each round, it adjusts its internal connections based on its performance. After training, the model is tested on a set of images to evaluate its accuracy.

Why It Matters:
Continuous training and testing allow the model to refine its predictions over time, leading to a system that can accurately recognize handwritten digits.

### GPU Acceleration for Speed:

What It Does:
The script is designed to use a Graphics Processing Unit (GPU), which dramatically speeds up the processing time compared to using a regular computer processor. Whether it's using Nvidia GPUs (via CUDA) or Apple’s M2 (via an equivalent backend), the script leverages high-performance hardware.

Why It Matters:
Faster processing means quicker training times, making it feasible to experiment with and improve the model iteratively.

## In Summary:
This script is a cutting-edge prototype that simulates brain-like processing for visual information. It takes images of handwritten digits, converts them to a brain-friendly format, processes them through a network of simulated neurons, and learns from feedback. The end goal is to build an efficient, biologically inspired classifier that can accurately recognize digits—an approach that could lead to low-power, highly adaptable computing systems.

In [1]:
# Cell 1: Imports & Device Setup
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Choose CUDA if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [2]:
# Cell 2: Configuration Dictionary
CONFIG = {
    "simulation": {"dt":1.0,"t_total":1200,"t_skip":200},
    "neuron": {"v_rest":0.0,"v_thresh":18.0,"v_reset":0.0},
    "middle_layer": {
        "num_neurons":5000, "ratio_exc":0.8,
        "tau_m_exc":20.0,"tau_m_inh":10.0,
        "refractory_exc":2.0,"refractory_inh":1.0
    },
    "output_layer": {"num_neurons":10,"tau_m":20.0,"refractory":2.0},
    "network": {"connection_prob":0.2},
    "actor_critic": {"num_critic_neurons":20,"gamma":0.99,"tau_r":20.0},
    "learning": {"learning_rate":0.001},
    "retina": {
        "input_size":(28,28),"pool_size":2,"stride":2,
        "spike_rate_scaling":50.0
    },
    "training": {"num_epochs":2,"batch_size":1}
}

In [3]:
# Cell 3: Retina Module
class Retina:
    def __init__(self, config):
        self.cfg = config["retina"]
        self.in_sz = self.cfg["input_size"]
        self.pool, self.stride = self.cfg["pool_size"], self.cfg["stride"]
        self.out_sz = (self.in_sz[0]//self.stride, self.in_sz[1]//self.stride)

    def process_image(self, img):
        pooled = np.zeros(self.out_sz, dtype=np.float32)
        for i in range(self.out_sz[0]):
            for j in range(self.out_sz[1]):
                window = img[i*self.stride:i*self.stride+self.pool,
                             j*self.stride:j*self.stride+self.pool]
                pooled[i,j] = window.mean()
        return pooled/255.0

    def generate_spike_train(self, rates, sim_t, dt):
        units = rates.size
        steps = int(sim_t/dt)
        train = np.zeros((units,steps),dtype=np.float32)
        scale = self.cfg["spike_rate_scaling"]
        for u in range(units):
            p = rates.flatten()[u]*scale*dt/1000.0
            train[u] = (np.random.rand(steps) < p).astype(np.float32)
        return torch.tensor(train,device=device)

In [4]:
# Cell 4: LIF Neuron Layer & Actor-Critic
class LayerGPU:
    def __init__(self, N, tau_m, v_rest, v_thresh, v_reset, refractory):
        self.N = N
        self.tau_m = tau_m.to(device)
        self.v = torch.full((N,),v_rest,device=device)
        self.v_thresh, self.v_reset = v_thresh, v_reset
        self.refractory = refractory.to(device)
        self.last_spike = torch.full((N,),-1e6,device=device)
        self.counts = torch.zeros((N,),device=device)

    def update(self, t, dt, I):
        ok = (t-self.last_spike)>=self.refractory
        dv = torch.zeros_like(self.v)
        dv[ok] = ((-self.v[ok]+I[ok])*dt/self.tau_m[ok])
        self.v += dv
        sp = self.v>=self.v_thresh
        if sp.any():
            self.last_spike[sp]=t
            self.v[sp]=self.v_reset
            self.counts[sp]+=1
        return sp.float()

    def reset(self):
        self.v.fill_(CONFIG["neuron"]["v_rest"])
        self.last_spike.fill_(-1e6)
        self.counts.zero_()

class ActorCriticGPU:
    def __init__(self, cfg):
        num = cfg["actor_critic"]["num_critic_neurons"]
        tau = torch.full((num,),20.0,device=device)
        ref = torch.full((num,),2.0,device=device)
        self.layer = LayerGPU(num,tau,
                              cfg["neuron"]["v_rest"],
                              cfg["neuron"]["v_thresh"],
                              cfg["neuron"]["v_reset"],
                              ref)
        self.last_val=0.0

    def compute_value(self):
        return self.layer.counts.mean().item()

    def update(self,reward,gamma):
        cur=self.compute_value()
        delta=reward+gamma*cur-self.last_val
        self.last_val=cur
        return delta

    def simulate(self,t,dt):
        self.layer.update(t,dt,torch.zeros(self.layer.N,device=device))

    def reset(self):
        self.layer.reset()
        self.last_val=0.0

In [5]:
# Cell 5: PSAC Network Definition
class PSACNetworkGPU:
    def __init__(self,cfg):
        self.cfg, self.dt = cfg, cfg["simulation"]["dt"]
        self.retina = Retina(cfg)
        # Middle
        M=cfg["middle_layer"]["num_neurons"]
        exc=int(M*cfg["middle_layer"]["ratio_exc"])
        tau = torch.empty(M,device=device)
        ref = torch.empty(M,device=device)
        tau[:exc]=cfg["middle_layer"]["tau_m_exc"]; ref[:exc]=cfg["middle_layer"]["refractory_exc"]
        tau[exc:]=cfg["middle_layer"]["tau_m_inh"]; ref[exc:]=cfg["middle_layer"]["refractory_inh"]
        self.middle=LayerGPU(M,tau,
                             cfg["neuron"]["v_rest"],
                             cfg["neuron"]["v_thresh"],
                             cfg["neuron"]["v_reset"],
                             ref)
        # Output
        O=cfg["output_layer"]["num_neurons"]
        tau_o = torch.full((O,),cfg["output_layer"]["tau_m"],device=device)
        ref_o = torch.full((O,),cfg["output_layer"]["refractory"],device=device)
        self.output=LayerGPU(O,tau_o,
                             cfg["neuron"]["v_rest"],
                             cfg["neuron"]["v_thresh"],
                             cfg["neuron"]["v_reset"],
                             ref_o)
        # Actor-Critic
        self.ac=ActorCriticGPU(cfg)
        # Connectivity
        p=cfg["network"]["connection_prob"]
        self.W=((torch.rand((M,O),device=device)<p).float()
                *torch.normal(0.5,0.1,(M,O),device=device))

    def run_sim(self,spikes,reward,verbose=False):
        T=spikes.shape[1]; dt=self.dt
        outc=torch.zeros(self.output.N,device=device)
        for step in range(T):
            t=step*dt
            rin=spikes[:,step].sum().item()
            I_mid=torch.full((self.middle.N,),rin*500.0,device=device)
            sp_mid=self.middle.update(t,dt,I_mid)
            I_out=(sp_mid.unsqueeze(0)@self.W).squeeze(0)*10.0
            sp_out=self.output.update(t,dt,I_out)
            if t>=self.cfg["simulation"]["t_skip"]:
                outc+=sp_out
            self.ac.simulate(t,dt)
            if sp_mid.sum()>0 and sp_out.sum()>0:
                d=self.ac.update(reward,self.cfg["actor_critic"]["gamma"])
                upd=torch.ger(sp_mid,sp_out)*self.cfg["learning"]["learning_rate"]*d
                self.W.add_(upd).clamp_(0,1.0)
            if verbose and step%200==0:
                print(f"t={int(t)}ms mid={int(sp_mid.sum())} out={int(sp_out.sum())}")
        pred=int(outc.argmax().item())
        return pred,outc.cpu().numpy()

    def reset(self):
        self.middle.reset(); self.output.reset(); self.ac.reset()

In [6]:
# Cell 6: Reward & Training Functions
def compute_reward(pred,label):
    return 1.0 if pred==label else -1.0

def train_network():
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x:x.squeeze(0)*255)
    ])
    ds=torchvision.datasets.MNIST(root="data",train=True,download=True,transform=transform)
    loader=DataLoader(ds,batch_size=CONFIG["training"]["batch_size"],shuffle=True)
    retina,net=Retina(CONFIG),PSACNetworkGPU(CONFIG)
    total,correct=0,0
    for ep in range(CONFIG["training"]["num_epochs"]):
        print(f"--- Epoch {ep+1}/{CONFIG['training']['num_epochs']} ---")
        for i,(img,lbl) in enumerate(loader):
            arr=img.squeeze(0).numpy().astype(np.uint8)
            rates=retina.process_image(arr)
            spikes=retina.generate_spike_train(rates,CONFIG["simulation"]["t_total"],CONFIG["simulation"]["dt"])
            net.reset()
            p,_ = net.run_sim(spikes,0.0,verbose=(i<2))
            r=compute_reward(p,lbl.item())
            net.reset()
            p,_ = net.run_sim(spikes,r)
            total+=1; correct+=(p==lbl.item())
            if (i+1)%100==0:
                print(f"Sample {i+1} True={lbl.item()} Pred={p} Acc={correct/total:.2f}")
        print(f"Epoch {ep+1} Acc={correct/total:.2f}")
    print(f"Training complete. Final Acc={correct/total:.2f}")
    return net

In [7]:
# Cell 7: Evaluation Function
def evaluate_network(net):
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x:x.squeeze(0)*255)
    ])
    ds=torchvision.datasets.MNIST(root="data",train=False,download=True,transform=transform)
    loader=DataLoader(ds,batch_size=1,shuffle=False)
    retina=Retina(CONFIG)
    total,correct=0,0
    print(f"▶️ Evaluating over {len(ds)} samples")
    for idx,(img,lbl) in enumerate(loader):
        arr=img.squeeze(0).numpy().astype(np.uint8)
        rates=retina.process_image(arr)
        spikes=retina.generate_spike_train(rates,CONFIG["simulation"]["t_total"],CONFIG["simulation"]["dt"])
        net.reset()
        p,_=net.run_sim(spikes,0.0)
        total+=1; correct+=(p==lbl.item())
        if total%100==0:
            print(f"[{total}/{len(ds)}] True={lbl.item()} Pred={p} Acc={correct/total:.2f}")
    print(f"✅ Final Eval Acc={correct/total:.2f}")

In [8]:
# Cell 8: Run Training & Evaluation
trained_net = train_network()
evaluate_network(trained_net)

--- Epoch 1/2 ---
t=0ms mid=5000 out=10
t=200ms mid=1000 out=0
t=400ms mid=5000 out=10
t=600ms mid=0 out=0
t=800ms mid=0 out=0
t=1000ms mid=0 out=0
t=0ms mid=0 out=0
t=200ms mid=5000 out=10
t=400ms mid=5000 out=10
t=600ms mid=1000 out=0
t=800ms mid=0 out=0
t=1000ms mid=0 out=0


KeyboardInterrupt: 