In [201]:
import os
import sys

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.a2c import a2c
from rlib.algorithms.ppo import ppo
from rlib.algorithms.reinforce import reinforce
from rlib.common.evaluation import validation
from rlib.common.policies import (
    DiscreteStochasticMlpPolicy,
    MlpCritic,
    StochasticMlpPolicy,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [226]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

In [202]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

In [213]:
env = gym.make("BipedalWalker-v3", render_mode="rgb_array")

In [225]:
discrete = True

obs_dim = env.observation_space.shape[0]

if discrete:
    action_dim = env.action_space.n
else:
    action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

4 2


### Reinforce

In [229]:
if discrete:
    policy = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    policy = StochasticMlpPolicy(obs_dim, action_dim)

optimizer = Adam(policy.parameters(), lr=1e-3)

In [230]:
reinforce(env, policy, optimizer, total_timesteps=100_000)

steps_n: 1026
mean_trajectory_rewards: 17.0
mean_trajectory_length: 17.0
steps_n: 2076
mean_trajectory_rewards: 41.0
mean_trajectory_length: 41.0
steps_n: 3070
mean_trajectory_rewards: 41.0
mean_trajectory_length: 41.0
steps_n: 4036
mean_trajectory_rewards: 28.0
mean_trajectory_length: 28.0
steps_n: 5044
mean_trajectory_rewards: 25.0
mean_trajectory_length: 25.0
steps_n: 6042
mean_trajectory_rewards: 31.0
mean_trajectory_length: 31.0
steps_n: 7002
mean_trajectory_rewards: 46.0
mean_trajectory_length: 46.0
steps_n: 8008
mean_trajectory_rewards: 16.0
mean_trajectory_length: 16.0
steps_n: 9014
mean_trajectory_rewards: 17.0
mean_trajectory_length: 17.0
steps_n: 10032
mean_trajectory_rewards: 45.0
mean_trajectory_length: 45.0
steps_n: 11046
mean_trajectory_rewards: 48.0
mean_trajectory_length: 48.0
steps_n: 12018
mean_trajectory_rewards: 64.0
mean_trajectory_length: 64.0
steps_n: 13154
mean_trajectory_rewards: 94.0
mean_trajectory_length: 94.0
steps_n: 14128
mean_trajectory_rewards: 110.0
m

In [231]:
validation(env, policy, deterministic=True)

np.float64(322.3)

### A2C

In [232]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=3e-4)
critic_optimizer = Adam(critic.parameters(), lr=1e-4)

In [233]:
a2c(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=100_000)

steps_n: 1202
mean_trajectory_rewards: 31.5
mean_trajectory_length: 31.5
steps_n: 2145
mean_trajectory_rewards: 21.100000381469727
mean_trajectory_length: 21.100000381469727
steps_n: 3176
mean_trajectory_rewards: 26.799999237060547
mean_trajectory_length: 26.80000114440918
steps_n: 4079
mean_trajectory_rewards: 22.700000762939453
mean_trajectory_length: 22.700000762939453
steps_n: 5129
mean_trajectory_rewards: 27.5
mean_trajectory_length: 27.5
steps_n: 6215
mean_trajectory_rewards: 25.899999618530273
mean_trajectory_length: 25.899999618530273
steps_n: 7160
mean_trajectory_rewards: 23.799999237060547
mean_trajectory_length: 23.80000114440918
steps_n: 8231
mean_trajectory_rewards: 30.899999618530273
mean_trajectory_length: 30.899999618530273
steps_n: 9009
mean_trajectory_rewards: 32.5
mean_trajectory_length: 32.5
steps_n: 10220
mean_trajectory_rewards: 25.299999237060547
mean_trajectory_length: 25.30000114440918
steps_n: 11313
mean_trajectory_rewards: 32.599998474121094
mean_trajectory_l

In [234]:
validation(env, actor, deterministic=True)

np.float64(73.25)

### PPO

In [235]:
if discrete:
    actor = DiscreteStochasticMlpPolicy(obs_dim, action_dim)
else:
    actor = StochasticMlpPolicy(obs_dim, action_dim)

critic = MlpCritic(obs_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-4)
critic_optimizer = Adam(critic.parameters(), lr=5e-4)

In [236]:
ppo(env, actor, critic, actor_optimizer, critic_optimizer, total_timesteps=100_000)

steps_n: 1489
mean_trajectory_rewards: 25.850000381469727
mean_trajectory_length: 25.850000381469727
steps_n: 2226
mean_trajectory_rewards: 36.79999923706055
mean_trajectory_length: 36.79999923706055
steps_n: 3079
mean_trajectory_rewards: 42.599998474121094
mean_trajectory_length: 42.60000228881836
steps_n: 4266
mean_trajectory_rewards: 59.29999923706055
mean_trajectory_length: 59.29999923706055
steps_n: 6534
mean_trajectory_rewards: 113.3499984741211
mean_trajectory_length: 113.3499984741211
steps_n: 10041
mean_trajectory_rewards: 175.3000030517578
mean_trajectory_length: 175.3000030517578
steps_n: 14533
mean_trajectory_rewards: 224.5500030517578
mean_trajectory_length: 224.5500030517578
steps_n: 19944
mean_trajectory_rewards: 270.5
mean_trajectory_length: 270.5
steps_n: 25425
mean_trajectory_rewards: 274.0
mean_trajectory_length: 274.0
steps_n: 32757
mean_trajectory_rewards: 366.54998779296875
mean_trajectory_length: 366.5500183105469
steps_n: 38646
mean_trajectory_rewards: 294.39999

In [237]:
validation(env, actor, deterministic=True)

np.float64(485.6)

### db

In [140]:
obs, _ = env.reset()

In [141]:
import torch

In [164]:
actor.forward(torch.FloatTensor(env.observation_space.sample().reshape(1, -1)))

(tensor([[0.1195]], grad_fn=<TanhBackward0>),
 tensor([[0.0302]], grad_fn=<TanhBackward0>))

In [171]:
actor.predict(env.observation_space.sample())

(array([0.6580007], dtype=float32),
 tensor([[-1.1629]], grad_fn=<ViewBackward0>))

In [None]:
from rlib.common.buffer import RolloutBuffer

In [49]:
rb = RolloutBuffer()

In [65]:
rb.collect_rollouts(env, actor, rollout_size=10)

In [66]:
data = rb.get_data()

In [67]:
observations = data["observations"]
actions = data["actions"]
old_log_probs = data["log_probs"]
epsilon = 0.1

In [72]:
loss = {}

_, new_log_probs = actor.get_action(observations, action=actions)

ratio = torch.exp(new_log_probs - old_log_probs.detach())
ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

values = critic(observations).reshape(ratio.shape)

targets = data["q_estimations"].reshape(ratio.shape)
advantages = targets.detach() - values
# print(targets.shape, values.shape, advantages.shape)
# print(old_log_probs.shape, new_log_probs.shape, ratio.shape)

if False:
    mean = advantages.mean()
    std = advantages.std()
    advantages = (advantages - mean) / (std + 1e-8)

actor_loss_1 = ratio * advantages.detach()
actor_loss_2 = ratio_clipped * advantages.detach()

print(advantages)
print(advantages.shape, ratio.shape)

loss["actor"] = -(torch.min(actor_loss_1, actor_loss_2)).mean()
loss["critic"] = (advantages**2).mean()

print(loss["actor"], loss["critic"])

tensor([[-31.3869],
        [-31.8648],
        [-32.4259],
        [-32.6909],
        [-32.5504],
        [-32.0703],
        [-30.3432],
        [-26.3773],
        [-20.5584],
        [-12.6074]], grad_fn=<SubBackward0>)
torch.Size([10, 1]) torch.Size([10, 1])
tensor(25.7980, grad_fn=<NegBackward0>) tensor(840.6730, grad_fn=<MeanBackward0>)


In [13]:
validation(env, actor, deterministic=True)

-945.4736045591187

In [142]:
mu = torch.zeros((1, 2))
mu

tensor([[0., 0.]])

In [144]:
std = torch.ones((1, 2))
std

tensor([[1., 1.]])

In [145]:
from torch.distributions import Normal

In [146]:
dist = Normal(mu, std)
action = dist.sample()

In [148]:
action.shape

torch.Size([1, 2])

In [151]:
dist.log_prob(action).shape

torch.Size([1, 2])

In [153]:
dist.log_prob(action).sum(dim=1)

tensor([-1.9676])

In [180]:
from stable_baselines3 import PPO

In [191]:
env = gym.make("Pendulum-v1")
agent = PPO("MlpPolicy", env)

In [192]:
env.reset()

(array([ 0.03049863, -0.9995348 ,  0.8736682 ], dtype=float32), {})

In [193]:
action, _ = agent.predict(
    env.observation_space.sample(),
)

In [195]:
action

array([-1.9694774], dtype=float32)

In [194]:
env.step(action)

(array([ 0.02193137, -0.9997595 , -0.17140453], dtype=float32),
 -2.4527108869676995,
 False,
 False,
 {})