In [1]:
import gymnasium as gym

env_name = 'Pendulum-v1'
env = gym.make(env_name)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, Independent


class Actor(nn.Module):
  def __init__(self, state_dim, hidden_dim, action_dim):
    super(Actor, self).__init__()
    self.base = nn.Sequential(
      nn.Linear(state_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
    )
    self.means = nn.Linear(hidden_dim, action_dim)
    # self.log_stds = nn.Linear(hidden_dim, action_dim)

  def forward(self, states: torch.tensor):
    x = self.base(states)
    means = self.means(x)
    # log_stds = self.log_stds(x)
    # the gradient of computing log_stds first and then using torch.exp
    # is much more well-behaved then computing stds directly using nn.Softplus()
    # ref: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L26
    LOG_STD_MAX = 2
    LOG_STD_MIN = -20
    # stds = torch.exp(torch.clamp(log_stds, LOG_STD_MIN, LOG_STD_MAX))
    return Independent(Normal(loc=means, scale=0.1), reinterpreted_batch_ndims=1)


class Critic(nn.Module):
  def __init__(self, state_dim, action_dim, hidden_dim):
    super(Critic, self).__init__()
    self.net = nn.Sequential(
      nn.Linear(state_dim + action_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, 1)
    )

  def forward(self, states: torch.tensor, actions: torch.tensor):
    return self.net(torch.cat([states, actions], dim=1))


actor = Actor(state_dim=3, hidden_dim=64, action_dim=1)
actor.to('cuda')
actor_target = Actor(state_dim=3, hidden_dim=64, action_dim=1)
actor_target.to('cuda')
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)

critic = Critic(state_dim=3, action_dim=1, hidden_dim=64)
critic.to('cuda')
critic_target = Critic(state_dim=3, action_dim=1, hidden_dim=64)
critic_target.to('cuda')
critic_target.load_state_dict(critic.state_dict())
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

  _torch_pytree._register_pytree_node(


In [3]:
test_state = torch.Tensor([[0.1, 0.2, 0.3]])
dist = actor(test_state.to('cuda'))
print(dist)
sample = dist.sample()
print(sample.shape)
log_prob = dist.log_prob(sample)
print(log_prob.shape)

Independent(Normal(loc: tensor([[-0.0329]], device='cuda:0', grad_fn=<AddmmBackward0>), scale: tensor([[0.1000]], device='cuda:0')), 1)
torch.Size([1, 1])
torch.Size([1])


In [4]:
def select_action(state: torch.tensor):
  global saved_actions

  dist = actor(state)
  u = dist.rsample()
  action = torch.tanh(u)
  log_prob = dist.log_prob(u)
  return action, log_prob

In [5]:
import matplotlib.pyplot as plt
from IPython import display

episode_rewards = []

def plot_durations(show_result=False):
  global episode_rewards

  plt.figure(1)
  durations_t = torch.tensor(episode_rewards, dtype=torch.float)
  if show_result:
    plt.title('Result')
  else:
    plt.clf()
    plt.title('Training...')
  plt.xlabel('Episode')
  plt.ylabel('Rewards')
  plt.plot(durations_t.numpy())
  # Take 100 episode averages and plot them too
  if len(durations_t) >= 100:
    means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
    means = torch.cat((torch.zeros(99), means))
    plt.plot(means.numpy())

  plt.pause(0.001)  # pause a bit so that plots are updated
  if not show_result:
    display.display(plt.gcf())
    display.clear_output(wait=True)
  else:
    display.display(plt.gcf())

In [6]:
TAU = 0.995
ALPHA = 0.002
GAMMA = 0.99

def clip_gradient(model: nn.Module) -> None:
  for name, param in model.named_parameters():
    param.grad.data.clamp_(-1, 1)


def update_target(target: nn.Module, model: nn.Module) -> None:
  for p1, p2 in zip(target.parameters(), model.parameters()):
    p1.data.copy_(p1.data * TAU + p2.data * (1 - TAU))


def update_models(saved_data):
  states = torch.tensor([entry[0] for entry in saved_data]).to('cuda')
  next_states = torch.tensor([entry[2] for entry in saved_data]).to('cuda')
  actions = torch.tensor([entry[1] for entry in saved_data]).to('cuda')
  rewards = torch.tensor([entry[3] for entry in saved_data], dtype=torch.float32).to('cuda')
  done = torch.tensor([entry[4] for entry in saved_data], dtype=torch.float32).to('cuda')
  with torch.no_grad():
    dist = actor_target(next_states)
    u = dist.sample()
    next_actions = torch.tanh(u)
    next_log_prob = dist.log_prob(u)
    next_q_values = torch.squeeze(critic_target(next_states, next_actions), dim=1)
    targets = rewards + GAMMA * (1.0 - done) * next_q_values

  q_pred = torch.squeeze(critic(states, actions), dim=1)
  critic_loss = torch.mean((q_pred - targets) ** 2)

  critic_optimizer.zero_grad()
  critic_loss.backward()
  clip_gradient(critic)
  critic_optimizer.step()

  for param in critic.parameters():
    param.requires_grad = False
  
  a, log_prob = select_action(states)
  q_values = critic(states, a)
  policy_loss = -torch.mean(q_values)

  for param in critic.parameters():
    param.requires_grad = True
  
  actor_optimizer.zero_grad()
  policy_loss.backward()
  clip_gradient(actor)
  actor_optimizer.step()

  with torch.no_grad():
    update_target(critic_target, critic)
    update_target(actor_target, actor)

In [7]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))


class ReplayBuffer(object):
  def __init__(self, capacity):
    self.memory = deque([], maxlen=capacity)

  def push(self, *args):
    self.memory.append(Transition(*args))

  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

In [8]:
import numpy as np
from itertools import count


NUM_EPISODE = 250
MAX_TIME = 200
REPLAY_BUFFER_SIZE = 1000000
BATCH_SIZE = 64

def run():
  replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE)
  for episode in range(NUM_EPISODE):
    
    # reset environment and episode reward
    state, _ = env.reset()
    ep_reward = 0
    total_reward = 0
    for t in count():
      action_tensor, _ = select_action(torch.Tensor([state]).to('cuda'))
      action = action_tensor.cpu().detach().numpy()[0]
      next_state, reward, done, _, _ = env.step(action)
      if t >= MAX_TIME:
        done = True
      
      total_reward += reward
      state = next_state

      replay_buffer.push(
        state,
        action,
        next_state,
        reward,
        int(done),
      )
      
      if done:
        episode_rewards.append(total_reward)
        plot_durations()
        break

      if len(replay_buffer) >= BATCH_SIZE:
        update_models(replay_buffer.sample(BATCH_SIZE))
    # update_models(replay_buffer.sample(BATCH_SIZE))
    # break

run()

<Figure size 640x480 with 0 Axes>