In [3]:
import os
import sys
from copy import deepcopy

import gymnasium as gym
import numpy as np
import torch
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))


from typing import Dict

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
from torch.optim import Adam

from rlib.algorithms.model_free.ddpg import ddpg
from rlib.algorithms.model_free.sac import sac
from rlib.algorithms.model_free.td3 import td3
from rlib.common.buffer import ReplayBuffer, RolloutBuffer
from rlib.common.evaluation import get_trajectory, validation
from rlib.common.logger import TensorBoardLogger
from rlib.common.losses import sac_loss
from rlib.common.policies import DeterministicMlpPolicy, MlpQCritic, StochasticMlpPolicy
from rlib.common.utils import smooth_update

%load_ext autoreload
%autoreload 2

In [34]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

discrete = False

In [35]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print(obs_dim, action_dim)

3 1


In [36]:
def get_filled_buffer(obs_dim, action_dim, data):
    rb = ReplayBuffer(obs_dim, action_dim)

    rb.observations = torch.tensor(data["observations"][:-1, :], dtype=torch.float32).squeeze(dim=1).detach()
    rb.next_observations = torch.tensor(data["observations"][1:, :], dtype=torch.float32).squeeze(dim=1).detach()
    rb.actions = torch.tensor(data["actions"][:-1, :], dtype=torch.float32).detach()
    rb.rewards = torch.tensor(data["rewards"][:-1, :], dtype=torch.float32).reshape(-1, 1).detach()
    rb.terminated = torch.tensor(data["terminated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()
    rb.truncated = torch.tensor(data["truncated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()

    rb.size = rb.observations.shape[0]

    return rb

In [37]:
def offline_sac(
    data,
    actor: StochasticMlpPolicy,
    critic_1: MlpQCritic,
    critic_2: MlpQCritic,
    actor_optimizer: Adam,
    critic_1_optimizer: Adam,
    critic_2_optimizer: Adam,
    total_episodes: int = 10_000,
    batch_size: int = 256,
    target_update_frequency: int = 1,
):
    buffer = get_filled_buffer(
        data["observations"][0].shape[0], data["actions"][0].shape[0], data
    )

    logger = TensorBoardLogger(log_dir="./tb_logs/offline_sac_")

    critic_1_target = deepcopy(critic_1)
    critic_2_target = deepcopy(critic_2)

    for episode_n in range(total_episodes):
        batch = buffer.get_batch(batch_size)

        loss = sac_loss(
            batch,
            actor,
            critic_1,
            critic_2,
            critic_1_target,
            critic_2_target,
        )

        actor_optimizer.zero_grad()
        loss["actor"].backward()
        actor_optimizer.step()

        critic_1_optimizer.zero_grad()
        loss["critic_1"].backward()
        critic_1_optimizer.step()

        critic_2_optimizer.zero_grad()
        loss["critic_2"].backward()
        critic_2_optimizer.step()

        if episode_n % target_update_frequency == 0:
            critic_1_target = smooth_update(critic_1, critic_1_target)
            critic_2_target = smooth_update(critic_2, critic_2_target)

        # Logging
        logger.log_scalars(loss, episode_n)

        if episode_n % 100:
            logs = {}            
            trajectory = get_trajectory(env, actor)
            logs["eval_traj_reward"] = np.sum(trajectory["rewards"])
            logger.log_scalars(logs, episode_n)

In [38]:
data = torch.load("./models/pendulum_expert_data")

In [39]:
actor = StochasticMlpPolicy(obs_dim, action_dim)
critic_1 = MlpQCritic(obs_dim, action_dim)
critic_2 = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_1_optimizer = Adam(critic_1.parameters(), lr=1e-3)
critic_2_optimizer = Adam(critic_2.parameters(), lr=1e-3)

In [40]:
offline_sac(
    data,
    actor,
    critic_1,
    critic_2,
    actor_optimizer,
    critic_1_optimizer,
    critic_2_optimizer,
    total_episodes=10_000,
)

  rb.observations = torch.tensor(data["observations"][:-1, :], dtype=torch.float32).squeeze(dim=1).detach()
  rb.next_observations = torch.tensor(data["observations"][1:, :], dtype=torch.float32).squeeze(dim=1).detach()
  rb.actions = torch.tensor(data["actions"][:-1, :], dtype=torch.float32).detach()
  rb.rewards = torch.tensor(data["rewards"][:-1, :], dtype=torch.float32).reshape(-1, 1).detach()
  rb.terminated = torch.tensor(data["terminated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()
  rb.truncated = torch.tensor(data["truncated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()


KeyboardInterrupt: 

In [None]:
def cql_loss(
    data,
    actor: StochasticMlpPolicy,
    critic_1: MlpQCritic,
    critic_2: MlpQCritic,
    alpha: float,
) -> Dict[str, torch.Tensor]:

    loss = {}

    observations = data["observations"]
    actions = data["actions"]

    random_actions = torch.FloatTensor(actions.shape).uniform_(-1, 1)
    
    curr_policy_actions, curr_policy_log_probs = actor.get_action(observations)
    
    q1_values_dataset = critic_1(observations, actions)
    q2_values_dataset = critic_2(observations, actions)
    
    q1_values_random = critic_1(observations, random_actions)
    q2_values_random = critic_2(observations, random_actions)

    q1_values_curr_policy = critic_1(observations, curr_policy_actions)
    q2_values_curr_policy = critic_2(observations, curr_policy_actions)
    
    random_density = 0.5 ** actions.shape[1]  # Uniform(-1,1) probability
    random_density_log_prob = torch.log(torch.tensor(random_density))

    cat_q1_values = torch.cat([
        q1_values_random - random_density_log_prob,
        q1_values_curr_policy - curr_policy_log_probs.detach()
    ], dim=1)
    logsumexp_q1_values = torch.logsumexp(cat_q1_values, dim=1)

    cat_q2_values = torch.cat([
        q2_values_random - random_density_log_prob,
        q2_values_curr_policy - curr_policy_log_probs.detach()
    ], dim=1)
    logsumexp_q2_values = torch.logsumexp(cat_q2_values, dim=1)
    
    loss["critic_1_reg"] = alpha * (logsumexp_q1_values - q1_values_dataset).mean()
    loss["critic_2_reg"] = alpha * (logsumexp_q2_values - q2_values_dataset).mean()
    
    return loss

In [55]:
def cql_sac(
    data,
    actor: StochasticMlpPolicy,
    critic_1: MlpQCritic,
    critic_2: MlpQCritic,
    actor_optimizer: Adam,
    critic_1_optimizer: Adam,
    critic_2_optimizer: Adam,
    total_episodes: int = 10_000,
    batch_size: int = 256,
    target_update_frequency: int = 1,
    alpha: float = 3,
):
    buffer = get_filled_buffer(
        data["observations"][0].shape[0], data["actions"][0].shape[0], data
    )

    logger = TensorBoardLogger(log_dir="./tb_logs/cql_sac_")

    critic_1_target = deepcopy(critic_1)
    critic_2_target = deepcopy(critic_2)

    for episode_n in range(total_episodes):
        batch = buffer.get_batch(batch_size)

        loss_sac = sac_loss(
            batch,
            actor,
            critic_1,
            critic_2,
            critic_1_target,
            critic_2_target,
        )

        loss_cql = cql_loss(
            batch,
            actor,
            critic_1,
            critic_2,
            alpha,
        )

        actor_optimizer.zero_grad()
        loss_sac["actor"].backward()
        actor_optimizer.step()

        critic_1_optimizer.zero_grad()
        (loss_sac["critic_1"] + loss_cql["critic_1_reg"]).backward()
        critic_1_optimizer.step()

        critic_2_optimizer.zero_grad()
        (loss_sac["critic_2"] + loss_cql["critic_2_reg"]).backward()
        critic_2_optimizer.step()

        if episode_n % target_update_frequency == 0:
            critic_1_target = smooth_update(critic_1, critic_1_target)
            critic_2_target = smooth_update(critic_2, critic_2_target)

        # Logging
        logger.log_scalars(loss_sac, episode_n)
        logger.log_scalars(loss_cql, episode_n)

        if episode_n % 100 == 0:
            logs = {}            
            trajectory = get_trajectory(env, actor)
            logs["eval_traj_reward"] = np.sum(trajectory["rewards"])
            logger.log_scalars(logs, episode_n)

In [56]:
data = torch.load("./models/pendulum_expert_data")

In [57]:
actor = StochasticMlpPolicy(obs_dim, action_dim)
critic_1 = MlpQCritic(obs_dim, action_dim)
critic_2 = MlpQCritic(obs_dim, action_dim)

actor_optimizer = Adam(actor.parameters(), lr=1e-3)
critic_1_optimizer = Adam(critic_1.parameters(), lr=1e-3)
critic_2_optimizer = Adam(critic_2.parameters(), lr=1e-3)

In [58]:
cql_sac(
    data,
    actor,
    critic_1,
    critic_2,
    actor_optimizer,
    critic_1_optimizer,
    critic_2_optimizer,
    total_episodes=10_000,
    alpha=0,
)

  rb.observations = torch.tensor(data["observations"][:-1, :], dtype=torch.float32).squeeze(dim=1).detach()
  rb.next_observations = torch.tensor(data["observations"][1:, :], dtype=torch.float32).squeeze(dim=1).detach()
  rb.actions = torch.tensor(data["actions"][:-1, :], dtype=torch.float32).detach()
  rb.rewards = torch.tensor(data["rewards"][:-1, :], dtype=torch.float32).reshape(-1, 1).detach()
  rb.terminated = torch.tensor(data["terminated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()
  rb.truncated = torch.tensor(data["truncated"][:-1, :], dtype=torch.int8).reshape(-1, 1).detach()
  "observations": torch.tensor(
  "next_observations": torch.tensor(
  "actions": torch.tensor(self.actions[indices], dtype=torch.float32),
  "rewards": torch.tensor(self.rewards[indices], dtype=torch.float32),
  "terminated": torch.tensor(self.terminated[indices], dtype=torch.int8),
  "truncated": torch.tensor(self.truncated[indices], dtype=torch.int8),


In [58]:
rb = RolloutBuffer()

In [60]:
rb.collect_rollouts(env, actor, rollout_size=10)

In [61]:
data = rb.get_data()

  return torch.tensor(value, dtype=dtype)


In [None]:
observations = data["observations"]
actions = data["actions"]
rewards = data["rewards"]
terminated = data["terminated"]

# 1. Сэмплируем действия из текущей политики
actor_actions, _ = actor.get_action(observations)

# 2. Сэмплируем случайные действия из равномерного распределения
random_actions = 2 * torch.rand(actor_actions.shape) - 1

# 3. Берём смесь случайных и политических действий
mix_mask = torch.randint(0, 2, rewards.shape, dtype=torch.bool)
sampled_actions = torch.where(mix_mask, actor_actions, random_actions)

# 4. Вычисляем Q-значения для sampled действий
q1_sampled = critic_1(observations, sampled_actions)
q2_sampled = critic_2(observations, sampled_actions)

# 5. Вычисляем CQL-потерю
alpha = 0.3
cql1_loss = alpha * (q1_sampled.mean() - critic_1(observations, actions).mean())
cql2_loss = alpha * (q2_sampled.mean() - critic_2(observations, actions).mean())


torch.Size([10, 1]) torch.Size([10, 1]) torch.Size([])
