In [1]:
import os
import sys
import pickle
from copy import deepcopy

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from gymnasium.wrappers import RescaleAction
from torch import nn
from torch.distributions import Normal
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.sac import sac
from rlib.common.buffer import RolloutBuffer, ReplayBuffer
from rlib.common.policies import DeterministicMlpPolicy, StochasticMlpPolicy, MlpQCritic
from rlib.common.evaluation import get_trajectory, validation
from rlib.common.logger import TensorBoardLogger

%load_ext autoreload
%autoreload 2

In [2]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

In [3]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

3 1


## GCL

### Expert

In [4]:
actor = StochasticMlpPolicy(obs_dim, action_dim)
critic_1 = MlpQCritic(obs_dim, action_dim)
critic_2 = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-4)
critic_1_optimizer = Adam(critic_1.parameters(), lr=1e-3)
critic_2_optimizer = Adam(critic_2.parameters(), lr=1e-3)

In [5]:
sac(
    env,
    actor,
    critic_1,
    critic_2,
    actor_optimizer,
    critic_1_optimizer,
    critic_2_optimizer,
)

In [None]:
validation(env, actor)

-1090.8483348587272

In [None]:
with open("./models/pendulum_stoc_expert", "w") as file:
    pickle.dump(actor, file)

In [None]:
with open("./models/pendulum_stoc_expert", "r") as file:
    expert_actor = pickle.load(file)

In [None]:
validation(env, expert_actor)

In [None]:
rb = RolloutBuffer()

In [None]:
rb.collect_rollouts(env, expert_actor, trajectories_n=30)

In [None]:
expert_data = rb.get_data()

### Train

In [None]:
class RewardNet(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_size=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, observations, actions):
        """
        Args:
            observations (torch.Tensor): (B, obs_dim)
            actions (torch.Tensor): (B, action_dim)

        Returns:
            q_values: (torch.Tensor): (B, 1)
        """
        input = torch.cat((observations, actions), dim=1)
        return self.net(input)

In [None]:
def gcl_loss():
    pass

In [None]:
def gcl(
    env: gym.Env,
    expert_trajectories: dict[str, torch.Tensor],
    learning_actor: DeterministicMlpPolicy,
    actor_optimizer: Adam,
    reward_net: RewardNet,
    reward_optimizer: Adam,
    total_episodes: int = 1000,
):
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    replay_buffer = ReplayBuffer(obs_dim, action_dim)
    rollout_buffer = RolloutBuffer()

    logger = TensorBoardLogger(log_dir="./tb_logs/gcl_")

    for episode_n in range(total_episodes):

        rollout_buffer.collect_rollouts(env, learning_actor, trajectories_n=30)
        learning_trajectories = rollout_buffer.get_data()

        loss = gcl_loss

In [None]:
def sac(
    env: gym.Env,
    actor: DeterministicMlpPolicy,
    critic_1: MlpQCritic,
    critic_2: MlpQCritic,
    actor_optimizer: Adam,
    critic_1_optimizer: Adam,
    critic_2_optimizer: Adam,
    training_starts: int = 1000,
    total_timesteps: int = 50_000,
    batch_size: int = 128,
    target_update_frequency: int = 2,
):
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    buffer = ReplayBuffer(obs_dim, action_dim)

    logger = TensorBoardLogger(log_dir="./tb_logs/sac_")

    critic_1_target = deepcopy(critic_1)
    critic_2_target = deepcopy(critic_2)

    steps_n = 0
    while steps_n < total_timesteps:
        buffer.collect_transition(env, actor)
        steps_n += 1

        if buffer.size < training_starts:
            continue

        batch = buffer.get_batch(batch_size)

        loss = sac_loss(
            batch,
            actor,
            critic_1,
            critic_2,
            critic_1_target,
            critic_2_target,
        )

        loss["actor"].backward()
        actor_optimizer.step()
        actor_optimizer.zero_grad()

        loss["critic_1"].backward()
        critic_1_optimizer.step()
        critic_1_optimizer.zero_grad()

        loss["critic_2"].backward()
        critic_2_optimizer.step()
        critic_2_optimizer.zero_grad()

        if steps_n % target_update_frequency == 0:
            critic_1_target = smooth_update(critic_1, critic_1_target)
            critic_2_target = smooth_update(critic_2, critic_2_target)

        # Logging
        steps_n += 1
        logger.log_scalars(loss, steps_n)

        if buffer.done:
            trajectory = buffer.get_last_trajectory()
            logger.log_trajectories(trajectory)
