In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from brian2 import *

# Set random seed for reproducibility
np.random.seed(0)

# Simulation parameters
n_input = 784  # 28x28 Fashion-MNIST images
n_neurons = 100  # Number of excitatory neurons
sim_time = 50 * ms  # Simulation time
tau = 10 * ms  # Membrane time constant
threshold = -50 * mV  # Firing threshold
reset_potential = -65 * mV  # Reset potential
rest_potential = -65 * mV  # Resting potential
refractory_period = 5 * ms  # Refractory period
batch_size = 1000  # Number of samples per batch

# Load Fashion-MNIST dataset from CSV files
def load_fashion_mnist(train_path, test_path):
    train_data = pd.read_csv(train_path)
    test_data = pd.read_csv(test_path)

    train_labels = train_data.iloc[:, 0].values
    train_images = train_data.iloc[:, 1:].values.astype(np.float32) / 255.0  # Normalize to [0,1]

    test_labels = test_data.iloc[:, 0].values
    test_images = test_data.iloc[:, 1:].values.astype(np.float32) / 255.0  # Normalize to [0,1]

    return (train_images, train_labels), (test_images, test_labels)

# Replace with actual paths
train_path = "D:/fashion-mnist_train.csv"
test_path = "D:/fashion-mnist_test.csv"
(train_images, train_labels), (test_images, test_labels) = load_fashion_mnist(train_path, test_path)

# Create Poisson spike trains from images
def poisson_encoding(images, rate=100 * Hz):
    num_samples, num_pixels = images.shape
    spike_trains = np.random.rand(num_samples, num_pixels) < (images * rate * sim_time).clip(0, 1)
    return spike_trains.astype(int)

# Define neuron model (Leaky Integrate-and-Fire)
eqs = '''
dv/dt = (rest_potential - v) / tau : volt (unless refractory)
'''

# Create neuron groups
input_layer = SpikeGeneratorGroup(n_input, np.arange(n_input), np.zeros(n_input) * ms)
exc_layer = NeuronGroup(
    n_neurons,
    eqs,
    threshold="v > threshold",
    reset="v = reset_potential",
    refractory=refractory_period,
    method="exact",
)

# Synapses with STDP learning
syn = Synapses(input_layer, exc_layer, 
    model='''
    w : 1
    apre : 1
    apost : 1
    ''',
    on_pre='''
    v_post += w * mV
    apre += 0.01
    w = clip(w + apost, 0, 1)
    ''',
    on_post='''
    apost += 0.01
    w = clip(w + apre, 0, 1)
    ''',
)
syn.connect(p=0.1)  # Connect 10% of input to neurons randomly
syn.w = "0.2 + 0.2 * rand()"
syn.apre = 0
syn.apost = 0

# Set up monitors
spike_mon = SpikeMonitor(exc_layer)
state_mon = StateMonitor(exc_layer, 'v', record=True)

# Train the network in batches
print("Training the SNN on Fashion-MNIST...")
num_batches = int(np.ceil(len(train_images) / batch_size))
for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(train_images))
    
    # Encode the current batch into spike trains
    batch_spikes = poisson_encoding(train_images[start_idx:end_idx])
    
    # Set input spikes for the current batch
    spike_indices = np.where(batch_spikes)[1]
    spike_times = np.zeros_like(spike_indices) * ms
    input_layer.set_spikes(spike_indices, spike_times)
    
    # Run simulation for the current batch
    run(sim_time)
    
    # Clear monitors to free memory
    spike_mon.i = []
    spike_mon.t = []
    
    print(f"Processed batch {batch_idx + 1}/{num_batches}")

# Plot spike raster
plt.figure(figsize=(10, 5))
plt.plot(spike_mon.t / ms, spike_mon.i, '.k')
plt.xlabel("Time (ms)")
plt.ylabel("Neuron Index")
plt.title("Spiking Activity of Excitatory Neurons")
plt.show()

# Test the trained network (Basic Evaluation)
test_spikes = poisson_encoding(test_images[:10])  # Testing on 10 samples
correct = 0

for i, sample in enumerate(test_spikes):
    # Reset the network
    input_layer.set_spikes([], [] * ms)  # Clear previous spikes
    exc_layer.v = rest_potential  # Reset membrane potential
    spike_mon.i = []  # Clear spike monitor
    spike_mon.t = []
    
    # Set input spikes for the current test sample
    spike_indices = np.where(sample)[0]
    spike_times = np.zeros_like(spike_indices) * ms
    input_layer.set_spikes(spike_indices, spike_times)
    
    # Run simulation
    run(sim_time)
    
    # Get the neuron that fired the most spikes
    if len(spike_mon.i) > 0:
        predicted_label = np.argmax(np.bincount(spike_mon.i))
    else:
        predicted_label = -1  # No spikes fired
    true_label = test_labels[i]

    if predicted_label == true_label:
        correct += 1

accuracy = correct / len(test_spikes) * 100
print(f"Test Accuracy on 10 samples: {accuracy:.2f}%")