<a href="https://colab.research.google.com/github/kimsooyoung/rl_oc_python/blob/main/oc_lec3_actor_critic/SAC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Soft Actor Critic Implementation with PyTorch

* This code is referenced from https://github.com/seungeunrho/minimalRL/blob/master/sac.py

In [None]:
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install swig
!pip install renderlab
!pip install gymnasium
!pip install gymnasium[box2d]

### Import the Necessary Packages

In [2]:
import gymnasium as gym
import collections
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.distributions import Normal

## Render Test

In [3]:
import renderlab as rl

env = gym.make("Pendulum-v1", render_mode = "rgb_array")
env = rl.RenderFrame(env, "./output")

observation, info = env.reset()
score = 0

while True:
  action = env.action_space.sample()
  observation, reward, terminated, truncated, info = env.step(action)
  score += reward

  if terminated or truncated:
    print("Score : ", score)
    break

env.play()

Score :  -1176.292837993472
Moviepy - Building video temp-{start}.mp4.
Moviepy - Writing video temp-{start}.mp4





Moviepy - Done !
Moviepy - video ready temp-{start}.mp4


## Before Diving into the Codes

* `Normal.log_prob()`
*

In [4]:
mu = 0.0
sigma = 5.0
pd = Normal(mu, sigma)

# Random Sample action from Gaussian Distribution
action = pd.rsample()
print(action)

# Get log(a) from distribution
print(pd.log_prob(action))

tensor(-4.3310)
tensor(-2.9035)


## Initialize Hyper Params

In [5]:
lr_pi           = 0.0005
lr_q            = 0.001
init_alpha      = 0.01
gamma           = 0.98
batch_size      = 32
buffer_limit    = 50000
max_torque      = 2.0 # pendulum max torque [-2.0 ~ 2.0]
tau             = 0.01 # for target network soft update
target_entropy  = -1.0 # for automated alpha update
lr_alpha        = 0.001  # for automated alpha update

## Define ReplayBuffer Class

In [66]:
class ReplayBuffer(object):

  def __init__(self):
    self.buffer = collections.deque(maxlen=buffer_limit)

  def put(self, transition):
    self.buffer.append(transition)

  def sample(self, n):
    mini_batch = random.sample(self.buffer, n)
    s_list, a_list, r_list, sp_list, done_mask_list = [], [], [], [], []

    for transition in mini_batch:
      s, a, r, sp, done = transition
      s_list.append(s)
      a_list.append([a])
      r_list.append([r])
      sp_list.append(sp)
      done_mask = 0.0 if done else 1.0
      done_mask_list.append([done_mask])

    return torch.tensor(s_list, dtype=torch.float), torch.tensor(a_list, dtype=torch.float), \
      torch.tensor(r_list, dtype=torch.float), torch.tensor(sp_list, dtype=torch.float), \
      torch.tensor(done_mask_list, dtype=torch.float)

  def size(self):
    return len(self.buffer)

In [15]:
rb = ReplayBuffer()

## Define PolicyNet Class

In [62]:
class PolicyNet(nn.Module):
  def __init__(self, lr_pi, lr_alpha):
    super(PolicyNet, self).__init__()

    self.fc_in = nn.Linear(3, 128)
    self.fc_mu = nn.Linear(128, 1)
    self.fc_sigma = nn.Linear(128, 1)
    self.optimizer = optim.Adam(self.parameters(), lr=lr_pi)

    self.log_alpha = torch.tensor( np.log(init_alpha) )
    self.log_alpha.requires_grad = True
    self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=lr_alpha)

  def forward(self, x):
    x1 = F.relu(self.fc_in(x))
    mu = self.fc_mu(x1)
    sigma = F.softplus(self.fc_sigma(x1))

    N = Normal(mu, sigma)
    action = N.rsample()
    log_p = N.log_prob(action)

    action_real = torch.tanh(action)
    log_p_real = log_p - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)

    return action_real, log_p_real

  def train_net(self, q1, q2, mini_batch):
    s, _, _, _, _ = mini_batch
    a, log_p = self.forward(s)
    entropy = - self.log_alpha.exp() * log_p

    q1_val, q2_val = q1(s, a), q2(s, a)
    q1_q2 = torch.cat([q1_val, q2_val], dim=1)
    min_q = torch.min(q1_q2, 1, keepdim=True)[0]

    loss = - (min_q + entropy).mean()
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    alpha_loss = -( self.log_alpha.exp()* (log_p + target_entropy).detach() ).mean()
    self.log_alpha_optimizer.zero_grad()
    alpha_loss.backward()
    self.log_alpha_optimizer.step()


In [18]:
pi = PolicyNet(lr_pi, lr_alpha)

In [41]:
class QNet(nn.Module):
  def __init__(self, lr_q):
    super(QNet, self).__init__()

    self.fc_s = nn.Linear(3, 64)
    self.fc_a = nn.Linear(1, 64)
    self.fc_cat = nn.Linear(128, 32)
    self.fc_out = nn.Linear(32, 1)
    self.optimizer = optim.Adam(self.parameters(), lr=lr_q)

  def forward(self, s, a):
    x_s = F.relu(self.fc_s(s))
    x_a = F.relu(self.fc_a(a))
    x_sa = torch.cat([x_s, x_a], dim=1)
    x_cat = F.relu(self.fc_cat(x_sa))
    x_out = self.fc_out(x_cat)
    return x_out

  def train_net(self, target, mini_batch):
    s, a, r, sp, done = mini_batch
    q = self.forward(s, a)
    loss = F.smooth_l1_loss(q, target)
    self.optimizer.zero_grad()
    loss.mean().backward()
    self.optimizer.step()

  def soft_update(self, q_target):
    for q_param, self_param in zip( q_target.parameters(), self.parameters() ):
      q_param.data.copy_( (1.0 - tau) * q_param + tau * self_param )

In [21]:
q_net = QNet(lr_q)

In [40]:
# calculate TD target
def calc_target(pi, q1, q2, mini_batch):
  s, a, r, sp, done = mini_batch

  with torch.no_grad():
    ap, log_p = pi(sp)
    entropy = - pi.log_alpha.exp() * log_p

    q1_val, q2_val = q1(sp, ap), q2(sp, ap)
    q1_q2 = torch.cat([q1_val, q2_val], dim=1)
    q_min = torch.min(q1_q2, 1, keepdim=True)[0]

    td_target = r + gamma * done * (q_min + entropy)

  return td_target

## Main loop

In [None]:
env = gym.make('Pendulum-v1')
q1, q2, q1_target, q2_target = QNet(lr_q), QNet(lr_q), QNet(lr_q), QNet(lr_q)
pi = PolicyNet(lr_pi, lr_alpha)
replay_buffer = ReplayBuffer()

q1_target.load_state_dict(q1.state_dict())
q2_target.load_state_dict(q2.state_dict())

score = 0
print_interval = 20

for i in range(10000):
  s, _ = env.reset()
  done = False
  count = 0

  while (not done) and (count < 200):
    a, log_p = pi(torch.from_numpy(s).float())
    sp, r, done, truncated, info = env.step( [2.0 * a.item()] )
    replay_buffer.put( (s, a.item(), r/10.0, sp, done) )
    score += r
    count += 1
    s = sp

  if replay_buffer.size() > 1000:
    for batch_epoch in range(20):
      mini_batch = replay_buffer.sample(batch_size)
      td_target = calc_target( pi, q1_target, q2_target, mini_batch )
      q1.train_net(td_target, mini_batch)
      q2.train_net(td_target, mini_batch)
      pi.train_net(q1, q2, mini_batch)
      q1.soft_update(q1_target)
      q2.soft_update(q2_target)

  if i % print_interval == 0 and i != 0:
    print("# of episode :{}, avg score : {:.1f} alpha:{:.4f}".format(i, score/print_interval, pi.log_alpha.exp()))
    score = 0.0

env.close()

In [71]:
import renderlab as rl

env = gym.make("Pendulum-v1", render_mode = "rgb_array")
env = rl.RenderFrame(env, "./output")

s, info = env.reset()
done = False
score = 0

while not done:
  a, log_p = pi(torch.from_numpy(s).float())
  sp, r, done, truncated, info = env.step( [2.0 * a.item()] )

  s = sp
  score += r

  if terminated or truncated:
    print("Score : ", score)
    break

env.play()

Score :  -5.269485975829426
Moviepy - Building video temp-{start}.mp4.
Moviepy - Writing video temp-{start}.mp4





Moviepy - Done !
Moviepy - video ready temp-{start}.mp4
