In [2]:
import torch
import torch.nn as nn
import gymnasium as gym
import snntorch as snn
from snntorch import functional as SF
from snntorch import spikeplot as splt
import torchvision.transforms as T
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download ANN model from huggingface first
first install stable baselines3 zoo, then install their pong DQN model. Run these lines in terminal: \
pip install rl_zoo3 \
python -m rl_zoo3.load_from_hub --algo dqn --env PongNoFrameskip-v4 -orga sb3 -f logs/ \
python enjoy.py --algo dqn --env PongNoFrameskip-v4  -f logs/ 


# Prepare ANN and dataloader for spikingjelly's ann2snn converter to use

In [3]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv, VecFrameStack

from spikingjelly.clock_driven import ann2snn, functional
from torch.utils.data import DataLoader, TensorDataset

# Path to the ANN model (update for your environment)
ann_model_path = "/Volumes/export/isn/diana/rl-baselines3-zoo/logs/dqn/PongNoFrameskip-v4_1/PongNoFrameskip-v4.zip"

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create Atari Pong evaluation environment
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
video_folder = '/Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/'  # Folder to save videos
video_length = 2000  # Length of the recorded video (in timesteps)
env = VecVideoRecorder(env, video_folder,
                     record_video_trigger=lambda x: x == 0,  # Record starting from the first step
                     video_length=video_length,
                     name_prefix=f"PongNoFrameskip-v4-SNN")

# Collect observations using the ANN to estimate activation statistics
ann_model = DQN.load(ann_model_path, custom_objects={"replay_buffer_class": None, "optimize_memory_usage": False})
print("ANN model loaded successfully.")

obs = env.reset()
observations = []
print("Collecting observations...")
for _ in range(1000):
    action, _states = ann_model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    observations.append(obs[0])
    if done:
        obs = env.reset()
print("Collected observations.")

# Convert list of numpy arrays to a torch tensor
obs_array = np.stack(observations)                      # shape: [N, 84, 84, 4]
obs_array = np.transpose(obs_array, (0, 3, 1, 2))       # shape: [N, 4, 84, 84]
obs_tensor = torch.tensor(obs_array, dtype=torch.float32)
dummy_labels = torch.zeros(len(obs_tensor), dtype=torch.long)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer
  logger.warn(


ANN model loaded successfully.
Collecting observations...
Collected observations.


In [4]:
# Wrap in TensorDataset to make it compatible with ann2snn.Converter
obs_dataset = TensorDataset(obs_tensor, dummy_labels)
loader = DataLoader(obs_dataset, batch_size=32, shuffle=False)
print("Dataloader created")

Dataloader created


# Convert ANN to SNN

In [5]:

from spikingjelly.activation_based import model

# Convert the Q-network of the ANN policy to a SNN
print("Converting ANN to SNN...")
converter = ann2snn.Converter(dataloader=loader, mode=1.0 / 2)
ann_q_net = ann_model.policy.q_net
snn_q_net = converter(ann_q_net).to(device)

# save snn_q_net to disk
snn_q_net_path = "snn_pong_q_net.pth"
torch.save(snn_q_net.state_dict(), snn_q_net_path)
print(f"SNN model saved to {snn_q_net_path}")

Converting ANN to SNN...


100%|██████████| 32/32 [00:00<00:00, 367.17it/s]

SNN model saved to snn_pong_q_net.pth





In [41]:
print(snn_q_net)

QNetwork(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): Sequential(
        (0): VoltageScaler(1.371911)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2): VoltageScaler(0.728910)
      )
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): Sequential(
        (0): VoltageScaler(1.975725)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2): VoltageScaler(0.506143)
      )
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): Sequential(
        (0): VoltageScaler(2.588178)
        (1): IFNode(
          v_threshold=1.0, v_reset=None, detach_reset=False
          (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
        )
        (2):

# Convert the SNN to be Hi-AER Spike converter friendly

In [30]:
def fuse_and_remove_voltage_scalers(qnet: nn.Module):
    cnn         = qnet.features_extractor.cnn
    linear_feat = qnet.features_extractor.linear
    q_linear    = qnet.q_net[0]

    # ── 1) Gather every scale before mutating the model ──────────────
    # Block 1
    print("cnn[1]:", cnn[1])
    pre1  = cnn[1][0].scale
    post1 = cnn[1][2].scale
    # Block 2
    pre2  = cnn[3][0].scale
    post2 = cnn[3][2].scale
    # Block 3
    pre3  = cnn[5][0].scale
    post3 = cnn[5][2].scale
    # Final Linear block
    pre4  = linear_feat[1][0].scale
    post4 = linear_feat[1][2].scale

    # ── 2) Fold scalers into the weights/biases ────────────────────
    # Conv0 ← pre1
    cnn[0].weight.data.mul_(pre1)
    if cnn[0].bias is not None: cnn[0].bias.data.mul_(pre1)
    # Conv1 ← post1*pre2
    m12 = post1 * pre2
    cnn[2].weight.data.mul_(m12)
    if cnn[2].bias is not None: cnn[2].bias.data.mul_(m12)
    # Conv2 ← post2*pre3
    m23 = post2 * pre3
    cnn[4].weight.data.mul_(m23)
    if cnn[4].bias is not None: cnn[4].bias.data.mul_(m23)
    # Linear(3136→512) ← post3
    lin0 = linear_feat[0]
    lin0.weight.data.mul_(post3)
    if lin0.bias is not None: lin0.bias.data.mul_(post3)
    # q_net Linear ← post4
    q_linear.weight.data.mul_(post4)
    if q_linear.bias is not None: q_linear.bias.data.mul_(post4)

    # ── 3) Now strip out every VoltageScaler ──────────────────────
    # In cnn: replace each [VS, IFNode, VS] with just the IFNode
    for idx in (1, 3, 5):
        seq = cnn[idx]
        if isinstance(seq, nn.Sequential) and len(seq)==3:
            cnn[idx] = seq[1]
    # In linear_feat: replace [VS, IFNode, VS] with IFNode
    seq_lin = linear_feat[1]
    linear_feat[1] = seq_lin[1]

    return qnet


In [11]:
import copy
import torch
import torch.nn as nn

def fuse_and_remove_voltage_scalers(qnet: nn.Module):
    # 1) Work on a fresh copy so we never collide with a half-fused net
    net = copy.deepcopy(qnet)
    cnn         = net.features_extractor.cnn
    linear_feat = net.features_extractor.linear
    q_linear    = net.q_net[0]

    # 2) Gather _all_ the scales before touching the model
    s_pre1,  s_post1 = cnn[1][0].scale, cnn[1][2].scale
    s_pre2,  s_post2 = cnn[3][0].scale, cnn[3][2].scale
    s_pre3,  s_post3 = cnn[5][0].scale, cnn[5][2].scale
    s_pre4,  s_post4 = linear_feat[1][0].scale, linear_feat[1][2].scale

    # 3) Fold them into weights & biases
    # block1 → Conv0
    cnn[0].weight .data.mul_(s_pre1)
    if cnn[0].bias is not None: cnn[0].bias.data.mul_(s_pre1)

    # block2 → Conv1
    m12 = s_post1 * s_pre2
    cnn[2].weight .data.mul_(m12)
    if cnn[2].bias is not None: cnn[2].bias.data.mul_(m12)

    # block3 → Conv2
    m23 = s_post2 * s_pre3
    cnn[4].weight .data.mul_(m23)
    if cnn[4].bias is not None: cnn[4].bias.data.mul_(m23)

    # block4 pre → Linear(3136→512)
    m34 = s_post3 * s_pre4
    lin0 = linear_feat[0]
    lin0.weight .data.mul_(m34)
    if lin0.bias is not None: lin0.bias.data.mul_(m34)

    # final q_net → Linear(512→6)
    q_linear.weight.data.mul_(s_post4)
    if q_linear.bias is not None: q_linear.bias.data.mul_(s_post4)

    # 4) Now strip out every VoltageScaler, leaving only the IFNode
    for idx in (1, 3, 5):
        seq = cnn[idx]
        if isinstance(seq, nn.Sequential) and len(seq)==3:
            cnn[idx] = seq[1]   # keep only the IFNode

    # the Linear block
    seq_lin = linear_feat[1]
    if isinstance(seq_lin, nn.Sequential) and len(seq_lin)==3:
        linear_feat[1] = seq_lin[1]

    return net


In [12]:
fused_snn = fuse_and_remove_voltage_scalers(snn_q_net)

In [47]:
print(fused_snn)

QNetwork(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=3136, out_features=512, bias=True)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
    )
  )
  (q_ne

In [35]:
print(snn_q_net)

QNetwork(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=3136, out_features=512, bias=True)
      (1): IFNode(
        v_threshold=1.0, v_reset=None, detach_reset=False
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
    )
  )
  (q_ne

In [13]:
from spikingjelly.clock_driven import functional as sf_func

print("Evaluating SNN with rate coding...")
episodes   = 5
time_steps = 20  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
fused_snn.eval()

for ep in range(episodes):
    obs    = env.reset()
    obs       = obs[0]    # unwrap VecEnv
    done      = False
    total_reward = 0
    steps_per_episode = 0
    sf_func.reset_net(fused_snn)
    
    while done == False:
        # preprocess frame to [1,4,84,84]
        x = (
            torch.tensor(obs, dtype=torch.float32)
                 .permute(2, 0, 1)
                 .unsqueeze(0)
                 .to(device)
            # / 255.0
        )

        # reset all LIF states before rate‐coding loop
        sf_func.reset_net(fused_snn)

        # accumulate outputs over time_steps
        out_sum = torch.zeros(
            (1, ann_model.action_space.n), device=device
        )

        with torch.no_grad():
            for t in range(time_steps):
                out = fused_snn(x)   # returns spike‐counts or membrane outputs for this tick
                spike_outputs.append(out.detach().cpu().numpy())
                out_sum += out

        # compute rate‐coded Q values
        q_rate = out_sum / float(time_steps)
        # print(q_rate)
        action = q_rate.argmax(dim=1).item()

        # step the environment
        next_obs, reward, done, info = env.step([action])
        done   = done[0]
        reward = reward[0]
        obs    = next_obs[0]

        total_reward += reward
        steps_per_episode += 1

    rewards.append(total_reward)
    print(f"Episode {ep+1} reward: {total_reward}, steps: {steps_per_episode}")


Evaluating SNN with rate coding...
Episode 1 reward: 21.0, steps: 1643
Episode 2 reward: 20.0, steps: 1718
Episode 3 reward: 20.0, steps: 1718
Episode 4 reward: 21.0, steps: 1638
Episode 5 reward: 21.0, steps: 1694


In [14]:
# save fused snn
fused_snn_path = "fused_snn_pong.pt"
torch.save(fused_snn, fused_snn_path)

# Evaluate the SNN
and record videos of the episodes

In [None]:
from spikingjelly.clock_driven import functional as sf_func

print("Evaluating SNN with rate coding...")
episodes   = 2
time_steps = 15  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
snn_q_net.eval()

for ep in range(episodes):
    obs    = env.reset()
    obs       = obs[0]    # unwrap VecEnv
    done      = False
    total_reward = 0
    steps_per_episode = 0
    sf_func.reset_net(snn_q_net)
    
    while done == False:
        # preprocess frame to [1,4,84,84]
        x = (
            torch.tensor(obs, dtype=torch.float32)
                 .permute(2, 0, 1)
                 .unsqueeze(0)
                 .to(device)
            # / 255.0
        )

        # reset all LIF states before rate‐coding loop
        sf_func.reset_net(snn_q_net)

        # accumulate outputs over time_steps
        out_sum = torch.zeros(
            (1, ann_model.action_space.n), device=device
        )

        with torch.no_grad():
            for t in range(time_steps):
                out = snn_q_net(x)   # returns spike‐counts or membrane outputs for this tick
                spike_outputs.append(out.detach().cpu().numpy())
                out_sum += out

        # compute rate‐coded Q values
        q_rate = out_sum / float(time_steps)
        # print(q_rate)
        action = q_rate.argmax(dim=1).item()

        # step the environment
        next_obs, reward, done, info = env.step([action])
        done   = done[0]
        reward = reward[0]
        obs    = next_obs[0]

        total_reward += reward
        steps_per_episode += 1

    rewards.append(total_reward)
    print(f"Episode {ep+1} reward: {total_reward}, steps: {steps_per_episode}")


Evaluating SNN with rate coding...


tensor([[[[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.]],

         [[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          ...,
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
          [  0.,   0.,   0.,  ...

                                                                            

MoviePy - Done !
MoviePy - video ready /Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/PongNoFrameskip-v4-SNN-step-0-to-step-2000.mp4
tensor([[[[ 52.,  52.,  52.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
          ...,
          [236., 236., 236.,  ..., 236., 236., 236.],
          [236., 236., 236.,  ..., 236., 236., 236.],
          [236., 236., 236.,  ..., 236., 236., 236.]],

         [[ 52.,  52.,  52.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
          ...,
          [236., 236., 236.,  ..., 236., 236., 236.],
          [236., 236., 236.,  ..., 236., 236., 236.],
          [236., 236., 236.,  ..., 236., 236., 236.]],

         [[ 52.,  52.,  52.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
          [ 87.,  87.,  87.,  ...,  87.,  87.,  87.],
         

KeyboardInterrupt: 

timesteps = 20 Evaluating SNN with rate coding...
Episode 1 reward: 21.0, steps: 1642
Episode 2 reward: 21.0, steps: 1626
Episode 3 reward: 21.0, steps: 1644
Episode 4 reward: 21.0, steps: 1622
Episode 5 reward: 20.0, steps: 1733

# Make input be binary

In [9]:
fused_snn_path = "/Volumes/export/isn/diana/bindsnet/examples/pong/fused_snn_pong.pt"
fused_snn = torch.load(fused_snn_path, weights_only = False, map_location=device)

In [21]:
def binary_encode(obs: torch.Tensor, T: int = 20) -> torch.Tensor:
    """
    obs:   torch.Tensor of shape (B, H, W, C*4) dtype uint8 from VecFrameStack
    returns: spike tensor shape (T, B, C*4, H, W), dtype float32
    """
    # to float, move channels first, normalize to [0,1]
    # x = obs.float().div_(255.0).permute(0, 3, 1, 2)  # → (B, C, H, W)
    # shape up: (T, B, C, H, W)
    x_rep = x.unsqueeze(0).repeat(T, 1, 1, 1, 1)
    # sample Poisson counts with rate = pixel intensity
    spikes = torch.bernoulli(x_rep)
    return spikes

In [23]:
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

print("Evaluating SNN with rate coding...")
episodes   = 2
time_steps = 15  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
fused_snn.eval()

for ep in range(episodes):
    obs    = env.reset()
    obs       = obs[0]    # unwrap VecEnv
    done      = False
    total_reward = 0
    steps_per_episode = 0
    sf_func.reset_net(fused_snn)
    
    while done == False:
        # preprocess frame to [1,4,84,84]
        x = (
            torch.tensor(obs, dtype=torch.float32)
                 .permute(2, 0, 1)
                 .unsqueeze(0)
                 .to(device)
            / 255.0
        )
        
        # poisson encode x into 20 steps of spikes
        spikes = binary_encode(x, time_steps)
        # print("shape of spikes:", spikes.shape)
        # print("spikes:", spikes)

        # reset all LIF states before rate‐coding loop
        sf_func.reset_net(fused_snn)

        # accumulate outputs over time_steps
        out_sum = torch.zeros(
            (1, ann_model.action_space.n), device=device
        )

        with torch.no_grad():
            for t in range(time_steps):
                out = fused_snn(spikes[t])   # returns spike‐counts or membrane outputs for this tick
                spike_outputs.append(out.detach().cpu().numpy())
                out_sum += out

        # compute rate‐coded Q values
        q_rate = out_sum / float(time_steps)
        # print(q_rate)
        action = q_rate.argmax(dim=1).item()

        # step the environment
        next_obs, reward, done, info = env.step([action])
        done   = done[0]
        reward = reward[0]
        obs    = next_obs[0]

        total_reward += reward
        steps_per_episode += 1

    rewards.append(total_reward)
    print(f"Episode {ep+1} reward: {total_reward}, steps: {steps_per_episode}")


Evaluating SNN with rate coding...
Episode 1 reward: -21.0, steps: 757
Episode 2 reward: -21.0, steps: 759
