APPENDIX 3: SOFT ACTOR CRITIC

adapted from https://towardsdatascience.com/soft-actor-critic-demystified-b8427df61665 , https://github.com/philtabor/Youtube-Code-Repository/blob/master/ReinforcementLearning/PolicyGradient/SAC/sac_torch.py and https://github.com/pranz24/pytorch-soft-actor-critic

In [None]:
!pip install -U gym

In [None]:
!pip install Box2D
!pip install pygame
!pip install box2d-py
!pip install gym[all]
!pip install gym[Box_2D]
!pip install git+https://github.com/ngc92/space-wrappers.git
import gym

In [1]:
import os
import torch as T
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np




class Critic(nn.Module):
    def __init__(self, beta, input, actions_n, fc1=128, fc2=128):
      super.__init__()

      #critic evaluates the value of the (s,a) pair.
      self.fc1 = nn.Linear(self.input_dims[0]+actions_n, self.fc1)
      self.fc2 = nn.Linear(self.fc1, self.fc2)
      self.q = nn.Linear(self.fc2, 1) 

    def forward(self, state, action):
        action_value = self.fc1(T.cat([state, action], dims=1))
        action_value = F.relu(action_value)
        action_value = self.fc2(action_value)
        action_value = F.relu(action_value)
        q = self.q(action_value)

        return q


#Value network 
class Value(nn.Module):
      def __init__(self, beta, input, fc1=128, fc2=128):
        super(Value, self).__init__()
        self.input = input
        self.input = input
        self.fc1 = fc1
        self.fc2 = fc2
    
        self.fc1 = nn.Linear(*self.input, self.fc1)
        self.fc2 = nn.Linear(self.fc1, fc2)
        self.v = nn.Linear(self.fc2, 1)

        self.optimizer = optim.Adam(self.parameters(), lr=beta)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')

    
      def forward(self, state):
        state_value = self.fc1(state)
        state_value = F.relu(state_value)
        state_value = self.fc2(state_value)
        state_value = F.relu(state_value)

        v = self.v(state_value)

        return v

#Policy sample to +, -1 so max_action multiplied by output
class Actor(nn.Module):
    def __init__(self, alpha, input, max_action, actions_n, fc1=128, 
            fc2=128):
        super(Value, self).__init__()
        self.input = input
        self.fc1 = fc1
        self.fc2 = fc2
        self.actions_n = actions_n
        self.max_action = max_action
        self.reparam_noise = 1e-6 #for calculation of policy: adding noise

        self.fc1 = nn.Linear(*self.input, self.fc1)
        self.fc2 = nn.Linear(self.fc1, self.fc2)
        self.mu = nn.Linear(self.fc2, self.actions_n) #mean
        self.sigma = nn.Linear(self.fc2, self.actions_n) #standard deviation

    def forward(self, state):
        prob = self.fc1(state)
        prob = F.relu(prob)
        prob = self.fc2(prob)
        prob = F.relu(prob)

        mu = self.mu(prob) 
        sigma = self.sigma(prob)

        mu = self.mu(prob)
        sigma = self.mu(sigma)
        #clamping the standard deviation to constrain the width of distribution
        sigma = T.clamp(sigma, min=self.reparam_noise, max=1)

        return mu, sigma

    #Policy - Gaussian: probability of selecting action from action space.
    def sample_normal(self, state, reparametrize = True):
      mu, sigma = self.forward(state)
      probabilities = Normal(mu, sigma)

     #Trick reparametrize to lower variance and keep policy differntiable
      actions = probabilities.rsample() #adding noise to sample

      #Action is in (-1, 1) hence multiplied by max action value
      action = T.tanh(actions)*T.tensor(self.max_action)
      log_probs = probabilities.log_prob(actions) #Calculation for the loss function
      log_probs -= T.log(1-action.pow(2) + self.reparam_noise) #reparam_noise so != 0
      log_probs = log_probs.sum(1, keepdim=True) #scaler quantity  as loss == number of actions

      return action, log_probs

#code below taken from https://www.youtube.com/watch?v=ioidsRlf79o&t=1947s
class Agent():
    def __init__(self, alpha=0.0003, beta=0.0003, input_dims=[8],
            env=None, gamma=0.99, n_actions=2, max_size=1000000, tau=0.005,
            layer1_size=256, layer2_size=256, batch_size=256, reward_scale=2):
        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.batch_size = batch_size
        self.n_actions = n_actions

        self.actor = ActorNetwork(alpha, input_dims, n_actions=n_actions,
                    name='actor', max_action=env.action_space.high)
        self.critic_1 = CriticNetwork(beta, input_dims, n_actions=n_actions,
                    name='critic_1')
        self.critic_2 = CriticNetwork(beta, input_dims, n_actions=n_actions,
                    name='critic_2')
        self.value = ValueNetwork(beta, input_dims, name='value')
        self.target_value = ValueNetwork(beta, input_dims, name='target_value')

        self.scale = reward_scale
        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        state = T.Tensor([observation]).to(self.actor.device)
        actions, _ = self.actor.sample_normal(state, reparameterize=False)

        return actions.cpu().detach().numpy()[0]

    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        target_value_params = self.target_value.named_parameters()
        value_params = self.value.named_parameters()

        target_value_state_dict = dict(target_value_params)
        value_state_dict = dict(value_params)

        for name in value_state_dict:
            value_state_dict[name] = tau*value_state_dict[name].clone() + \
                    (1-tau)*target_value_state_dict[name].clone()

        self.target_value.load_state_dict(value_state_dict)

    def save_models(self):
        print('.... saving models ....')
        self.actor.save_checkpoint()
        self.value.save_checkpoint()
        self.target_value.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()

    def load_models(self):
        print('.... loading models ....')
        self.actor.load_checkpoint()
        self.value.load_checkpoint()
        self.target_value.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        state, action, reward, new_state, done = \
                self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.actor.device)
        done = T.tensor(done).to(self.actor.device)
        state_ = T.tensor(new_state, dtype=T.float).to(self.actor.device)
        state = T.tensor(state, dtype=T.float).to(self.actor.device)
        action = T.tensor(action, dtype=T.float).to(self.actor.device)

        value = self.value(state).view(-1)
        value_ = self.target_value(state_).view(-1)
        value_[done] = 0.0

        actions, log_probs = self.actor.sample_normal(state, reparameterize=False)
        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1.forward(state, actions)
        q2_new_policy = self.critic_2.forward(state, actions)
        critic_value = T.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)

        self.value.optimizer.zero_grad()
        value_target = critic_value - log_probs
        value_loss = 0.5 * F.mse_loss(value, value_target)
        value_loss.backward(retain_graph=True)
        self.value.optimizer.step()

        actions, log_probs = self.actor.sample_normal(state, reparameterize=True)
        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1.forward(state, actions)
        q2_new_policy = self.critic_2.forward(state, actions)
        critic_value = T.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)
        
        actor_loss = log_probs - critic_value
        actor_loss = T.mean(actor_loss)
        self.actor.optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor.optimizer.step()

        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()
        q_hat = self.scale*reward + self.gamma*value_
        q1_old_policy = self.critic_1.forward(state, action).view(-1)
        q2_old_policy = self.critic_2.forward(state, action).view(-1)
        critic_1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat)
        critic_2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat)

        critic_loss = critic_1_loss + critic_2_loss
        critic_loss.backward()
        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()

        self.update_network_parameters()


In [None]:
def sample_normal(self, state, reparatemize = True)