In [229]:
import os
import pickle
import sys

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.model_free.a2c import a2c
from rlib.algorithms.model_free.ppo import ppo
from rlib.algorithms.model_free.reinforce import reinforce
from rlib.common.evaluation import get_trajectory, validation
from rlib.common.policies import (
    DiscreteStochasticMlpPolicy,
    MlpCritic,
    StochasticMlpPolicy,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [112]:
env = gym.make("CartPole-v1", render_mode="rgb_array")
discrete = True

In [230]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

discrete = False

In [231]:
obs_dim = env.observation_space.shape[0]

if discrete:
    action_dim = env.action_space.n
else:
    action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

3 1


### Reinforce

In [232]:
if discrete:
    policy = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    policy = StochasticMlpPolicy(obs_dim, action_dim)

optimizer = Adam(policy.parameters(), lr=1e-3)

In [233]:
reinforce(env, policy, optimizer, total_timesteps=100_000)

In [234]:
validation(env, policy, deterministic=True)

-1490.2763956995082

### A2C

In [235]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_optimizer = Adam(critic.parameters(), lr=1e-3)

In [236]:
a2c(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=100_000)

In [237]:
validation(env, actor, deterministic=True)

-1569.926548787393

### PPO

In [240]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_optimizer = Adam(critic.parameters(), lr=1e-3)

In [241]:
ppo(
    env,
    actor,
    critic,
    actor_optimizer,
    critic_optimizer,
    total_timesteps=100_000,
)

In [None]:
validation(env, actor, deterministic=True)

-291.40888919071506