In [None]:
import torch
import numpy as np
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import (
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

from phd.streaming_rl.core.algorithms.stream_q import StreamQ
from phd.streaming_rl.core.processing import NormalizeObservation, ScaleReward

In [None]:
def agent_env_interaction(env_name='PongNoFrameskip-v4', seed=0, lr=1.0, gamma=0.99, lamda=0.8, total_steps=10_000, epsilon_target=0.01, epsilon_start=1.0, exploration_fraction=0.05, kappa_value=2.0, debug=True, save_agent=True, update_agent=True, load_agent=None, render=False):
    torch.manual_seed(seed); np.random.seed(seed)
    env = gym.make(env_name, render_mode="rgb_array")
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    env = EpisodicLifeEnv(env)
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)
    env = NormalizeObservation(env)
    env = gym.wrappers.FrameStack(env, 4)
    env = ScaleReward(env, gamma=gamma)
    agent = StreamQ(n_actions=env.action_space.n, lr=lr, gamma=gamma, lamda=lamda, epsilon_target=epsilon_target, epsilon_start=epsilon_start, exploration_fraction=exploration_fraction, total_steps=total_steps, kappa_value=kappa_value)

    # if render:
    #   from pyvirtualdisplay import Display
    #   virtual_display = Display(visible=0, size=(320, 240))
    #   virtual_display.start()

    if load_agent is not None:
        agent.load_state_dict(torch.load(load_agent+"/seed_{}.pth".format(seed), weights_only=False))
        reward_stats, obs_stats = pickle.load(open(load_agent+"/stats_data_{}.pkl".format(seed), "rb"))
        env.obs_stats.mean = obs_stats.mean
        env.obs_stats.var = obs_stats.var
        env.obs_stats.count = obs_stats.count
        env.obs_stats.p = obs_stats.p
        env.reward_stats.mean = reward_stats.mean
        env.reward_stats.var = reward_stats.var
        env.reward_stats.count = reward_stats.count
        env.reward_stats.p = reward_stats.p

    if debug:
        print("seed: {}".format(seed), "env: {}".format(env.spec.id))

    returns, term_time_steps, frames = [], [], []
    s, _ = env.reset(seed=seed)
    episode_num = 1
    for t in range(1, total_steps+1):
        a, is_nongreedy = agent.sample_action(s)
        s_prime, r, terminated, _, info = env.step(a)
        if update_agent:
            agent.update_params(s, a, r, s_prime, terminated, is_nongreedy)
        s = s_prime
        if render:
            if t % 4 == 0:
                frame = env.render()
                frames.append(frame)
        if info and "episode" in info:
            if debug:
                print("Episodic Return: {}, Time Step {}, Episode Number {}, Epsilon {}".format(info['episode']['r'][0], t, episode_num, agent.epsilon))
            returns.append(info['episode']['r'][0])
            term_time_steps.append(t)
            s, _ = env.reset()
            episode_num += 1
    env.close()

    if save_agent:
        save_dir = "data_stream_q_{}".format(env.spec.id)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(os.path.join(save_dir, "seed_{}.pkl".format(seed)), "wb") as f:
            pickle.dump((returns, term_time_steps, env_name), f)
        save_dir = "stream_q_{}".format(env.spec.id)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(agent.state_dict(), os.path.join(save_dir, "seed_{}.pth".format(seed)))
        reward_stats = env.reward_stats
        obs_stats = env.obs_stats
        with open(os.path.join(save_dir, "stats_data_{}.pkl".format(seed)), "wb") as f:
            pickle.dump((reward_stats, obs_stats), f)

    return frames

In [9]:
agent_env_interaction(env_name='PongNoFrameskip-v4',
                      seed=0,
                      lr=1.0,
                      gamma=0.99,
                      lamda=0.8,
                      total_steps=300_000,
                      epsilon_target=0.01,
                      epsilon_start=1.0,
                      exploration_fraction=0.05,
                      kappa_value=2.0,
                      debug=True,
                      save_agent=True,
                      update_agent=True,
)

seed: 0 env: PongNoFrameskip-v4


  x = torch.tensor(np.array(x), dtype=torch.float)
  s, a, r, s_prime, done_mask = torch.tensor(np.array(s), dtype=torch.float), torch.tensor([a], dtype=torch.int).squeeze(0), \
  torch.tensor(np.array(r)), torch.tensor(np.array(s_prime), dtype=torch.float), \


Episodic Return: -20.0, Time Step 891, Episode Number 1, Epsilon 0.941194
Episodic Return: -21.0, Time Step 1672, Episode Number 2, Epsilon 0.889648
Episodic Return: -20.0, Time Step 2605, Episode Number 3, Epsilon 0.82807
Episodic Return: -21.0, Time Step 3388, Episode Number 4, Epsilon 0.776392
Episodic Return: -21.0, Time Step 4256, Episode Number 5, Epsilon 0.719104
Episodic Return: -21.0, Time Step 5040, Episode Number 6, Epsilon 0.66736
Episodic Return: -21.0, Time Step 6133, Episode Number 7, Epsilon 0.5952219999999999
Episodic Return: -21.0, Time Step 7030, Episode Number 8, Epsilon 0.5360199999999999
Episodic Return: -21.0, Time Step 7847, Episode Number 9, Epsilon 0.4820979999999999
Episodic Return: -21.0, Time Step 8845, Episode Number 10, Epsilon 0.41623
Episodic Return: -21.0, Time Step 9660, Episode Number 11, Epsilon 0.36244
Episodic Return: -19.0, Time Step 10594, Episode Number 12, Epsilon 0.30079599999999995
Episodic Return: -21.0, Time Step 11590, Episode Number 13, 

  logger.warn(


[]