In [34]:
import math
import random
import sys

import gym
import numpy as np
import time 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from tensorboardX import SummaryWriter
import roboschool

from IPython.display import clear_output
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display

%matplotlib inline

use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

# Networks

In [35]:
class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=[400,300], init_w=3e-3):
        super(SoftQNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size[0])
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.linear3 = nn.Linear(hidden_size[1], 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [36]:
class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=[400,300], 
                 init_w=3e-3, log_std_min=-20, log_std_max=2, epsilon=1e-6):
        super(PolicyNetwork, self).__init__()
        
        self.epsilon = epsilon
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        self.linear1 = nn.Linear(num_inputs, hidden_size[0])
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        
        self.mean_linear = nn.Linear(hidden_size[1], num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)
        
        self.log_std_linear = nn.Linear(hidden_size[1], num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)
        
    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (
            self.normal_mean +
            self.normal_std *
            Normal(
                ptu.zeros(self.normal_mean.size()),
                ptu.ones(self.normal_std.size())
            ).sample()
        )
        z.requires_grad_()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)
    
    def forward(self, state, deterministic=False):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        
        mean    = self.mean_linear(x)
        
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        std = torch.exp(log_std)
        
        log_prob = None
        
        if deterministic:
            action = torch.tanh(mean)
        else:
            # assumes actions have been normalized to (0,1)
            normal = Normal(0, 1)
            z = mean + std * normal.sample().requires_grad_()
            action = torch.tanh(z)
            log_prob = Normal(mean, std).log_prob(z) - torch.log(1 - action * action + self.epsilon)
            
        return action, mean, log_std, log_prob, std
    
    def get_action(self, state, deterministic=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        action,_,_,_,_ =  self.forward(state, deterministic)
        act = action.cpu()[0][0]
        return act

In [37]:
class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=[400,300], init_w=3e-3):
        super(SoftQNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size[0])
        self.linear2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.linear3 = nn.Linear(hidden_size[1], 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

# Memory

In [38]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done
    
    def __len__(self):
        return len(self.buffer)

In [39]:
class NormalizedActions(gym.ActionWrapper):
    def action(self, action):
        low  = self.action_space.low
        high = self.action_space.high

        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)
        
        return action

def normalize_action(action, low, high):
    action = low + (action + 1.0) * 0.5 * (high - low)
    action = np.clip(action, low, high)
    
    return action

# SAC Agent

In [40]:
class SAC(object):
    
    def __init__(self, env, replay_buffer, seed=0, hidden_dim=[400,300],
        steps_per_epoch=200, epochs=1000, discount=0.99,
        tau=1e-2, lr=1e-3, auto_alpha=True, batch_size=100, start_steps=10000,
        max_ep_len=200, logger_kwargs=dict(), save_freq=1):
        
        # Set seeds
        self.env = env
        self.env.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # env space
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0] 
        self.hidden_dim = hidden_dim
        
        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # init networks
        
        # Soft Q
        self.soft_q_net1 = SoftQNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(device)
        self.soft_q_net2 = SoftQNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(device)
        
        self.target_soft_q_net1 = SoftQNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(device)
        self.target_soft_q_net2 = SoftQNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(device)
        
        for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
            target_param.data.copy_(param.data)
        
        for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
            target_param.data.copy_(param.data)
            
        # Policy
        self.policy_net = PolicyNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(device)
        
        # Optimizers/Loss
        self.soft_q_criterion = nn.MSELoss()
        
        self.soft_q_optimizer1 = optim.Adam(self.soft_q_net1.parameters(), lr=lr)
        self.soft_q_optimizer2 = optim.Adam(self.soft_q_net2.parameters(), lr=lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        
        # alpha tuning
        self.auto_alpha = auto_alpha
        
        if self.auto_alpha:
            self.target_entropy = -np.prod(env.action_space.shape).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
            
        self.replay_buffer = replay_buffer
        self.discount = discount
        self.batch_size = batch_size
        self.tau = tau
        
    def get_action(self, state, deterministic=False, explore=False):
        
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        if explore:
            return self.env.action_space.sample()
        else:
            action  = self.policy_net.get_action(state, deterministic).detach()
            return action.numpy()
           
    def update(self, iterations, batch_size = 100):
        
        for _ in range(0,iterations):
        
            state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)

            state      = torch.FloatTensor(state).to(device)
            next_state = torch.FloatTensor(next_state).to(device)
            action     = torch.FloatTensor(action).to(device)
            reward     = torch.FloatTensor(reward).unsqueeze(1).to(device)
            done       = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)

            new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy_net(state)

            if self.auto_alpha:
                alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                alpha = self.log_alpha.exp()
            else:
                alpha_loss = 0
                alpha = 0.2 # constant used by OpenAI

            # Update Policy 
            q_new_actions = torch.min(
                self.soft_q_net1(state, new_actions), 
                self.soft_q_net2(state, new_actions)
            )

            policy_loss = (alpha*log_pi - q_new_actions).mean()

            # Update Soft Q Function
            q1_pred = self.soft_q_net1(state, action)
            q2_pred = self.soft_q_net2(state, action)

            new_next_actions, _, _, new_log_pi, *_ = self.policy_net(next_state)

            target_q_values = torch.min(
                self.target_soft_q_net1(next_state, new_next_actions),
                self.target_soft_q_net2(next_state, new_next_actions),
            ) - alpha * new_log_pi

            q_target = reward + (1 - done) * self.discount * target_q_values
            q1_loss = self.soft_q_criterion(q1_pred, q_target.detach())
            q2_loss = self.soft_q_criterion(q2_pred, q_target.detach())

            # Update Networks
            self.soft_q_optimizer1.zero_grad()
            q1_loss.backward()
            self.soft_q_optimizer1.step()

            self.soft_q_optimizer2.zero_grad()
            q2_loss.backward()
            self.soft_q_optimizer2.step()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            # Soft Updates
            for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
                target_param.data.copy_(
                    target_param.data * (1.0 - self.tau) + param.data * self.tau
                )

            for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
                target_param.data.copy_(
                    target_param.data * (1.0 - self.tau) + param.data * self.tau
                )
                
    def save(self, filename, directory):
        torch.save(self.soft_q_net1.state_dict(), '%s/%s_actor.pth' % (directory, filename))
        torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))


    def load(self, filename="best_avg", directory="./saves"):
        self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
        self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
     

# Train

In [43]:
def train(agent, steps_per_epoch=1000, epochs=100, start_steps=1000, max_ep_len=200):
    
    writer = SummaryWriter(comment="-SAC-Pendulum")
    
    # start tracking time
    start_time = time.time()
    total_rewards = []
    avg_reward = None
    
    # set initial values
    o, r, d, ep_reward, ep_len, ep_num = env.reset(), 0, False, 0, 0, 1
    
    # track total steps
    total_steps = steps_per_epoch * epochs
    
    for t in range(1,total_steps):
        
        explore = t < start_steps
        a = agent.get_action(o, explore=explore)
        
        # Step the env
        o2, r, d, _ = env.step(a)
        ep_reward += r
        ep_len += 1
        
        writer.add_scalar("reward_step", r, t)

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.push(o, a, r, o2, d)

        # update observation
        o = o2
        
        if d or (ep_len == max_ep_len):
        
            # carry out update for each step experienced (episode length)
            if not explore:
                agent.update(ep_len)
            
            # log progress
            total_rewards.append(ep_reward)
            avg_reward = np.mean(total_rewards[-10:])
            
            writer.add_scalar("avg_reward", avg_reward, t)
            writer.add_scalar("episode_reward", ep_reward, t)
            
            print("Steps:{} Episode:{} Reward:{} Avg Reward:{}".format(t, ep_num, ep_reward, avg_reward))
            o, r, d, ep_reward, ep_len = env.reset(), 0, False, 0, 0
            ep_num += 1
            
            if avg_reward > -50:
                print("saving....")
                break

        # End of epoch wrap-up
        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch
            


# Main

In [44]:
env = "Pendulum-v0"#"RoboschoolHalfCheetah-v1"

replay_buffer = ReplayBuffer(int(1e6))

env = NormalizedActions(gym.make(env))

agent = SAC(env, replay_buffer)

train(agent)

Steps:200 Episode:1 Reward:-1427.0917955599123 Avg Reward:-1427.0917955599123
Steps:400 Episode:2 Reward:-886.2195408825114 Avg Reward:-1156.6556682212117
Steps:600 Episode:3 Reward:-1486.8925541287147 Avg Reward:-1266.7346301903792
Steps:800 Episode:4 Reward:-1067.9983838134756 Avg Reward:-1217.0505685961534
Steps:1000 Episode:5 Reward:-1662.3183718191065 Avg Reward:-1306.104129240744
Steps:1200 Episode:6 Reward:-1578.10968520089 Avg Reward:-1351.4383885674351
Steps:1400 Episode:7 Reward:-866.228313120975 Avg Reward:-1282.122663503655
Steps:1600 Episode:8 Reward:-1604.0023937745216 Avg Reward:-1322.3576297875134
Steps:1800 Episode:9 Reward:-1065.768460088575 Avg Reward:-1293.8477220431869
Steps:2000 Episode:10 Reward:-1091.1967771112525 Avg Reward:-1273.5826275499935
Steps:2200 Episode:11 Reward:-888.4377463418524 Avg Reward:-1219.7172226281875
Steps:2400 Episode:12 Reward:-640.3823628178775 Avg Reward:-1195.133504821724
Steps:2600 Episode:13 Reward:-1325.8349045463726 Avg Reward:-117

Steps:20800 Episode:104 Reward:-237.5071764164259 Avg Reward:-154.8331367185735
Steps:21000 Episode:105 Reward:-240.32777087117452 Avg Reward:-178.7379752167529
Steps:21200 Episode:106 Reward:-120.29701253363616 Avg Reward:-166.14468727367466
Steps:21400 Episode:107 Reward:-119.60980908426764 Avg Reward:-165.82366749501284
Steps:21600 Episode:108 Reward:-229.1542630891135 Avg Reward:-176.96819409950945
Steps:21800 Episode:109 Reward:-118.93596002173055 Avg Reward:-176.88458084315576
Steps:22000 Episode:110 Reward:-237.1052057699421 Avg Reward:-177.859086046281
Steps:22200 Episode:111 Reward:-128.16329005887312 Avg Reward:-178.8091936914944
Steps:22400 Episode:112 Reward:-127.123431699418 Avg Reward:-179.86264609747232
Steps:22600 Episode:113 Reward:-229.59276121663922 Avg Reward:-178.78166807612206
Steps:22800 Episode:114 Reward:-119.3299102811465 Avg Reward:-166.96394146259414
Steps:23000 Episode:115 Reward:-118.27671818817059 Avg Reward:-154.75883619429376
Steps:23200 Episode:116 Rew

Steps:41000 Episode:205 Reward:-121.01087826659898 Avg Reward:-167.9971022840898
Steps:41200 Episode:206 Reward:-125.28891611312274 Avg Reward:-155.81920076908216
Steps:41400 Episode:207 Reward:-125.78569875433439 Avg Reward:-155.9938583964828
Steps:41600 Episode:208 Reward:-119.56564881247616 Avg Reward:-145.06620415642632
Steps:41800 Episode:209 Reward:-117.8243728136476 Avg Reward:-133.8171050076083
Steps:42000 Episode:210 Reward:-1.4310259329983965 Avg Reward:-121.13786049946734
Steps:42200 Episode:211 Reward:-122.51224994866664 Avg Reward:-109.62485400777707
Steps:42400 Episode:212 Reward:-117.59813865783114 Avg Reward:-109.52482297820966
Steps:42600 Episode:213 Reward:-120.86666327141572 Avg Reward:-109.21234205118438
Steps:42800 Episode:214 Reward:-1.7651033502310498 Avg Reward:-97.3648695921323
Steps:43000 Episode:215 Reward:-348.5586369760247 Avg Reward:-120.11964546307486
Steps:43200 Episode:216 Reward:-0.40493637690588996 Avg Reward:-107.63124748945316
Steps:43400 Episode:21

Steps:61200 Episode:306 Reward:-120.74751136510478 Avg Reward:-140.4309990338383
Steps:61400 Episode:307 Reward:-122.4719340888489 Avg Reward:-152.54263905649776
Steps:61600 Episode:308 Reward:-1.4348082668576012 Avg Reward:-129.33873799427158
Steps:61800 Episode:309 Reward:-118.99027662921945 Avg Reward:-129.41211903029983
Steps:62000 Episode:310 Reward:-126.52376121508688 Avg Reward:-130.6132952663793
Steps:62200 Episode:311 Reward:-124.13105891199771 Avg Reward:-130.77554644637462
Steps:62400 Episode:312 Reward:-119.81702480893908 Avg Reward:-120.72382158666593
Steps:62600 Episode:313 Reward:-230.86645346411487 Avg Reward:-132.31694542007082
Steps:62800 Episode:314 Reward:-125.32058024139246 Avg Reward:-132.78644388579232
Steps:63000 Episode:315 Reward:-227.60245507077605 Avg Reward:-131.79058640623376
Steps:63200 Episode:316 Reward:-0.8794729470805196 Avg Reward:-119.80378256443137
Steps:63400 Episode:317 Reward:-122.0237468754695 Avg Reward:-119.75896384309343
Steps:63600 Episode:

KeyboardInterrupt: 