# SAC Practice

In [1]:
import time
import csv
import gym
import copy
import os
import numpy as np
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Independent
from torch.distributions.normal import Normal

from utils import *
from buffer import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('current device =', device)

current device = cpu


  return torch._C._cuda_getDeviceCount() > 0


# 0. Define Q-network & policy-network

In [3]:
##################################################
##  Policy network with multi-layer perceptron  ##
##################################################

# Input - |S|
# Output - normal distribution of size |A|

class SACActor(nn.Module):
    def __init__(self, dimS, dimA, hidden1, hidden2, ctrl_range):
        super(SACActor, self).__init__()
        self.fc1 = nn.Linear(dimS, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, dimA)
        self.fc4 = nn.Linear(hidden2, dimA)

        self.ctrl_range = ctrl_range

    def forward(self, state, eval=False, with_log_prob=False):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
    
        # Build normal distribution with parameters from layer
        mu = self.fc3(x)
        log_sigma = self.fc4(x)
        
        # clip value of log_sigma, as was done in Haarnoja's implementation of SAC:
        # https://github.com/haarnoja/sac.git
        log_sigma = torch.clamp(log_sigma, -20.0, 2.0)
        
        # Build normal distribution with parameters from layer
        sigma = torch.exp(log_sigma)
        distribution = Independent(Normal(mu, sigma), 1)

        if not eval:
            # use rsample() instead of sample(), as sample() does not allow back-propagation through params
            u = distribution.rsample()
            if with_log_prob:
                log_prob = distribution.log_prob(u)
                # Support?
                log_prob -= 2.0 * torch.sum((np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u - F.softplus(-2.0 * u)), dim=1)
            else:
                log_prob = None
        # Give deterministic policy (centered at mu) when evaluation
        else:
            u = mu
            log_prob = None
            
        # apply tanh so that the resulting action lies in (-1, 1)^D
        # Squashed gaussian
        a = self.ctrl_range * torch.tanh(u)

        return a, log_prob
    

##################################################
##  Critic network with multi-layer perceptron  ##
##################################################

# Input - |S|+|A|
# Output - single value

class DoubleCritic(nn.Module):
    # Retain double network - Idea from TD3
    def __init__(self, dimS, dimA, hidden1, hidden2):
        super(DoubleCritic, self).__init__()    
        # Q1
        self.fc1 = nn.Linear(dimS + dimA, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)
        
        # Q2
        self.fc4 = nn.Linear(dimS + dimA, hidden1)
        self.fc5 = nn.Linear(hidden1, hidden2)
        self.fc6 = nn.Linear(hidden2, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        
        # Q1
        x1 = F.relu(self.fc1(x))
        x1 = F.relu(self.fc2(x1))
        x1 = self.fc3(x1)
        
        # Q2
        x2 = F.relu(self.fc4(x))
        x2 = F.relu(self.fc5(x2))
        x2 = self.fc6(x2)

        return x1, x2

    def Q1(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# 1. Define SAC agent

In [4]:
class SACAgent:
    def __init__(self,
                 dimS,
                 dimA,
                 ctrl_range,
                 gamma=0.99,
                 pi_lr=1e-4,
                 q_lr=1e-3,
                 polyak=1e-3,
                 alpha=0.2,
                 hidden1=400,
                 hidden2=300,
                 buffer_size=1000000,
                 batch_size=128,
                 device='cpu',
                 render=False):

        self.dimS = dimS
        self.dimA = dimA
        self.ctrl_range = ctrl_range

        self.gamma = gamma
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.polyak = polyak
        self.alpha = alpha
        
        self.batch_size = batch_size
        
        # networks definition
        # pi : actor network, Q : 2 critic network
        self.pi = SACActor(dimS, dimA, hidden1, hidden2, ctrl_range).to(device)
        self.Q = DoubleCritic(dimS, dimA, hidden1, hidden2).to(device)

        # target networks
        self.target_Q = copy.deepcopy(self.Q).to(device)
        freeze(self.target_Q)

        self.buffer = ReplayBuffer(dimS, dimA, limit=buffer_size)

        self.Q_optimizer = Adam(self.Q.parameters(), lr=self.q_lr)
        self.pi_optimizer = Adam(self.pi.parameters(), lr=self.pi_lr)

        self.device = device
        self.render = render

        return

    
    def act(self, state, eval=False):

        state = torch.tensor(state, dtype=torch.float).to(self.device)
        with torch.no_grad():
            action, _ = self.pi(state, eval=eval, with_log_prob=False)
        action = action.cpu().detach().numpy()

        return action
    
    
    def target_update(self):

        for params, target_params in zip(self.Q.parameters(), self.target_Q.parameters()):
            target_params.data.copy_(self.polyak * params.data + (1.0 - self.polyak) * target_params.data)

        return
    
    
    def save_model(self, path):
        print('adding checkpoints...')
        checkpoint_path = path + 'model.pth.tar'
        torch.save(
                    {'actor': self.pi.state_dict(),
                     'critic': self.Q.state_dict(),
                     'target_critic': self.target_Q.state_dict(),
                     'actor_optimizer': self.pi_optimizer.state_dict(),
                     'critic_optimizer': self.Q_optimizer.state_dict()
                    },
                    checkpoint_path)

        return

    
    def load_model(self, path):
        print('networks loading...')
        checkpoint = torch.load(path)

        self.pi.load_state_dict(checkpoint['actor'])
        self.Q.load_state_dict(checkpoint['critic'])
        self.target_Q.load_state_dict(checkpoint['target_critic'])
        self.pi_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        self.Q_optimizer.load_state_dict(checkpoint['critic_optimizer'])

        return

# 2. Implement one-step param update

In [5]:
def update(agent, batch):
    # Upload batch to GPU
    obs_batch = torch.tensor(batch.obs, dtype=torch.float).to(device)
    act_batch = torch.tensor(batch.act, dtype=torch.float).to(device)
    next_obs_batch = torch.tensor(batch.next_obs, dtype=torch.float).to(device)
    rew_batch = torch.tensor(batch.rew, dtype=torch.float).to(device)
    done_batch = torch.tensor(batch.done, dtype=torch.float).to(device)
    masks = torch.tensor([1.]) - done_batch
    
    #########################
    ##    Critic Update    ##
    #########################
    # Build Bellman target
    with torch.no_grad():
        # Get action with log(pi(a|s)) (also gradient)
        next_actions, log_probs = agent.pi(next_obs_batch, with_log_prob=True)
        
        # To calculate TQ, we need Q(s',pi(s'))
        target_q1, target_q2 = agent.target_Q(next_obs_batch, next_actions)
        
        # To mitigate overestimation! - Idea from TD3
        target_q = torch.min(target_q1, target_q2)
        
        # TQ^pi = r + gamma [ Q(s',pi(s')) - alpha H(pi(s')) ]
        # Recall : H = sum[ -P(X) * log(P(x)) ] = E [ -log(P(x)) ]
        # Recall : H \approx -log(P(x))
        TQ = rew_batch + agent.gamma * masks * (target_q - agent.alpha * log_probs)

    # Calculate MSELoss
    Q1, Q2 = self.Q(obs_batch, act_batch)
    Q_loss1 = torch.mean((Q1 - TQ)**2)
    Q_loss2 = torch.mean((Q2 - TQ)**2)
    Q_loss = Q_loss1 + Q_loss2

    # Gradient descent
    self.Q_optimizer.zero_grad()
    Q_loss.backward()
    self.Q_optimizer.step()
    
    ########################
    ##    Actor Update    ##
    ########################
    actions, log_probs = self.pi(obs_batch, with_log_prob=True)
    
    freeze(self.Q)
    q1, q2 = self.Q(obs_batch, actions)
    q = torch.min(q1, q2)

    # Need to perform gradient ascent, so (-) is required
    pi_loss = torch.mean(q - self.alpha * log_probs)
    pi_loss = - pi_loss
    
    # Gradient ascent
    self.pi_optimizer.zero_grad()
    pi_loss.backward()
    self.pi_optimizer.step()
    
    ####################################
    ##    Soft Target Critic Update    #
    ####################################
    unfreeze(self.Q)
    self.target_update()

# 3. Putting these together

In [12]:
def render_agent(agent, env_id):
    eval_agent(agent, env_id, eval_num=1, render=True)


def eval_agent(agent, env_id, eval_num=5, render=False):
    log = []
    for ep in range(eval_num):
        env = gym.make(env_id)

        state = env.reset()
        step_count = 0
        ep_reward = 0
        done = False

        while not done:
            if render and ep == 0:
                env.render()

            action = agent.act(state, eval=True)
            next_state, reward, done, _ = env.step(action)
            step_count += 1
            state = next_state
            ep_reward += reward

        if render and ep == 0:
            env.close()
        log.append(ep_reward)

    avg = sum(log) / eval_num

    return avg

In [13]:
def run_sac(
            env_id,
            max_iter=1e6,
            eval_interval=2000,
            start_train=10000,
            train_interval=50,
            buffer_size=1e6,
            fill_buffer=20000,
            truncate=1000,
            gamma=0.99,
            pi_lr=3e-4,
            q_lr=3e-4,
            polyak=5e-3,
            alpha=0.2,
            hidden1=256,
            hidden2=256,
            batch_size=128,
            device='cpu',
            render='False'
            ):

    params = locals()

    max_iter = int(max_iter)
    buffer_size = int(buffer_size)
    env = gym.make(env_id)

    dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env)

    if truncate is not None:
        max_ep_len = truncate

    # Instantize agent
    agent = SACAgent(
                     dimS,
                     dimA,
                     ctrl_range,
                     gamma=gamma,
                     pi_lr=pi_lr,
                     q_lr=q_lr,
                     polyak=polyak,
                     alpha=alpha,
                     hidden1=hidden1,
                     hidden2=hidden2,
                     buffer_size=buffer_size,
                     batch_size=batch_size,
                     device=device,
                     render=render
                     )

    set_log_dir(env_id)
    
    # Logging & Saving Weights
    num_checkpoints = 5
    checkpoint_interval = max_iter // (num_checkpoints - 1)
    current_time = time.strftime("%m%d-%H%M%S")
    train_log = open('./train_log/' + env_id + '/SAC_' + current_time + '.csv',
                     'w', encoding='utf-8', newline='')

    path = './eval_log/' + env_id + '/SAC_' + current_time
    eval_log = open(path + '.csv', 'w', encoding='utf-8', newline='')

    train_logger = csv.writer(train_log)
    eval_logger = csv.writer(eval_log)

    with open(path + '.txt', 'w') as f:
        for key, val in params.items():
            print(key, '=', val, file=f)

    ##############################
    ##    Main training loop    ##
    ##############################
    obs = env.reset()
    step_count, ep_reward = 0, 0
    start = time.time()
    
    for t in range(max_iter + 1):
        # Rollout agent to fill in replay buffer
        if t < fill_buffer:
            # For early stage of training, use random agent to promote exploration
            action = env.action_space.sample()
        else:
            action = agent.act(obs)

        next_obs, reward, done, _ = env.step(action)
        step_count += 1

        if step_count == max_ep_len:
            done = False

        agent.buffer.append(obs, action, next_obs, reward, done)

        obs = next_obs
        ep_reward += reward
        
        # Reset environment if trajectory ends
        if done or (step_count == max_ep_len):
            train_logger.writerow([t, ep_reward])
            obs = env.reset()
            step_count, ep_reward = 0
        
        # Actor-Critic
        if (t >= start_train) and (t % train_interval == 0):
            # Iterate sampling batch and updating actor-critic
            for _ in range(train_interval):
                batch = agent.buffer.sample_batch(batch_size=batch_size)
                update(agent, batch)
        
        # Evaluate agent
        if t % eval_interval == 0:
            eval_score = eval_agent(agent, env_id, render=False)
            log = [t, eval_score]
            print('step {} : {:.4f}'.format(t, eval_score))
            eval_logger.writerow(log)
        
        """
        # Render agent peridoically
        if t % (10 * eval_interval) == 0:
            if render:
                render_agent(agent, env_id)
        """
        
        # Save agent weight while training
        if t % checkpoint_interval == 0:
            agent.save_model('./checkpoints/' + env_id + '/sac_{}th_iter_'.format(t))

    train_log.close()
    eval_log.close()

    return

# 4. Let's train our agent!

### Hyperparameter setting

In [9]:
# Use continuous control!
env_id = 'LunarLanderContinuous-v2'
truncate = 1000
max_iter = 5e5
eval_interval = 2000
render = False
tau = 5e-3
lr = 3e-4
hidden1 = 256
hidden2 = 256
train_interval = 50
start_train = 1e4
fill_buffer = 2e4

### Setup environment

In [10]:
env = gym.make(env_id)
obs_dim = env.observation_space.shape[0]
num_act = env.action_space.shape[0]

print('observation space dim. : {} / # actions : {}'.format(obs_dim, num_act))

observation space dim. : 8 / # actions : 2


### Run experiment!

In [14]:
run_sac(
        env_id,
        max_iter=max_iter,
        eval_interval=eval_interval,
        start_train=start_train,
        train_interval=train_interval,
        fill_buffer=fill_buffer,
        truncate=truncate,
        gamma=0.99,
        pi_lr=lr,
        q_lr=lr,
        polyak=tau,
        alpha=0.2,
        hidden1=hidden1,
        hidden2=hidden2,
        batch_size=128,
        buffer_size=1e6,
        device=device,
        render=render
        )

environment : LunarLanderContinuous-v2
obs dim :  (8,) / ctrl dim :  (2,)
--------------------------------------------------------------------------------
ctrl range : (-1.00, 1.00)
max_ep_len :  1000
--------------------------------------------------------------------------------
step 0 : -287.4919


AttributeError: 'SACAgent' object has no attribute 'save_model'

# 5. Watch the trained agent!

In [None]:
obs = env.reset()
done = False
score = 0.
load_model(agent, path='./snapshots/trained.pth.tar', device=device)
while not done:
    env.render()
    obs, rew, done, _ = env.step(agent.act(obs))
    score += rew
    
env.close()
print('score : ', score)