In [1]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))


class ReplayBuffer(object):
  def __init__(self, capacity):
    self.memory = deque([], maxlen=capacity)

  def push(self, *args):
    self.memory.append(Transition(*args))

  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

In [2]:
import matplotlib.pyplot as plt
from IPython import display

episode_rewards = []

def plot_durations(show_result=False):
  global episode_rewards

  plt.figure(1)
  durations_t = torch.tensor(episode_rewards, dtype=torch.float)
  if show_result:
    plt.title('Result')
  else:
    plt.clf()
    plt.title('Training...')
  plt.xlabel('Episode')
  plt.ylabel('Rewards')
  plt.plot(durations_t.numpy())
  # Take 100 episode averages and plot them too
  if len(durations_t) >= 100:
    means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
    means = torch.cat((torch.zeros(99), means))
    plt.plot(means.numpy())

  plt.pause(0.001)  # pause a bit so that plots are updated
  if not show_result:
    display.display(plt.gcf())
    display.clear_output(wait=True)
  else:
    display.display(plt.gcf())

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, Independent
import gymnasium as gym
import numpy as np
from itertools import count

GAMMA = 0.99
TAU = 0.995
BATCH_SIZE = 128
MAX_EPISODE_LENGTH = 1000
NUM_EPISODE = 250

class Actor(nn.Module):
  def __init__(self, state_dim, hidden_dim, action_dim):
    super(Actor, self).__init__()
    self.base = nn.Sequential(
      nn.Linear(state_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
    )
    self.means = nn.Linear(hidden_dim, action_dim)
    self.log_stds = nn.Linear(hidden_dim, action_dim)

  def forward(self, states: torch.tensor):
    x = self.base(states)
    means = self.means(x)
    log_stds = self.log_stds(x)
    # the gradient of computing log_stds first and then using torch.exp
    # is much more well-behaved then computing stds directly using nn.Softplus()
    # ref: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L26
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20
    stds = torch.exp(torch.clamp(log_stds, LOG_STD_MIN, LOG_STD_MAX))
    return Independent(Normal(loc=means, scale=stds), reinterpreted_batch_ndims=1)


class Critic(nn.Module):
  def __init__(self, state_dim, action_dim, hidden_dim):
    super(Critic, self).__init__()
    self.net = nn.Sequential(
      nn.Linear(state_dim + action_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, 1)
    )

  def forward(self, states: torch.tensor, actions: torch.tensor):
    return self.net(torch.cat([states, actions], dim=1))


class Agent(object):
  def __init__(self, input_dim, hidden_dim, action_dim):
    self.actor = Actor(state_dim=input_dim, hidden_dim=hidden_dim, action_dim=action_dim)
    self.actor.to('cuda')
    self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)

    self.critic = Critic(state_dim=input_dim, hidden_dim=hidden_dim, action_dim=action_dim)
    self.critic.to('cuda')
    self.critic_target = Critic(state_dim=input_dim, hidden_dim=hidden_dim, action_dim=action_dim)
    self.critic_target.to('cuda')
    self.critic_target.load_state_dict(self.critic.state_dict())
    self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)

  def select_action(self, state: torch.tensor, use_reparametrization_trick: bool) -> tuple:
    dist = self.actor(state)
    u = dist.rsample() if use_reparametrization_trick else dist.sample()
    a = torch.tanh(u)
    # the following line of code is not numerically stable:
    # log_pi_a_given_s = mu_given_s.log_prob(u) - torch.sum(torch.log(1 - torch.tanh(u) ** 2), dim=1)
    # ref: https://github.com/vitchyr/rlkit/blob/0073d73235d7b4265cd9abe1683b30786d863ffe/rlkit/torch/distributions.py#L358
    # ref: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73
    log_pi_a_given_s = dist.log_prob(u)
    return a, log_pi_a_given_s

  def clip_gradient(self, net: nn.Module) -> None:
    for param in net.parameters():
      param.grad.data.clamp_(-1, 1)

  def update_target(self, target: nn.Module, current: nn.Module) -> None:
    for p1, p2 in zip(target.parameters(), current.parameters()):
      p1.data.copy_(p1.data * TAU + p2.data * (1 - TAU))

  def update_networks(self, batch_data) -> None:
    batch_size = len(batch_data)
    states = torch.tensor([entry[0] for entry in batch_data], dtype=torch.float).to('cuda').view(batch_size, -1)
    actions = torch.tensor([entry[1] for entry in batch_data], dtype=torch.float).to('cuda').view(batch_size, 1)
    next_states = torch.tensor([entry[2] for entry in batch_data], dtype=torch.float).to('cuda').view(batch_size, -1)
    rewards = torch.tensor([entry[3] for entry in batch_data], dtype=torch.float).to('cuda').view(batch_size, 1)
    dones  = torch.tensor([entry[4] for entry in batch_data], dtype=torch.float).to('cuda').view(batch_size, 1)

    with torch.no_grad():
      next_actions, log_pi_na_given_ns = self.select_action(next_states, use_reparametrization_trick=False)
      targets = rewards + GAMMA * (1 - dones) * self.critic_target(next_states, next_actions)

    q_pred = self.critic(states, actions)
    critic_loss = torch.mean((q_pred - targets) ** 2)

    self.critic_optimizer.zero_grad()
    critic_loss.backward()
    self.clip_gradient(net=self.critic)
    self.critic_optimizer.step()


    for param in self.critic.parameters():
      param.requires_grad = False

    a, log_pi_a_given_s = self.select_action(states, use_reparametrization_trick=True)
    policy_loss = -torch.mean(self.critic(states, a))

    self.actor_optimizer.zero_grad()
    policy_loss.backward()
    self.clip_gradient(net=self.actor)
    self.actor_optimizer.step()

    for param in self.critic.parameters():
      param.requires_grad = True

    with torch.no_grad():
      self.update_target(self.critic_target, self.critic)

  def act(self, state: np.array) -> np.array:
    state = torch.tensor(state).to('cuda').unsqueeze(0).float()
    action, _ = self.select_action(state, use_reparametrization_trick=False)
    return action.cpu().detach().numpy()[0]



def run():
  env = gym.make('Pendulum-v1')
  # env = ScalingActionWrapper(env, scaling_factors=env_raw.action_space.high)
  replay_buffer = ReplayBuffer(int(1e6))
  agent = Agent(
    input_dim=env.observation_space.shape[0],
    hidden_dim=16,
    action_dim=env.action_space.shape[0],
  )

  for e in range(NUM_EPISODE):
    state, _ = env.reset()
    total_reward = 0
    total_updates = 0
    for t in count():
      action = agent.act(state)
      next_state, reward, done, info, _ = env.step(action)
      if t >= MAX_EPISODE_LENGTH:
        done = True
        
      total_reward += reward
      replay_buffer.push(
        state,
        action,
        next_state,
        reward,
        done,
      )

      if len(replay_buffer) >= BATCH_SIZE:
        agent.update_networks(replay_buffer.sample(BATCH_SIZE))
        total_updates += 1

      if done:
        episode_rewards.append(total_reward)
        plot_durations()
        break
      state = next_state


run()

<Figure size 640x480 with 0 Axes>