In [1]:
import torch
import torch.nn as nn
import gymnasium as gym
import snntorch as snn
from snntorch import functional as SF
from snntorch import spikeplot as splt
import torchvision.transforms as T
import numpy as np
from stable_baselines3 import A2C

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
import torch
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv, VecFrameStack

from spikingjelly.clock_driven import ann2snn, functional
from torch.utils.data import DataLoader, TensorDataset

# Path to the ANN model (update for your environment)
ann_model_path = "/Volumes/export/isn/diana/rl-baselines3-zoo/logs/dqn/PongNoFrameskip-v4_1/PongNoFrameskip-v4.zip"

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create Atari Pong evaluation environment
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
video_folder = '/Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/'  # Folder to save videos
video_length = 2000  # Length of the recorded video (in timesteps)
env = VecVideoRecorder(env, video_folder,
                     record_video_trigger=lambda x: x == 0,  # Record starting from the first step
                     video_length=video_length,
                     name_prefix=f"PongNoFrameskip-v4-SNN")

# Collect observations using the ANN to estimate activation statistics
ann_model = DQN.load(ann_model_path, custom_objects={"replay_buffer_class": None, "optimize_memory_usage": False})
print("ANN model loaded successfully.")

obs = env.reset()
observations = []
print("Collecting observations...")
for _ in range(1000):
    action, _states = ann_model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    observations.append(obs[0])
    if done:
        obs = env.reset()
print("Collected observations.")

# Convert list of numpy arrays to a torch tensor
obs_array = np.stack(observations)                      # shape: [N, 84, 84, 4]
obs_array = np.transpose(obs_array, (0, 3, 1, 2))       # shape: [N, 4, 84, 84]
obs_tensor = torch.tensor(obs_array, dtype=torch.float32)
dummy_labels = torch.zeros(len(obs_tensor), dtype=torch.long)

ANN model loaded successfully.
Collecting observations...
Collected observations.


In [3]:
# Wrap in TensorDataset to make it compatible with ann2snn.Converter
obs_dataset = TensorDataset(obs_tensor, dummy_labels)
loader = DataLoader(obs_dataset, batch_size=32, shuffle=False)
print("Dataloader created")

Dataloader created


In [4]:

from spikingjelly.activation_based import model

# Convert the Q-network of the ANN policy to a SNN
print("Converting ANN to SNN...")
converter = ann2snn.Converter(dataloader=loader, mode=1.0 / 2)
ann_q_net = ann_model.policy.q_net
snn_q_net = converter(ann_q_net).to(device)

# save snn_q_net to disk
snn_q_net_path = "snn_pong_q_net.pth"
torch.save(snn_q_net.state_dict(), snn_q_net_path)
print(f"SNN model saved to {snn_q_net_path}")

Converting ANN to SNN...


100%|██████████| 32/32 [00:00<00:00, 407.27it/s]

SNN model saved to snn_pong_q_net.pth





In [20]:
from spikingjelly.clock_driven import functional as sf_func

print("Evaluating SNN with rate coding...")
episodes   = 2
time_steps = 15  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
snn_q_net.eval()

for ep in range(episodes):
    obs    = env.reset()
    obs       = obs[0]    # unwrap VecEnv
    done      = False
    total_reward = 0
    steps_per_episode = 0
    sf_func.reset_net(snn_q_net)
    
    while done == False:
        # preprocess frame to [1,4,84,84]
        x = (
            torch.tensor(obs, dtype=torch.float32)
                 .permute(2, 0, 1)
                 .unsqueeze(0)
                 .to(device)
            # / 255.0
        )

        # reset all LIF states before rate‐coding loop
        sf_func.reset_net(snn_q_net)

        # accumulate outputs over time_steps
        out_sum = torch.zeros(
            (1, ann_model.action_space.n), device=device
        )

        with torch.no_grad():
            for t in range(time_steps):
                out = snn_q_net(x)   # returns spike‐counts or membrane outputs for this tick
                spike_outputs.append(out.detach().cpu().numpy())
                out_sum += out

        # compute rate‐coded Q values
        q_rate = out_sum / float(time_steps)
        # print(q_rate)
        action = q_rate.argmax(dim=1).item()

        # step the environment
        next_obs, reward, done, info = env.step([action])
        done   = done[0]
        reward = reward[0]
        obs    = next_obs[0]

        total_reward += reward
        steps_per_episode += 1

    rewards.append(total_reward)
    print(f"Episode {ep+1} reward: {total_reward}, steps: {steps_per_episode}")


Evaluating SNN with rate coding...
Episode 1 reward: 1.0, steps: 77
Saving video to /Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/PongNoFrameskip-v4-SNN-step-0-to-step-2000.mp4
MoviePy - Building video /Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/PongNoFrameskip-v4-SNN-step-0-to-step-2000.mp4.
MoviePy - Writing video /Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/PongNoFrameskip-v4-SNN-step-0-to-step-2000.mp4



                                                                            

MoviePy - Done !
MoviePy - video ready /Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/PongNoFrameskip-v4-SNN-step-0-to-step-2000.mp4
Episode 2 reward: 21.0, steps: 1632


# visualize an episode

In [16]:
obs    = env.reset()
obs       = obs[0]    # unwrap VecEnv
done      = False
total_reward = 0
steps_per_episode = 0
time_steps = 15  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
snn_q_net.eval()
sf_func.reset_net(snn_q_net)

while done == False:
    # preprocess frame to [1,4,84,84]
    x = (
        torch.tensor(obs, dtype=torch.float32)
                .permute(2, 0, 1)
                .unsqueeze(0)
                .to(device)
        # / 255.0
    )

    # reset all LIF states before rate‐coding loop
    sf_func.reset_net(snn_q_net)

    # accumulate outputs over time_steps
    out_sum = torch.zeros(
        (1, ann_model.action_space.n), device=device
    )

    with torch.no_grad():
        for t in range(time_steps):
            out = snn_q_net(x)   # returns spike‐counts or membrane outputs for this tick
            spike_outputs.append(out.detach().cpu().numpy())
            out_sum += out

    # compute rate‐coded Q values
    q_rate = out_sum / float(time_steps)
    # print(q_rate)
    action = q_rate.argmax(dim=1).item()

    # step the environment
    next_obs, reward, done, info = env.step([action])
    done   = done[0]
    reward = reward[0]
    obs    = next_obs[0]

    total_reward += reward
    steps_per_episode += 1

rewards.append(total_reward)
print(f"Episode reward: {total_reward}, steps: {steps_per_episode}")

Episode reward: 9.0, steps: 623


timesteps = 20 Evaluating SNN with rate coding...
Episode 1 reward: 21.0, steps: 1642
Episode 2 reward: 21.0, steps: 1626
Episode 3 reward: 21.0, steps: 1644
Episode 4 reward: 21.0, steps: 1622
Episode 5 reward: 20.0, steps: 1733

# quantization

In [6]:
print(snn_q_net)

QNetwork(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): Sequential(
        (0): VoltageScaler(1.371911)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2): VoltageScaler(0.728910)
      )
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): Sequential(
        (0): VoltageScaler(1.975725)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2): VoltageScaler(0.506143)
      )
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): Sequential(
        (0): VoltageScaler(2.588178)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2):

In [7]:
print(ann_q_net)

QNetwork(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=3136, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (q_net): Sequential(
    (0): Linear(in_features=512, out_features=6, bias=True)
  )
)


In [None]:

from hs_api.converter.cri_converter import Quantize_Network

alpha = 4
qn = Quantize_Network(w_alpha=alpha)
net_quan = qn.quantize(ann_q_net)

# Hi-AER Spike Conversion

In [None]:
from hs_api.converter.cri_converter import CRI_Converter

input_layer = 1 #first pytorch layer that acts as synapses, indexing begins at 0 
output_layer = 4 #last pytorch layer that acts as synapses
snn_layers = 2 # number of snn layers 
input_shape = (1, 28, 28)
backend = 'spikingjelly'
v_threshold = qn.v_threshold
    
cn = CRI_Converter(num_steps = args.T,
                   input_layer = input_layer, 
                   output_layer = output_layer, 
                   input_shape = input_shape,
                   snn_layers = snn_layers,
                   backend = backend,
                   v_threshold = int(v_threshold))

cn.layer_converter(net_quan)

In [None]:
# initiate the model

config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = int(qn.v_threshold)

softwareNetwork = CRI_network(dict(cn.axon_dict),
                              connections=dict(cn.neuron_dict),
                              config=config,target='simpleSim', 
                              outputs = cn.output_neurons,
                              coreID=1)