In [1]:
import os
import sys

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.a2c import a2c
from rlib.algorithms.ppo import ppo
from rlib.algorithms.reinforce import reinforce
from rlib.common.evaluation import validation
from rlib.common.policies import (
    DiscreteStochasticMlpPolicy,
    MlpCritic,
    StochasticMlpPolicy,
)

%load_ext autoreload
%autoreload 2

In [226]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

In [2]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

In [213]:
env = gym.make("BipedalWalker-v3", render_mode="rgb_array")

In [4]:
discrete = False

obs_dim = env.observation_space.shape[0]

if discrete:
    action_dim = env.action_space.n
else:
    action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

3 1


### Reinforce

In [146]:
if discrete:
    policy = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    policy = StochasticMlpPolicy(obs_dim, action_dim)

optimizer = Adam(policy.parameters(), lr=1e-3)

In [147]:
reinforce(env, policy, optimizer, total_timesteps=100_000)

In [148]:
validation(env, policy, deterministic=True)

-1153.2254475851644

### A2C

In [149]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=3e-4)
critic_optimizer = Adam(critic.parameters(), lr=1e-4)

In [150]:
a2c(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=100_000)

In [151]:
validation(env, actor, deterministic=True)

-1314.010065162001

### PPO

In [155]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-4)
critic_optimizer = Adam(critic.parameters(), lr=5e-4)

In [156]:
ppo(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=100_000)

In [157]:
validation(env, actor, deterministic=True)

-231.61235894825842

### db

In [140]:
obs, _ = env.reset()

In [141]:
import torch

In [164]:
actor.forward(torch.FloatTensor(env.observation_space.sample().reshape(1, -1)))

(tensor([[0.1195]], grad_fn=<TanhBackward0>),
 tensor([[0.0302]], grad_fn=<TanhBackward0>))

In [171]:
actor.predict(env.observation_space.sample())

(array([0.6580007], dtype=float32),
 tensor([[-1.1629]], grad_fn=<ViewBackward0>))

In [116]:
from rlib.common.buffer import RolloutBuffer

In [123]:
rb = RolloutBuffer()

In [124]:
rb.collect_rollouts(env, actor, rollout_size=500)

In [100]:
rb.collect_rollouts(env, policy, trajectories_n=1)

In [125]:
data = rb.get_data()

In [126]:
print(
    data["observations"].shape,
    data["actions"].shape,
    data["rewards"].shape,
    data["terminated"].shape,
    data["log_probs"][0].shape,
    data["q_estimations"].shape,
)

torch.Size([500, 3]) torch.Size([500, 1]) torch.Size([500, 1]) torch.Size([500, 1]) torch.Size([1]) torch.Size([500, 1])


In [128]:
loss = {}

returns = data["q_estimations"]
log_probs = data["log_probs"]

if True:
    mean = returns.mean()
    std = returns.std()
    returns = (returns - mean) / (std + 1e-8)

loss["actor"] = -(log_probs * returns).mean()

print(returns.shape, log_probs.shape, (log_probs * returns).shape)
print(loss["actor"].shape)

torch.Size([500, 1]) torch.Size([500, 1]) torch.Size([500, 1])
torch.Size([])


In [129]:
import torch

In [132]:
loss = {}

observations = data["observations"]
log_probs = data["log_probs"]
targets = data["q_estimations"]

values = critic(observations)
advantages = targets[:-1].detach() - values[:-1]

loss["actor"] = -(log_probs[:-1] * advantages.detach()).mean()
loss["critic"] = (advantages**2).mean()

print(values.shape, advantages.shape, (log_probs[:-1] * advantages).shape)
print(loss["actor"].shape)

torch.Size([500, 1]) torch.Size([499, 1]) torch.Size([499, 1])
torch.Size([])


In [133]:
observations = data["observations"]
actions = data["actions"]
old_log_probs = data["log_probs"]
epsilon = 0.1

In [136]:
loss = {}

observations = data["observations"]
old_log_probs = data["log_probs"]
actions = data["actions"]
targets = data["q_estimations"]

_, new_log_probs = actor.get_action(observations, action=actions)

ratio = torch.exp(new_log_probs - old_log_probs.detach())
ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

values = critic(observations)

advantages = targets.detach() - values

actor_loss_1 = ratio * advantages.detach()
actor_loss_2 = ratio_clipped * advantages.detach()

loss["actor"] = -(torch.min(actor_loss_1, actor_loss_2)).mean()
loss["critic"] = (advantages**2).mean()

print(ratio.shape, new_log_probs.shape, torch.min(actor_loss_1, actor_loss_2).shape)
print(loss["actor"].shape, loss["critic"].shape)

torch.Size([500, 1]) torch.Size([500, 1]) torch.Size([500, 1])
torch.Size([]) torch.Size([])
