# A3C, Asynchronous Advantage Actor Critic

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
from model import PolicyNetwork, ValueNetwork, HybridNetwork

In [4]:
import warnings
warnings.simplefilter('ignore', UserWarning)

In [5]:
plt.style.use('ggplot')

## Set Configs

In [6]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

In [7]:
GAMMA = 0.99
LR = 1e-3
GLOBAL_NUM_EPISODES = 1000
PRINT_EVERY = 100

## Set Environment

In [8]:
ENV_NAME = 'CartPole-v0'
env = gym.make(ENV_NAME).unwrapped; env.seed(90);

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [9]:
print('Environment Display:')
env.reset() # reset environment to a new, random state
env.render()

print('State space {}'.format(env.observation_space))
print('Action space {}'.format(env.action_space))

Environment Display:
State space Box(4,)
Action space Discrete(2)


## Define Agent Worker

In [10]:
class AgentWorker(mp.Process):
    
    def __init__(self, i, env, seed, global_network, global_optimizer, global_episode, global_num_episodes):
        
        super(AgentWorker, self).__init__()
        
        self.name = f'worker{i}'

        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.hidden_size = 256
        self.action_size = env.action_space.n
        
        self.gamma = GAMMA
        
        self.local_network = HybridNetwork(self.state_size, self.hidden_size, self.action_size, seed).to(device)
        
        self.global_network = global_network
        self.global_optimizer = global_optimizer
        self.global_episode = global_episode
        self.global_num_episodes = global_num_episodes
        
    def act(self, state):
        
        state = torch.FloatTensor(state).to(device)
        probs, _ = self.local_network(state)
        m = Categorical(probs)
        
        action = m.sample()
        
        return action.cpu().detach().item()
    
    def update_global(self, trajectory):
        
        total_loss = self.compute_loss(trajectory)
        
        self.global_optimizer.zero_grad()
        total_loss.backward()
        
        # propagate local gradients to global parameters
        for local_params, global_params in zip(self.local_network.parameters(), self.global_network.parameters()):
            global_params._grad = local_params._grad
        self.global_optimizer.step()
        
        return total_loss
    
    def compute_loss(self, trajectory):
        
        states = torch.FloatTensor([sarsd[0] for sarsd in trajectory])
        states = states.to(device)
        
        actions = torch.LongTensor([sarsd[1] for sarsd in trajectory]).view(-1, 1)
        actions = actions.to(device)
        
        rewards = torch.FloatTensor([sarsd[2] for sarsd in trajectory])
        rewards = rewards.to(device)
        
        next_states = torch.FloatTensor([sarsd[3] for sarsd in trajectory])
        next_states = next_states.to(device)
        
        dones = torch.FloatTensor([sarsd[4] for sarsd in trajectory]).float().view(-1, 1)
        dones = dones.to(device)
        
        # compute value target
        discounted_rewards = [torch.sum(torch.FloatTensor([self.gamma**i for i in \
                              range(rewards[j:].size(0))]) * rewards[j:]) for j in range(rewards.size(0))]
        
        # rewards plus next state values
        rewards = rewards.view(-1, 1) + torch.FloatTensor(discounted_rewards).view(-1, 1).to(device)
        
        # compute policy loss with entropy bonus & value loss
        probs, state_values = self.local_network(states)
        m = Categorical(probs)
        
        value_loss = F.mse_loss(rewards.detach(), state_values)
        
        # compute entropy bonus
        dist_entropy = []
        for prob in probs:
            dist_entropy.append(-torch.sum(prob.mean() * torch.log(prob)))
        dist_entropy = torch.stack(dist_entropy).sum()
        
        advantages = rewards - values
        policy_loss = -m.log_prob(actions.view(actions.size(0))).view(-1, 1) * advantages.detach()
        policy_loss = policy_loss.mean() 
        
        total_loss = policy_loss + value_loss - 0.001 * dist_entropy
        
        return total_loss

    def sync_with_global(self):
        
        self.local_network.load_state_dict(self.global_network.state_dict())
    
    def run(self):
        
        state = env.reset()
        score = 0
        
        trajectory = []
        
        while self.global_episode.value < self.global_num_episodes:
            
            action = self.act(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append([state, action, reward, next_state, done])
            
            score += reward
            
            if done:
                with self.global_episode.get_lock():
                    self.global_episode.value += 1
                
                print(f'\rEpisode: {str(self.global_episode.value)}, Worker: {self.name}, Average Score: {score:.2f}')
                
                total_loss = self.update_global(trajectory)
                self.sync_with_global()
                
                trajectory = []
                score = 0
                state = self.env.reset()
            
            state = next_state

## Define [A3C](https://arxiv.org/pdf/1602.01783.pdf) Agent

In [11]:
class A3CAgent():
    
    def __init__(self, env, global_num_episodes, seed):
        
        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.hidden_size = 256
        self.action_size = env.action_space.n
        
        self.gamma = GAMMA
        self.lr = LR
        
        self.global_episode = mp.Value('i', 0)
        self.global_num_episodes = global_num_episodes
        
        self.global_network = HybridNetwork(self.state_size, self.hidden_size, self.action_size, seed)
        self.global_optimizer = optim.Adam(self.global_network.parameters(), lr=self.lr)
        self.workers = [AgentWorker(self.global_episode.value, env, seed, 
                                    self.global_network, self.global_optimizer, self.global_episode, self.global_num_episodes)]
    
    def train(self):
        
        for worker in self.workers: worker.start() 
        for worker in self.workers: worker.join()
        print('Training completed.')        
        
    def save(self, agent_path):
        
        if not os.path.exists('./agents/'): os.makedirs('./agents/')
        torch.save(agent.global_network.state_dict(), agent_path)

In [12]:
agent = A3CAgent(env, GLOBAL_NUM_EPISODES, seed=90)

## Train The Agent

In [13]:
scores = agent.train()
agent.save(f'./agents/A3C_{ENV_NAME}.pth')

Episode: 1, Worker: worker0, Average Score: 38.00
Episode: 2, Worker: worker0, Average Score: 36.00
Episode: 3, Worker: worker0, Average Score: 12.00
Episode: 4, Worker: worker0, Average Score: 55.00
Episode: 5, Worker: worker0, Average Score: 11.00
Episode: 6, Worker: worker0, Average Score: 30.00
Episode: 7, Worker: worker0, Average Score: 23.00
Episode: 8, Worker: worker0, Average Score: 18.00
Episode: 9, Worker: worker0, Average Score: 17.00
Episode: 10, Worker: worker0, Average Score: 62.00
Episode: 11, Worker: worker0, Average Score: 20.00
Episode: 12, Worker: worker0, Average Score: 30.00
Episode: 13, Worker: worker0, Average Score: 129.00
Episode: 14, Worker: worker0, Average Score: 38.00
Episode: 15, Worker: worker0, Average Score: 30.00
Episode: 16, Worker: worker0, Average Score: 64.00
Episode: 17, Worker: worker0, Average Score: 35.00
Episode: 18, Worker: worker0, Average Score: 48.00
Episode: 19, Worker: worker0, Average Score: 36.00
Episode: 20, Worker: worker0, Average S

Episode: 161, Worker: worker0, Average Score: 81.00
Episode: 162, Worker: worker0, Average Score: 102.00
Episode: 163, Worker: worker0, Average Score: 51.00
Episode: 164, Worker: worker0, Average Score: 100.00
Episode: 165, Worker: worker0, Average Score: 65.00
Episode: 166, Worker: worker0, Average Score: 47.00
Episode: 167, Worker: worker0, Average Score: 66.00
Episode: 168, Worker: worker0, Average Score: 67.00
Episode: 169, Worker: worker0, Average Score: 146.00
Episode: 170, Worker: worker0, Average Score: 144.00
Episode: 171, Worker: worker0, Average Score: 72.00
Episode: 172, Worker: worker0, Average Score: 56.00
Episode: 173, Worker: worker0, Average Score: 128.00
Episode: 174, Worker: worker0, Average Score: 52.00
Episode: 175, Worker: worker0, Average Score: 55.00
Episode: 176, Worker: worker0, Average Score: 152.00
Episode: 177, Worker: worker0, Average Score: 77.00
Episode: 178, Worker: worker0, Average Score: 100.00
Episode: 179, Worker: worker0, Average Score: 69.00
Episo

Episode: 318, Worker: worker0, Average Score: 104.00
Episode: 319, Worker: worker0, Average Score: 129.00
Episode: 320, Worker: worker0, Average Score: 105.00
Episode: 321, Worker: worker0, Average Score: 100.00
Episode: 322, Worker: worker0, Average Score: 112.00
Episode: 323, Worker: worker0, Average Score: 172.00
Episode: 324, Worker: worker0, Average Score: 140.00
Episode: 325, Worker: worker0, Average Score: 109.00
Episode: 326, Worker: worker0, Average Score: 132.00
Episode: 327, Worker: worker0, Average Score: 165.00
Episode: 328, Worker: worker0, Average Score: 126.00
Episode: 329, Worker: worker0, Average Score: 105.00
Episode: 330, Worker: worker0, Average Score: 146.00
Episode: 331, Worker: worker0, Average Score: 99.00
Episode: 332, Worker: worker0, Average Score: 85.00
Episode: 333, Worker: worker0, Average Score: 112.00
Episode: 334, Worker: worker0, Average Score: 139.00
Episode: 335, Worker: worker0, Average Score: 116.00
Episode: 336, Worker: worker0, Average Score: 15

Episode: 474, Worker: worker0, Average Score: 132.00
Episode: 475, Worker: worker0, Average Score: 124.00
Episode: 476, Worker: worker0, Average Score: 147.00
Episode: 477, Worker: worker0, Average Score: 127.00
Episode: 478, Worker: worker0, Average Score: 175.00
Episode: 479, Worker: worker0, Average Score: 119.00
Episode: 480, Worker: worker0, Average Score: 163.00
Episode: 481, Worker: worker0, Average Score: 158.00
Episode: 482, Worker: worker0, Average Score: 202.00
Episode: 483, Worker: worker0, Average Score: 142.00
Episode: 484, Worker: worker0, Average Score: 171.00
Episode: 485, Worker: worker0, Average Score: 195.00
Episode: 486, Worker: worker0, Average Score: 188.00
Episode: 487, Worker: worker0, Average Score: 172.00
Episode: 488, Worker: worker0, Average Score: 152.00
Episode: 489, Worker: worker0, Average Score: 227.00
Episode: 490, Worker: worker0, Average Score: 92.00
Episode: 491, Worker: worker0, Average Score: 197.00
Episode: 492, Worker: worker0, Average Score: 1

Episode: 629, Worker: worker0, Average Score: 359.00
Episode: 630, Worker: worker0, Average Score: 256.00
Episode: 631, Worker: worker0, Average Score: 472.00
Episode: 632, Worker: worker0, Average Score: 434.00
Episode: 633, Worker: worker0, Average Score: 380.00
Episode: 634, Worker: worker0, Average Score: 270.00
Episode: 635, Worker: worker0, Average Score: 294.00
Episode: 636, Worker: worker0, Average Score: 199.00
Episode: 637, Worker: worker0, Average Score: 316.00
Episode: 638, Worker: worker0, Average Score: 115.00
Episode: 639, Worker: worker0, Average Score: 617.00
Episode: 640, Worker: worker0, Average Score: 507.00
Episode: 641, Worker: worker0, Average Score: 668.00
Episode: 642, Worker: worker0, Average Score: 552.00
Episode: 643, Worker: worker0, Average Score: 341.00
Episode: 644, Worker: worker0, Average Score: 126.00
Episode: 645, Worker: worker0, Average Score: 404.00
Episode: 646, Worker: worker0, Average Score: 513.00
Episode: 647, Worker: worker0, Average Score: 

Episode: 783, Worker: worker0, Average Score: 781.00
Episode: 784, Worker: worker0, Average Score: 96.00
Episode: 785, Worker: worker0, Average Score: 2774.00
Episode: 786, Worker: worker0, Average Score: 915.00
Episode: 787, Worker: worker0, Average Score: 49.00
Episode: 788, Worker: worker0, Average Score: 77.00
Episode: 789, Worker: worker0, Average Score: 2836.00
Episode: 790, Worker: worker0, Average Score: 1016.00
Episode: 791, Worker: worker0, Average Score: 2881.00
Episode: 792, Worker: worker0, Average Score: 2056.00
Episode: 793, Worker: worker0, Average Score: 2911.00
Episode: 794, Worker: worker0, Average Score: 1869.00
Episode: 795, Worker: worker0, Average Score: 2304.00
Episode: 796, Worker: worker0, Average Score: 1838.00
Episode: 797, Worker: worker0, Average Score: 2631.00
Episode: 798, Worker: worker0, Average Score: 1902.00
Episode: 799, Worker: worker0, Average Score: 57.00
Episode: 800, Worker: worker0, Average Score: 2727.00
Episode: 801, Worker: worker0, Average

Episode: 938, Worker: worker0, Average Score: 137.00
Episode: 939, Worker: worker0, Average Score: 60.00
Episode: 940, Worker: worker0, Average Score: 122.00
Episode: 941, Worker: worker0, Average Score: 61.00
Episode: 942, Worker: worker0, Average Score: 143.00
Episode: 943, Worker: worker0, Average Score: 163.00
Episode: 944, Worker: worker0, Average Score: 61.00
Episode: 945, Worker: worker0, Average Score: 30.00
Episode: 946, Worker: worker0, Average Score: 69.00
Episode: 947, Worker: worker0, Average Score: 74.00
Episode: 948, Worker: worker0, Average Score: 46.00
Episode: 949, Worker: worker0, Average Score: 42.00
Episode: 950, Worker: worker0, Average Score: 152.00
Episode: 951, Worker: worker0, Average Score: 56.00
Episode: 952, Worker: worker0, Average Score: 154.00
Episode: 953, Worker: worker0, Average Score: 182.00
Episode: 954, Worker: worker0, Average Score: 56.00
Episode: 955, Worker: worker0, Average Score: 36.00
Episode: 956, Worker: worker0, Average Score: 49.00
Episo

---