In [None]:
import gymnasium as gym
import ale_py
import torch
gym.register_envs(ale_py)
print(gym.envs.registry.keys())

print(torch.cuda.is_available())  # Should print True
print(torch.cuda.device_count())  # Should be > 0
#print(torch.cuda.current_device())  # Should not error
#print(torch.cuda.get_device_name(0))


ModuleNotFoundError: No module named 'gymnasium'

In [5]:
from stable_baselines3.common.atari_wrappers import AtariWrapper
import gymnasium as gym

def make_atari_env(game):
    env = gym.make(game, render_mode=None)
    env = AtariWrapper(env)
    return env


AttributeError: module 'numpy' has no attribute 'bool8'

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_util import DummyVecEnv

def make_atari_env(game):
    env = gym.make(game, render_mode=None)
    env = AtariWrapper(env)
    # Move channel to first dimension for Stable Baselines3 compatibility
    obs_shape = env.observation_space.shape
    new_obs_space = gym.spaces.Box(
        low=0, high=255, shape=(obs_shape[2], obs_shape[0], obs_shape[1]), dtype=np.uint8
    )
    env = gym.wrappers.TransformObservation(
        env, lambda obs: np.transpose(obs, (2, 0, 1)), observation_space=new_obs_space
    )
    return env

def make_vec_atari_env(game):
    return DummyVecEnv([lambda: make_atari_env(game)])

# =====================================================
# 1. Task sequence (different action spaces!)
# =====================================================
tasks = ["ALE/Pong-v5", "ALE/Breakout-v5", "ALE/SpaceInvaders-v5"]

# =====================================================
# 2. Shared CNN feature extractor
# =====================================================
class SharedCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=512):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]  # channel first
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            sample = torch.zeros(1, n_input_channels,
                                 observation_space.shape[1],
                                 observation_space.shape[2])
            n_flatten = self.cnn(sample).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # observations: (N, C, H, W)
        x = observations.float() / 255.0
        return self.linear(self.cnn(x))

# =====================================================
# 3. Custom Actor-Critic Policy with task-specific heads
# =====================================================
class MultiHeadPolicy(ActorCriticPolicy):
    def __init__(self, *args, task_action_dims=None, **kwargs):
        super().__init__(*args, **kwargs,
                         features_extractor_class=SharedCNN,
                         features_extractor_kwargs=dict(features_dim=512))
        # Create one head per task
        self.task_heads = nn.ModuleDict({
            task: nn.Linear(self.features_extractor.features_dim, n_actions)
            for task, n_actions in task_action_dims.items()
        })
        self.value_heads = nn.ModuleDict({
            task: nn.Linear(self.features_extractor.features_dim, 1)
            for task in task_action_dims.keys()
        })
        self.active_task = None  # will be set dynamically

    def set_active_task(self, task):
        self.active_task = task

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.task_heads[self.active_task](features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_heads[self.active_task](features)

# =====================================================
# 4. Training & Evaluation
# =====================================================
# Build action dims dict
task_action_dims = {
    task: gym.make(task).action_space.n for task in tasks
}

# Shared model
policy_kwargs = dict(task_action_dims=task_action_dims)
model = PPO(
    MultiHeadPolicy,
    make_vec_atari_env(tasks[0]),
    policy_kwargs=policy_kwargs,
    verbose=1,
    n_steps=128,
    batch_size=64,
    device="cuda",
)

eval_scores = {task: [] for task in tasks}

def evaluate(model, task, n_episodes=3):
    env = gym.make(task)
    model.policy.set_active_task(task)
    returns = []
    for _ in range(n_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            done = terminated or truncated
        returns.append(total_reward)
    env.close()
    return np.mean(returns)

timesteps = 50_000  # adjust for bigger training

for i, task in enumerate(tasks):
    print(f"\n=== Training on {task} ===")
    env = make_vec_env(lambda: make_atari_env(task), n_envs=1)
    model.set_env(env)
    model.policy.set_active_task(task)
    model.learn(total_timesteps=timesteps)

    # Evaluate across all tasks so far
    for eval_task in tasks[: i + 1]:
        score = evaluate(model, eval_task)
        eval_scores[eval_task].append(score)
        print(f"Eval on {eval_task}: {score:.2f}")

# =====================================================
# 5. Plot results
# =====================================================
plt.figure(figsize=(8,5))
for task in tasks:
    plt.plot(range(1, len(eval_scores[task]) + 1), eval_scores[task], label=task)
plt.xlabel("Task step")
plt.ylabel("Evaluation Score")
plt.title("Continual RL with Shared CNN + Task-Specific Heads (PPO)")
plt.legend()
plt.show()



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "c:\Users\jjul482\AppData\Local\anaconda3\envs\adapters\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "c:\Users\jjul482\AppData\Local\anaconda3\envs\adapters\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "c:\Users\jjul482\AppData\Local\anaconda3\envs\adapters\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "c:\Users\jjul482\AppData\Local\anaconda3\envs\adapters\lib\site-packages\trait

ModuleNotFoundError: No module named 'stable_baselines3'

In [None]:
# Calculate forgetting dynamics
forgetting = {}
for task in tasks:
    scores = eval_scores[task]
    max_score = max(scores)
    final_score = scores[-1]
    forgetting[task] = max_score - final_score

print("\nForgetting per task:")
for t, f in forgetting.items():
    print(f"{t}: {f:.2f}")


In [None]:
# Plot evaluation scores over task sequence
plt.figure(figsize=(8, 5))
for task in tasks:
    plt.plot(range(1, len(eval_scores[task]) + 1), eval_scores[task], label=task)

plt.xlabel("Task step")
plt.ylabel("Evaluation Score")
plt.title("Continual RL: Evaluation Across Tasks")
plt.legend()
plt.show()


In [20]:
env = make_atari_env("ALE/Pong-v5")
obs, _ = env.reset()
print("Obs shape:", obs.shape)

import torch
extractor = SharedCNN(env.observation_space)
features = extractor(torch.from_numpy(obs[None]))
print("Features shape:", features.shape)


Obs shape: (84, 84, 1)
Features shape: torch.Size([1, 512])
