# This file is for creating the benchmark, with stacked frames.

# import relevant packages

In [18]:
import gymnasium as gym
import torch
import wandb

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, EveryNTimesteps
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage

from feature_extraction.callbacks.wandb_reward_logging_callback import WandbRewardLoggingCallback
from utils import evaluate_policy


# Settings

In [19]:
progress_bar = False
train_model = True
eval_model = True
save_name = "a2c_breakout_benchmark_framestack"

# Login to wanb and create a project with config

In [20]:
wandb.login()

config = dict(
    env_id="ALE/Breakout-v5",
    algorithm='PPO',
    #Hyperparams
    policy="CnnPolicy",
    learning_rate=2.5e-4,
    n_steps=128,
    batch_size=256,
    n_epochs=4,
    n_envs=8,
    n_timesteps=10_000,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.1,
    clip_range_vf=None,
    normalize_advantage=True,
    normalize=False,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    use_sde=False,
    sde_sample_freq=-1,
    rollout_buffer_class=None,
    rollout_buffer_kwargs=None,
    target_kl=None,
    stats_window_size=100,
    tensorboard_log=None,
    policy_kwargs=None,
    verbose=0,
    seed=None,
    device='auto',
    _init_setup_model=True,
    env_wrapper='stable_baselines3.common.atari_wrappers.AtariWrapper',
    frame_stack=4,
)

wandb.init(project=save_name, config=config)
config = wandb.config

# Create callbacks

In [21]:
vec_eval_env = make_atari_env(config.env_id, n_envs=config.n_envs)
vec_eval_env = VecFrameStack(vec_eval_env, n_stack=config.frame_stack)
vec_eval_env = VecTransposeImage(vec_eval_env)

# WandbCallback
wandb_callback_after_eval = WandbRewardLoggingCallback()

# Save best model
eval_callback = EvalCallback(vec_eval_env, best_model_save_path="./logs/",
                             log_path="./logs/", eval_freq=max(500 // config.n_envs, 1), callback_after_eval=wandb_callback_after_eval,
                             deterministic=True, render=False)


# Create vectorized env and stack frames

In [22]:
vec_train_env = make_atari_env(config.env_id, n_envs=config.n_envs)
# Frame-stacking with 4 frames
vec_train_env = VecFrameStack(vec_train_env, n_stack=config.frame_stack)
vec_train_env = VecTransposeImage(vec_train_env)

# Create model, learn and save with wandb

In [23]:
if train_model:
    ppo_params_keys = [
        'policy', 'learning_rate', 'n_steps', 'batch_size', 'n_epochs',
        'gamma', 'gae_lambda', 'clip_range', 'clip_range_vf', 'normalize_advantage',
        'ent_coef', 'vf_coef', 'max_grad_norm', 'use_sde', 'sde_sample_freq',
        'rollout_buffer_class', 'rollout_buffer_kwargs', 'target_kl',
        'stats_window_size', 'tensorboard_log', 'policy_kwargs', 'verbose',
        'seed', 'device', '_init_setup_model'
    ]   
    
    # Step 2: Filter the config dictionary to extract only the hyperparameters for PPO
    ppo_hyperparams = {key: config[key] for key in ppo_params_keys if key in config}
    
    # Step 3: Unpack the filtered hyperparameters dictionary into the PPO constructor
    model = PPO(**ppo_hyperparams, env=vec_train_env)
    
    model = PPO(config.policy, vec_train_env, verbose=1)
    wandb.watch(model.policy, log="all", log_freq=10)
    model.learn(total_timesteps=config.n_timesteps, callback=eval_callback, progress_bar=progress_bar)
    model.save(save_name)

Using cpu device
Eval num_timesteps=496, episode_reward=2.20 +/- 0.75
Episode length: 236.80 +/- 26.13
---------------------------------
| eval/              |          |
|    mean_ep_length  | 237      |
|    mean_reward     | 2.2      |
| time/              |          |
|    total_timesteps | 496      |
---------------------------------
New best mean reward!
Eval num_timesteps=992, episode_reward=2.40 +/- 1.02
Episode length: 5599.40 +/- 10700.34
---------------------------------
| eval/              |          |
|    mean_ep_length  | 5.6e+03  |
|    mean_reward     | 2.4      |
| time/              |          |
|    total_timesteps | 992      |
---------------------------------
New best mean reward!


KeyboardInterrupt: 

# Load and evaluate Model

In [None]:
if eval_model:
    model = PPO.load("logs/best_model.zip", env=vec_eval_env)
    mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=2, render=False, fps=30)
    print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
    

# Wrap up

In [None]:
wandb.finish()

# Export model to ONNX

In [None]:
#Example for creating an ONNX model (should be saved to wandb)
dummy_input = torch.randn(1, 4, 84, 84)  # Batch size of 1

torch.onnx.export(model.policy,             # Model's policy to export
                  dummy_input,              # Example input for the model
                  save_name + ".onnx", # Path to save the ONNX model
                  export_params=True,       # Export model parameters
                  opset_version=11,         # Set the ONNX version
                  do_constant_folding=True, # Optimization
                  input_names=['input'],    # Naming the input layer
                  output_names=['action_output', 'value_output'], # Naming the output layers
                  dynamic_axes={'input': {0: 'batch_size'},    # Handling variable batch sizes
                                'action_output': {0: 'batch_size'},
                                'value_output': {0: 'batch_size'}})
