In [None]:
import os
import sys
from copy import deepcopy

import gymnasium as gym
import torch
from torch import nn
from torch.distributions import Normal
from torch.optim import Adam

sys.path.append(os.path.abspath(".."))

from rlib.common.logger import Logger
from rlib.common.policies import (
    MlpCritic,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
env = gym.make("Pendulum-v1")

### DDPG

In [None]:
class DeterministicMlpPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_size=128):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_size = hidden_size

        self.alpha = 0.15
        self.mu = 0
        self.sigma = 0.2
        self.epsilon = 0

        loc = self.mu * torch.ones((1, self.action_dim))
        scale = self.sigma * torch.ones((1, self.action_dim))
        self.dist = Normal(loc, scale)

        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_dim),
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): (B, obs_dim)

        Returns:
            output: (torch.Tensor): (B, action_dim)
        """
        output = torch.tanh(self.net(input))
        return output

    def get_action(self, input, deterministic=False):
        """
        Args:
            input (torch.Tensor): (B, obs_dim)

        Returns:
            action: (torch.Tensor): (B, action_dim)
            log_prob_action: (torch.Tensor): (B, 1)
        """
        action = self.forward(input)

        if not deterministic:
            self.epsilon = self.alpha * self.epsilon + self.dist.sample()
            action += self.epsilon

        return action

    def predict(self, observation, action=None, deterministic=False):
        """
        Called for env observation

        Args:
            observation (np.ndarray): (obs_dim,)

        Returns:
            action: (np.ndarray): (action_dim,)
            log_prob_action: (torch.Tensor): (1, 1)
        """

        expected_shape = (self.obs_dim,)
        if observation.shape != (self.obs_dim,):
            raise ValueError(
                f"Expected shape {expected_shape}, but got {observation.shape}"
            )

        input = torch.FloatTensor(observation.reshape(1, self.obs_dim))

        action = self.get_action(input, deterministic)
        action = action.detach().numpy()
        action = action.reshape((self.action_dim,))

        return action

In [None]:
class ReplayBuffer:
    def __init__(self, gamma: int = 0.99, min_size: int = 1000, max_size: int = 10000):
        self.clear()
        self.gamma = gamma
        self.min_size = min_size
        self.max_size = max_size

    def clear(self):
        self.observations = []
        self.next_observations = []
        self.actions = []
        self.rewards = []
        self.terminated = []
        self.truncated = []

    def add_transition(self, obs, next_obs, action, reward, terminated, truncated):
        self.observations.append(obs)
        self.next_observations.append(next_obs)
        self.actions.append(action)
        self.rewards.append(reward)
        self.terminated.append(terminated)
        self.truncated.append(truncated)

    def collect_transition(env, policy):
        pass

    def get_batch(batch_size):
        pass

In [None]:
def ddpg_loss():
    pass

In [None]:
def smooth_update(model, target_model, tau: float = 0.99):
    for param, target_param in zip(model.parameters(), target_model.parameters()):
        new_terget_param = tau * target_param + (1 - tau) * param
        target_param.data.copy_(new_terget_param)

    return target_model

In [None]:
def ddpg(
    env,
    actor: DeterministicMlpPolicy,
    critic: MlpCritic,
    actor_optimizer: Adam,
    critic_optimizer: Adam,
    total_timesteps: int = 50_000,
    batch_size: int = 256,
):
    buffer = ReplayBuffer()
    logger = Logger()

    actor_target = deepcopy(actor)
    critic_target = deepcopy(critic)

    steps_n = 0
    while steps_n < total_timesteps:
        buffer.collect_transition(env, actor)
        steps_n += 1

        if buffer.size < buffer.min_size:
            continue

        batch = buffer.get_batch(batch_size)

        loss = ddpg_loss(batch, actor, critic, actor_target, critic_target)

        loss["actor"].backward()
        actor_optimizer.step()
        actor_optimizer.zero_grad()

        loss["critic"].backward()
        critic_optimizer.step()
        critic_optimizer.zero_grad()

        actor_target = smooth_update(actor, actor_target)
        critic_target = smooth_update(critic, critic_target)

        logger.log(steps_n, batch)  # сюда не batch
