<a href="https://colab.research.google.com/github/nimamt/machine_learning/blob/master/pytorch/reinforcement/SAC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on arXiv:1812.05905 "Soft Actor-Critic Algorithms and Applications"

In [7]:
import gym
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
from torch.autograd import Variable
from torch.distributions.normal import Normal

# Hyperparameters
BATCH_SIZE = 256
LR_ACTOR = 1e-4
LR_CRITIC = 1e-3
LR_ALPHA = 3e-4
STD_NOISE = 0.3
GAMMA = 0.99
TAU = 1e-3
BUFFER_SIZE = int(1e6)
STEPS = 1000
POLICY_NOISE = 0.2
NOISE_CLIP = 0.5
DELAY_STEPS = 2
EPSILON = 1e-6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorNet(torch.nn.Module):
    def __init__(self, state_size, action_size, max_action):
        super(ActorNet, self).__init__()

        self.fc1 = torch.nn.Linear(state_size, 256, device=device)
        self.fc2 = torch.nn.Linear(256, 128, device=device)
        self.mu = torch.nn.Linear(128, action_size, device=device)
        self.sigma = torch.nn.Linear(128, action_size, device=device)
        self.max_action = max_action

        self.reset_parameters()

    def reset_parameters(self):
        x = self.fc1_init()
        self.fc1.weight.data.uniform_(x[0],x[1])
        x = self.fc2_init()
        self.fc2.weight.data.uniform_(x[0],x[1])

    def fc1_init(self):
        lim = 1. / np.sqrt(self.fc1.weight.data.size()[0])
        return (-lim, lim)

    def fc2_init(self):
        lim = 1. / np.sqrt(self.fc2.weight.data.size()[0])
        return (-lim, lim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        # self.x1 = x
        x = F.relu(self.fc2(x))
        # self.x2 = x
        mu = self.mu(x)
        sigma = self.sigma(x)
        sigma = torch.clamp(sigma, min=EPSILON, max=1)

        return mu, sigma

    def sample_normal(self, x, reparameterize=True):
      mu, sigma = self.forward(x)

    #   if torch.isnan(mu).any():
    #     print(self.x1)
    #     print(self.x2)
    #     print(x)
    #     print(mu)
    #     print(sigma)

      probabilities = Normal(mu, sigma)

      if reparameterize:
        actions = probabilities.rsample()
      else:
        actions = probabilities.sample()

      log_probabilities = probabilities.log_prob(actions).sum(-1, keepdim=True)
      action = torch.tanh(actions) * torch.tensor(self.max_action)
    #   log_probabilities -= torch.log(1-action.pow(2) + EPSILON)
    #   log_probabilities = log_probabilities.sum(1, keepdim=True)

      return action, log_probabilities

class CriticNet(torch.nn.Module):
    def __init__(self, state_size, action_size):
        super(CriticNet, self).__init__()

        self.fc1 = torch.nn.Linear(state_size + action_size, 256, device=device)
        self.fc2 = torch.nn.Linear(256, 128, device=device)
        self.fc3 = torch.nn.Linear(128, 1, device=device)

        self.fc4 = torch.nn.Linear(state_size + action_size, 256, device=device)
        self.fc5 = torch.nn.Linear(256, 128, device=device)
        self.fc6 = torch.nn.Linear(128, 1, device=device)

        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*self.fc1_init())
        self.fc2.weight.data.uniform_(*self.fc2_init())
        self.fc3.weight.data.uniform_(-3e-3, 3e-3)

        self.fc4.weight.data.uniform_(*self.fc1_init())
        self.fc5.weight.data.uniform_(*self.fc2_init())
        self.fc6.weight.data.uniform_(-3e-3, 3e-3)

    def fc1_init(self):
        lim = 1. / np.sqrt(self.fc1.weight.data.size()[0])
        return (-lim, lim)

    def fc2_init(self):
        lim = 1. / np.sqrt(self.fc2.weight.data.size()[0])
        return (-lim, lim)

    def forward(self, state, action):
        state_action = torch.cat([state, action], 1).to(device)
        x = F.relu(self.fc1(state_action))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        y = F.relu(self.fc4(state_action))
        y = F.relu(self.fc5(y))
        y = self.fc6(y)
        return x,y

    def Q1(self, state, action):
        state_action = torch.cat([state, action], 1).to(device)
        x = F.relu(self.fc1(state_action))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class SAC:
    def __init__(self, state_dim, action_dim, action_high):
        self.actor = ActorNet(state_dim, action_dim, action_high)
        self.actor_target = ActorNet(state_dim, action_dim, action_high)
        self.critic = CriticNet(state_dim, action_dim)
        self.critic_target = CriticNet(state_dim, action_dim)
        self.memory = deque(maxlen=BUFFER_SIZE)
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)
        self.action_high = action_high
        self.step = 0
        self.log_alpha = torch.tensor([-3.0], requires_grad=True, device=device)
        self.optimizer_logalpha = optim.Adam([self.log_alpha], lr=LR_ALPHA)
        self.target_entropy = -action_dim

    def act(self, state):
        state = Variable(torch.from_numpy(state).float().to(device).unsqueeze(0))
        self.actor.eval()
        with torch.no_grad():
            action, _ = self.actor.sample_normal(state.to(device), False)
        self.actor.train()
        return np.clip(action.cpu().numpy()[0] * self.action_high, -self.action_high, self.action_high)

    def memorize(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def learn(self):
        if len(self.memory) < BATCH_SIZE:
            return

        self.step += 1

        batch = random.sample(self.memory, BATCH_SIZE)
        state_batch = torch.from_numpy(np.array([arr[0] for arr in batch])).float().to(device)
        action_batch = torch.from_numpy(np.array([arr[1] for arr in batch])).float().to(device)
        reward_batch = torch.from_numpy(np.array([arr[2] for arr in batch]).reshape(BATCH_SIZE, 1)).float().to(device)
        next_state_batch = torch.from_numpy(np.array([arr[3] for arr in batch])).float().to(device)
        done_batch = torch.from_numpy(np.array([arr[4] for arr in batch], dtype=np.uint8).reshape(BATCH_SIZE, 1)).float().to(device)

        alpha = torch.exp(self.log_alpha)

        next_actions, log_probs = self.actor_target(next_state_batch)
        q_next1, q_next2 = self.critic_target(next_state_batch, next_actions)
        target_V = torch.min(q_next1,q_next2) - alpha.detach() * log_probs
        q_targets = reward_batch + GAMMA * target_V * (1 - done_batch)

        # Update critic
        self.optimizer_critic.zero_grad()
        q_current1, q_current2 = self.critic(state_batch, action_batch)
        critic_loss = F.mse_loss(q_current1, q_targets) + F.mse_loss(q_current2, q_targets)
        critic_loss.backward()
        self.optimizer_critic.step()

        if self.step % DELAY_STEPS == 0:
          # Update actor
          self.optimizer_actor.zero_grad()
          next_actions, log_probs = self.actor.sample_normal(state_batch)
          q_next1, q_next2 = self.critic(state_batch, next_actions)
          actor_Q = torch.min(q_next1,q_next2)
          actor_loss = (alpha.detach() * log_probs - actor_Q).mean()
          actor_loss.backward()
          self.optimizer_actor.step()

          # Update log_alpha
          self.optimizer_logalpha.zero_grad()
          alpha_loss = (alpha *
                              (-log_probs - self.target_entropy).detach()).mean()
          alpha_loss.backward()
          self.optimizer_logalpha.step()

          # Update target networks
          self.update_targets()

    def update_targets(self):
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

if __name__ == "__main__":
    env = gym.make("Pendulum-v1")
    agent = SAC(env.observation_space.shape[0], env.action_space.shape[0], env.action_space.high[0])
    scores = []
    for i_episode in range(1, 150):
        state = env.reset()
        score = 0
        for t in range(1000):
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            agent.memorize(state, action, reward, next_state, done)
            agent.learn()
            state = next_state
            score += reward
            if done:
                break
        scores.append(score)
        print("Episode {}: Score = {}".format(i_episode, score))

Episode 1: Score = -1284.202383367827
Episode 2: Score = -1249.920901653196
Episode 3: Score = -1417.02455970926
Episode 4: Score = -1498.4402739570194
Episode 5: Score = -1472.9527439268427
Episode 6: Score = -1379.426326102248
Episode 7: Score = -1526.5210593380307
Episode 8: Score = -1371.0197157259233
Episode 9: Score = -1448.9689359824747
Episode 10: Score = -1365.4053218411211
Episode 11: Score = -1307.8871726946345
Episode 12: Score = -1413.5868210332276
Episode 13: Score = -1366.8610809385357
Episode 14: Score = -1355.4255944325412
Episode 15: Score = -1411.3268937172063
Episode 16: Score = -1452.4580688560097
Episode 17: Score = -1425.2875839904834
Episode 18: Score = -1405.4279549784699
Episode 19: Score = -1388.3571280748463
Episode 20: Score = -1475.9038274137683
Episode 21: Score = -1517.5672714993534
Episode 22: Score = -1408.6944657935253
Episode 23: Score = -1341.1781773964137
Episode 24: Score = -1483.228530990245
Episode 25: Score = -1455.3656684153623
Episode 26: Sco