# A3C, Asynchronous Advantage Actor Critic

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

from collections import namedtuple, deque

In [2]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
from model import PolicyNetwork, ValueNetwork, HybridNetwork

In [4]:
import warnings
warnings.simplefilter('ignore', UserWarning)

In [5]:
plt.style.use('ggplot')

## Set Configs

In [6]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

In [7]:
GAMMA = 0.99
LR = 1e-3
GLOBAL_NUM_EPISODES = 1000
PRINT_EVERY = 100

## Set Environment

In [8]:
ENV_NAME = 'CartPole-v0'
env = gym.make(ENV_NAME).unwrapped; env.seed(90);

[2020-02-21 16:17:22,560] Making new env: CartPole-v0
  result = entry_point.load(False)


In [9]:
print('Environment Display:')
env.reset() # reset environment to a new, random state
env.render()

print('State space {}'.format(env.observation_space))
print('Action space {}'.format(env.action_space))

Environment Display:
State space Box(4,)
Action space Discrete(2)


## Define Agent Worker

In [10]:
class AgentWorker(mp.Process):
    
    def __init__(self, i, env, seed, global_network, global_optimizer, global_episode, global_num_episodes):
        
        super(AgentWorker, self).__init__()
        
        self.name = f'worker{i}'

        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.hidden_size = 256
        self.action_size = env.action_space.n
        
        self.gamma = GAMMA
        
        self.local_network = HybridNetwork(self.state_size, self.hidden_size, self.action_size, seed).to(device)
        
        self.global_network = global_network
        self.global_optimizer = global_optimizer
        self.global_episode = global_episode
        self.global_num_episodes = global_num_episodes
        
    def act(self, state):
        
        state = torch.FloatTensor(state).to(device)
        probs, _ = self.local_network(state)
        m = Categorical(probs)
        
        action = m.sample()
        
        return action.cpu().detach().item()
    
    def update_global(self, trajectory):
        
        total_loss = self.compute_loss(trajectory)
        
        self.global_optimizer.zero_grad()
        total_loss.backward()
        
        # propagate local gradients to global parameters
        for local_params, global_params in zip(self.local_network.parameters(), self.global_network.parameters()):
            global_params._grad = local_params._grad
        self.global_optimizer.step()
        
        return total_loss
    
    def compute_loss(self, trajectory):
        
        states = torch.FloatTensor([sarsd[0] for sarsd in trajectory])
        states = states.to(device)
        
        actions = torch.LongTensor([sarsd[1] for sarsd in trajectory]).view(-1, 1)
        actions = actions.to(device)
        
        rewards = torch.FloatTensor([sarsd[2] for sarsd in trajectory])
        rewards = rewards.to(device)
        
        next_states = torch.FloatTensor([sarsd[3] for sarsd in trajectory])
        next_states = next_states.to(device)
        
        dones = torch.FloatTensor([sarsd[4] for sarsd in trajectory]).float().view(-1, 1)
        dones = dones.to(device)
        
        # compute value target
        discounted_rewards = [torch.sum(torch.FloatTensor([self.gamma**i for i in \
                              range(rewards[j:].size(0))]) * rewards[j:]) for j in range(rewards.size(0))]
        
        value_targets = rewards.view(-1, 1) + torch.FloatTensor(discounted_rewards).view(-1, 1).to(device)
        
        # compute policy loss with entropy bonus & value loss
        probs, values = self.local_network(states)
        m = Categorical(probs)
        
        value_loss = F.mse_loss(values, value_targets.detach())
        
        # compute entropy bonus
        entropy = []
        for prob in probs:
            entropy.append(-torch.sum(prob.mean() * torch.log(prob)))
        entropy = torch.stack(entropy).sum()
        
        advantage = value_targets - values
        policy_loss = -m.log_prob(actions.view(actions.size(0))).view(-1, 1) * advantage.detach()
        policy_loss = policy_loss.mean() 
        
        total_loss = policy_loss + value_loss - 0.001 * entropy
        
        return total_loss

    def sync_with_global(self):
        
        self.local_network.load_state_dict(self.global_network.state_dict())
    
    def run(self):
        
        state = env.reset()
        score = 0
        
        trajectory = []
        
        while self.global_episode.value < self.global_num_episodes:
            
            action = self.act(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append([state, action, reward, next_state, done])
            
            score += reward
            
            if done:
                with self.global_episode.get_lock():
                    self.global_episode.value += 1
                
                print(f'\rEpisode: {str(self.global_episode.value)}, Worker: {self.name}, Average Score: {score:.2f}')
                
                total_loss = self.update_global(trajectory)
                self.sync_with_global()
                
                trajectory = []
                score = 0
                state = self.env.reset()
            
            state = next_state

## Define [A3C](https://arxiv.org/pdf/1602.01783.pdf) Agent

In [11]:
class A3CAgent():
    
    def __init__(self, env, global_num_episodes, seed):
        
        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.hidden_size = 256
        self.action_size = env.action_space.n
        
        self.gamma = GAMMA
        self.lr = LR
        
        self.global_episode = mp.Value('i', 0)
        self.global_num_episodes = global_num_episodes
        
        self.global_network = HybridNetwork(self.state_size, self.hidden_size, self.action_size, seed)
        self.global_optimizer = optim.Adam(self.global_network.parameters(), lr=self.lr)
        self.workers = [AgentWorker(self.global_episode.value, env, seed, 
                                    self.global_network, self.global_optimizer, self.global_episode, self.global_num_episodes)]
    
    def train(self):
        
        for worker in self.workers: worker.start() 
        for worker in self.workers: worker.join()
        print('Training completed.')        
        
    def save(self, agent_path):
        
        if not os.path.exists('./agents/'): os.makedirs('./agents/')
        torch.save(agent.global_network.state_dict(), agent_path)

In [12]:
agent = A3CAgent(env, GLOBAL_NUM_EPISODES, seed=90)

## Train The Agent

In [13]:
scores = agent.train()
agent.save(f'./agents/A3C_{ENV_NAME}.pth')

Episode: 1, Worker: worker0, Average Score: 38.00
Episode: 2, Worker: worker0, Average Score: 36.00
Episode: 3, Worker: worker0, Average Score: 12.00
Episode: 4, Worker: worker0, Average Score: 55.00
Episode: 5, Worker: worker0, Average Score: 11.00
Episode: 6, Worker: worker0, Average Score: 30.00
Episode: 7, Worker: worker0, Average Score: 23.00
Episode: 8, Worker: worker0, Average Score: 18.00
Episode: 9, Worker: worker0, Average Score: 17.00
Episode: 10, Worker: worker0, Average Score: 62.00
Episode: 11, Worker: worker0, Average Score: 20.00
Episode: 12, Worker: worker0, Average Score: 30.00
Episode: 13, Worker: worker0, Average Score: 129.00
Training completed.


---