In [1]:
import os
import sys
from copy import deepcopy

import gymnasium as gym
import torch
from gymnasium.wrappers import RescaleAction
from torch.distributions import Normal
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.model_free.ddpg import ddpg
from rlib.algorithms.model_free.sac import sac
from rlib.algorithms.model_free.td3 import td3
from rlib.common.evaluation import get_trajectory, validation
from rlib.common.policies import DeterministicMlpPolicy, MlpQCritic, StochasticMlpPolicy

%load_ext autoreload
%autoreload 2

In [2]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

discrete = False

In [3]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

3 1


### DDPG

In [4]:
actor = DeterministicMlpPolicy(obs_dim, action_dim)
critic = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_optimizer = Adam(critic.parameters(), lr=1e-3)

In [None]:
ddpg(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=30_000)

In [None]:
validation(env, actor)

-161.23902026577827

### TD3

In [12]:
actor = DeterministicMlpPolicy(obs_dim, action_dim)
critic_1 = MlpQCritic(obs_dim, action_dim)
critic_2 = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_1_optimizer = Adam(critic_1.parameters(), lr=1e-3)
critic_2_optimizer = Adam(critic_2.parameters(), lr=1e-3)

In [None]:
td3(
    env,
    actor,
    critic_1,
    critic_2,
    actor_optimizer,
    critic_1_optimizer,
    critic_2_optimizer,
    total_timesteps=20_000,
)

In [None]:
validation(env, actor)

-1427.9277668169286

### SAC

In [10]:
actor = StochasticMlpPolicy(obs_dim, action_dim)
critic_1 = MlpQCritic(obs_dim, action_dim)
critic_2 = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_1_optimizer = Adam(critic_1.parameters(), lr=1e-3)
critic_2_optimizer = Adam(critic_2.parameters(), lr=1e-3)

In [11]:
sac(
    env,
    actor,
    critic_1,
    critic_2,
    actor_optimizer,
    critic_1_optimizer,
    critic_2_optimizer,
    total_timesteps=10_000,
)

In [22]:
validation(env, actor, deterministic=True)

-756.9275584807227