In [1]:
import torch
import torch.nn as nn
import gymnasium as gym
import snntorch as snn
from snntorch import functional as SF
from snntorch import spikeplot as splt
import torchvision.transforms as T
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv, VecFrameStack

from spikingjelly.clock_driven import ann2snn, functional
from torch.utils.data import DataLoader, TensorDataset

In [3]:
# Path to the ANN model (update for your environment)
ann_model_path = "/PongNoFrameskip-v4.zip"

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create Atari Pong evaluation environment
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
video_folder = '/Volumes/export/isn/diana/bindsnet/examples/pong/logs/videos/'  # Folder to save videos
video_length = 2000  # Length of the recorded video (in timesteps)
env = VecVideoRecorder(env, video_folder,
                     record_video_trigger=lambda x: x == 0,  # Record starting from the first step
                     video_length=video_length,
                     name_prefix=f"PongNoFrameskip-v4-SNN")

In [4]:
# load models
ann_model_path = "/Volumes/export/isn/diana/rl-baselines3-zoo/logs/dqn/PongNoFrameskip-v4_1/PongNoFrameskip-v4.zip"
ann_model = DQN.load(ann_model_path, custom_objects={"replay_buffer_class": None, "optimize_memory_usage": False})
snn_model = torch.load("snn_pong_q_net_full.pt", weights_only=False, map_location=device)
fused_snn = torch.load("fused_snn_pong.pt", weights_only=False, map_location=device)
target_snn = torch.load("fused_snn_pong.pt", weights_only=False, map_location=device)

Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer


In [5]:
# 1) Freeze everything in feature extractor
features_extractor = ann_model.policy.q_net.features_extractor
for p in features_extractor.parameters():
    p.requires_grad = False

target_features_extractor = ann_model.policy.q_net_target.features_extractor
for p in target_features_extractor.parameters():
    p.requires_grad = False

# 3) Verify
for name, p in ann_model.policy.named_parameters():
    print(f"{name:40s} requires_grad={p.requires_grad}")

q_net.features_extractor.cnn.0.weight    requires_grad=False
q_net.features_extractor.cnn.0.bias      requires_grad=False
q_net.features_extractor.cnn.2.weight    requires_grad=False
q_net.features_extractor.cnn.2.bias      requires_grad=False
q_net.features_extractor.cnn.4.weight    requires_grad=False
q_net.features_extractor.cnn.4.bias      requires_grad=False
q_net.features_extractor.linear.0.weight requires_grad=False
q_net.features_extractor.linear.0.bias   requires_grad=False
q_net.q_net.0.weight                     requires_grad=True
q_net.q_net.0.bias                       requires_grad=True
q_net_target.features_extractor.cnn.0.weight requires_grad=False
q_net_target.features_extractor.cnn.0.bias requires_grad=False
q_net_target.features_extractor.cnn.2.weight requires_grad=False
q_net_target.features_extractor.cnn.2.bias requires_grad=False
q_net_target.features_extractor.cnn.4.weight requires_grad=False
q_net_target.features_extractor.cnn.4.bias requires_grad=False
q_net_ta

In [6]:
ann_model = DQN.load(ann_model_path, custom_objects={"replay_buffer_class": None, "optimize_memory_usage": False})

# glorot initialize q network
nn.init.xavier_normal_(ann_model.policy.q_net.q_net[0].weight)
nn.init.zeros_(ann_model.policy.q_net.q_net[0].bias)
nn.init.xavier_normal_(ann_model.policy.q_net_target.q_net[0].weight)
nn.init.zeros_(ann_model.policy.q_net_target.q_net[0].bias)

Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4
)

# finetune ann

In [17]:
import copy
import random
from collections import deque, namedtuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from spikingjelly.clock_driven import functional as sf_func
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

# ─── Hyperparameters ────────────────────────────────────────────────────────────
ENV_ID         = "PongNoFrameskip-v4"
NUM_EPISODES   = 500
GAMMA          = 0.99
LR             = 1e-4
TARGET_SYNC    = 10       # episodes between syncing target network
BUFFER_SIZE    = 100_000
BATCH_SIZE     = 32
MIN_REPLAY     = 1_000    # start training after this many transitions
EPS_START      = 1.0
EPS_END        = 0.1
EPS_DECAY      = 100_000  # frames over which epsilon decays
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Environment setup ──────────────────────────────────────────────────────────
env = make_atari_env(ENV_ID, n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
action_dim = env.action_space.n

def unwrap(obs_tuple):
    # unwrap VecEnv: obs_tuple is ([frames], infos)
    return obs_tuple[0]

# ─── Load & clone networks ──────────────────────────────────────────────────────
fine_tuned_ann_model_path = "./ann_q_net_finetuned.pth" # epsilon = 0.850
fine_tuned_ann_target_model_path = "./ann_q_net_finetuned_target.pth" # epsilon = 0.850
# ann_model = torch.load(fine_tuned_ann_model_path, weights_only=False, map_location=DEVICE)
# ann_model_target = torch.load(fine_tuned_ann_target_model_path, weights_only=False, map_location=DEVICE)
ann_model = DQN.load(ann_model_path, custom_objects={"replay_buffer_class": None, "optimize_memory_usage": False})
ann_model_target = ann_model.q_net_target
ann_model = ann_model.q_net

# ─── Replay buffer ──────────────────────────────────────────────────────────────
Transition = namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=BUFFER_SIZE)

# ─── Optimizer & loss ───────────────────────────────────────────────────────────
optimizer = optim.Adam(ann_model.parameters(), lr=LR)
criterion = nn.SmoothL1Loss()   # Huber loss

# ─── Epsilon schedule ───────────────────────────────────────────────────────────
def epsilon_by_frame(frame_idx):
    return EPS_END + (EPS_START - EPS_END) * np.exp(-1.0 * frame_idx / EPS_DECAY)

# ─── Helper: compute Q‐rates for a batch of observations ─────────────────────────
def compute_q_rates(net, obs_batch):
    """
    obs_batch: Tensor of shape (B, H, W, C), values in [0,255]
    returns: Tensor of shape (B, action_dim)
    """
    qs = []
    for obs in obs_batch:
        x = obs.permute(2, 0, 1).unsqueeze(0).to(DEVICE) / 255.0
        sf_func.reset_net(net)
        qs.append(net(x))
    return torch.cat(qs, dim=0)

# ─── Pre-fill replay buffer with random play ────────────────────────────────────
obs = unwrap(env.reset())
for _ in range(MIN_REPLAY):
    action = env.action_space.sample()
    next_obs, reward, done, _ = env.step([action])
    next_obs, reward, done = next_obs[0], reward[0], done[0]
    replay_buffer.append(Transition(obs, action, reward, next_obs, done))
    obs = next_obs if not done else unwrap(env.reset())

# ─── Main training loop ─────────────────────────────────────────────────────────
frame_idx = 0
for ep in range(1, NUM_EPISODES + 1):
    obs = unwrap(env.reset())
    total_reward = 0
    done = False

    # reset spiking states
    sf_func.reset_net(ann_model)
    sf_func.reset_net(ann_model_target)

    while not done:
        frame_idx += 1
        eps = epsilon_by_frame(frame_idx)

        # ε-greedy action selection
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_rate = compute_q_rates(ann_model, torch.tensor(obs[None], dtype=torch.float32))
            action = q_rate.argmax(dim=1).item()

        # step environment
        next_obs, reward, done, _ = env.step([action])
        next_obs, reward, done = next_obs[0], reward[0], done[0]
        replay_buffer.append(Transition(obs, action, reward, next_obs, done))
        obs = next_obs
        total_reward += reward

        # once we have enough samples, perform a training step
        if len(replay_buffer) >= MIN_REPLAY:
            transitions = random.sample(replay_buffer, BATCH_SIZE)
            batch = Transition(*zip(*transitions))

            state_batch      = torch.stack([torch.tensor(s, dtype=torch.float32) for s in batch.state])
            next_state_batch = torch.stack([torch.tensor(s, dtype=torch.float32) for s in batch.next_state])
            action_batch     = torch.tensor(batch.action, dtype=torch.int64, device=DEVICE).unsqueeze(1)
            reward_batch     = torch.tensor(batch.reward, dtype=torch.float32, device=DEVICE).unsqueeze(1)
            done_batch       = torch.tensor(batch.done, dtype=torch.float32, device=DEVICE).unsqueeze(1)

            # current Q-values
            q_values = compute_q_rates(ann_model, state_batch)
            current_q = q_values.gather(1, action_batch)

            # Double DQN: select next action via online net
            with torch.no_grad():
                next_q_online = compute_q_rates(ann_model, next_state_batch)
                next_actions  = next_q_online.argmax(dim=1, keepdim=True)

                # evaluate with target net
                next_q_target = compute_q_rates(ann_model_target, next_state_batch)
                next_q        = next_q_target.gather(1, next_actions)

                # build TD target, mask terminals
                td_target = reward_batch + GAMMA * (1 - done_batch) * next_q

            # loss & optimize
            loss = criterion(current_q, td_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # sync target network periodically
    if ep % TARGET_SYNC == 0:
        ann_model_target.load_state_dict(ann_model.state_dict())

    print(f"Episode {ep:03d}  Reward: {total_reward:.1f}  Epsilon: {eps:.3f}")
    
    # save model every 10 episodes
    if ep % 10 == 0:
        torch.save(ann_model, f"ann_q_net_finetuned.pth")
        torch.save(ann_model_target, f"ann_q_net_finetuned_target.pth")
        print(f"Saved model at episode {ep}")




Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer
Exception: 'bytes' object cannot be interpreted as an integer


KeyboardInterrupt: 

# finetune SNN

In [None]:
import copy
import random
from collections import deque, namedtuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from spikingjelly.clock_driven import functional as sf_func
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

# ─── Hyperparameters ────────────────────────────────────────────────────────────
ENV_ID         = "PongNoFrameskip-v4"
NUM_EPISODES   = 500
TIME_STEPS     = 20       # SNN ticks per frame
GAMMA          = 0.99
LR             = 1e-4
TARGET_SYNC    = 10       # episodes between syncing target network
BUFFER_SIZE    = 100_000
BATCH_SIZE     = 32
MIN_REPLAY     = 1_000    # start training after this many transitions
EPS_START      = 1.0
EPS_END        = 0.1
EPS_DECAY      = 100_000  # frames over which epsilon decays
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Environment setup ──────────────────────────────────────────────────────────
env = make_atari_env(ENV_ID, n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
action_dim = env.action_space.n

def unwrap(obs_tuple):
    # unwrap VecEnv: obs_tuple is ([frames], infos)
    return obs_tuple[0]

# ─── Load & clone networks ──────────────────────────────────────────────────────
fused_snn = torch.load("path/to/fused_snn.pth", map_location=DEVICE)
fused_snn.to(DEVICE)
target_snn = copy.deepcopy(fused_snn).to(DEVICE)
target_snn.eval()

# ─── Replay buffer ──────────────────────────────────────────────────────────────
Transition = namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])
replay_buffer = deque(maxlen=BUFFER_SIZE)

# ─── Optimizer & loss ───────────────────────────────────────────────────────────
optimizer = optim.Adam(fused_snn.parameters(), lr=LR)
criterion = nn.SmoothL1Loss()   # Huber loss

# ─── Epsilon schedule ───────────────────────────────────────────────────────────
def epsilon_by_frame(frame_idx):
    return EPS_END + (EPS_START - EPS_END) * np.exp(-1.0 * frame_idx / EPS_DECAY)

# ─── Helper: compute Q‐rates for a batch of observations ─────────────────────────
def compute_q_rates(net, obs_batch):
    """
    obs_batch: Tensor of shape (B, H, W, C), values in [0,255]
    returns: Tensor of shape (B, action_dim)
    """
    qs = []
    for obs in obs_batch:
        x = obs.permute(2, 0, 1).unsqueeze(0).to(DEVICE) / 255.0
        sf_func.reset_net(net)
        out_sum = torch.zeros((1, action_dim), device=DEVICE)
        for _ in range(TIME_STEPS):
            out_sum += net(x)
        qs.append(out_sum.div_(TIME_STEPS))
    return torch.cat(qs, dim=0)

# ─── Pre-fill replay buffer with random play ────────────────────────────────────
obs = unwrap(env.reset())
for _ in range(MIN_REPLAY):
    action = env.action_space.sample()
    next_obs, reward, done, _ = env.step([action])
    next_obs, reward, done = next_obs[0], reward[0], done[0]
    replay_buffer.append(Transition(obs, action, reward, next_obs, done))
    obs = next_obs if not done else unwrap(env.reset())

# ─── Main training loop ─────────────────────────────────────────────────────────
frame_idx = 0
for ep in range(1, NUM_EPISODES + 1):
    obs = unwrap(env.reset())
    total_reward = 0
    done = False

    # reset spiking states
    sf_func.reset_net(fused_snn)
    sf_func.reset_net(target_snn)

    while not done:
        frame_idx += 1
        eps = epsilon_by_frame(frame_idx)

        # ε-greedy action selection
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_rate = compute_q_rates(fused_snn, torch.tensor(obs[None], dtype=torch.float32))
            action = q_rate.argmax(dim=1).item()

        # step environment
        next_obs, reward, done, _ = env.step([action])
        next_obs, reward, done = next_obs[0], reward[0], done[0]
        replay_buffer.append(Transition(obs, action, reward, next_obs, done))
        obs = next_obs
        total_reward += reward

        # once we have enough samples, perform a training step
        if len(replay_buffer) >= MIN_REPLAY:
            transitions = random.sample(replay_buffer, BATCH_SIZE)
            batch = Transition(*zip(*transitions))

            state_batch      = torch.stack([torch.tensor(s, dtype=torch.float32) for s in batch.state])
            next_state_batch = torch.stack([torch.tensor(s, dtype=torch.float32) for s in batch.next_state])
            action_batch     = torch.tensor(batch.action, dtype=torch.int64, device=DEVICE).unsqueeze(1)
            reward_batch     = torch.tensor(batch.reward, dtype=torch.float32, device=DEVICE).unsqueeze(1)
            done_batch       = torch.tensor(batch.done, dtype=torch.float32, device=DEVICE).unsqueeze(1)

            # current Q-values
            q_values = compute_q_rates(fused_snn, state_batch)
            current_q = q_values.gather(1, action_batch)

            # Double DQN: select next action via online net
            with torch.no_grad():
                next_q_online = compute_q_rates(fused_snn, next_state_batch)
                next_actions  = next_q_online.argmax(dim=1, keepdim=True)

                # evaluate with target net
                next_q_target = compute_q_rates(target_snn, next_state_batch)
                next_q        = next_q_target.gather(1, next_actions)

                # build TD target, mask terminals
                td_target = reward_batch + GAMMA * (1 - done_batch) * next_q

            # loss & optimize
            loss = criterion(current_q, td_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # sync target network periodically
    if ep % TARGET_SYNC == 0:
        target_snn.load_state_dict(fused_snn.state_dict())

    print(f"Episode {ep:03d}  Reward: {total_reward:.1f}  Epsilon: {eps:.3f}")

# ─── Save final weights ─────────────────────────────────────────────────────────
torch.save(fused_snn.state_dict(), "fused_snn_dqn_final.pth")


TypeError: ReplayBuffer.__init__() got an unexpected keyword argument 'capacity'

# Evaluate

In [None]:
from spikingjelly.clock_driven import functional as sf_func

print("Evaluating SNN with rate coding...")
episodes   = 5
time_steps = 20  # how many SNN ticks per frame
rewards    = []
spike_outputs = []

# Make sure your network is in eval mode
fused_snn.eval()

for ep in range(episodes):
    obs    = env.reset()
    obs       = obs[0]    # unwrap VecEnv
    
    done      = False
    total_reward = 0
    steps_per_episode = 0
    sf_func.reset_net(fused_snn)
    
    while done == False:
        # preprocess frame to [1,4,84,84]
        x = (
            torch.tensor(obs, dtype=torch.float32)
                 .permute(2, 0, 1)
                 .unsqueeze(0)
                 .to(device)
            # / 255.0
        )

        # reset all LIF states before rate‐coding loop
        sf_func.reset_net(fused_snn)

        # accumulate outputs over time_steps
        out_sum = torch.zeros(
            (1, fused_snn.action_space.n), device=device
        )

        with torch.no_grad():
            for t in range(time_steps):
                out = fused_snn(x)   # returns spike‐counts or membrane outputs for this tick
                spike_outputs.append(out.detach().cpu().numpy())
                out_sum += out

        # compute rate‐coded Q values
        q_rate = out_sum / float(time_steps)
        # print(q_rate)
        action = q_rate.argmax(dim=1).item()

        # step the environment
        next_obs, reward, done, info = env.step([action])
        done   = done[0]
        reward = reward[0]
        obs    = next_obs[0]

        total_reward += reward
        steps_per_episode += 1

    rewards.append(total_reward)
    print(f"Episode {ep+1} reward: {total_reward}, steps: {steps_per_episode}")
