In [28]:
import gymnasium as gym
import numpy as np
from itertools import count
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from tqdm import tqdm

#Code implementations derived from https://github.com/mimoralea/gdrl

In [25]:
class FCDAP(nn.Module):
    def __init__(self, 
                 input_dim, 
                 output_dim,
                 hidden_dims=(32,32), #define hidden layers as tuple where each element is an int representing # of neurons at a layer
                 activation_fc=nn.ReLU):
        super(FCDAP, self).__init__()
        self.activation_fc = activation_fc

        hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            hidden_layers.append(activation_fc())
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            activation_fc(),
            *hidden_layers,
            nn.Linear(hidden_dims[-1], output_dim)
        )

        device = "cpu"
        #if torch.cuda.is_available():
        #    device = "cuda"
        self.device = torch.device(device)
        self.to(self.device)

    def _format(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, 
                             dtype=torch.float32)
            x = x.unsqueeze(0)
        return x
        
    def forward(self, state):
        x = self._format(state)
        return self.layers(x)

    #select and return action, corresponding log prob of the action, and entropy of the distribution
    def select_action(self, state):
        logits = self.forward(state)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action.item(), dist.log_prob(action).unsqueeze(-1), dist.entropy().unsqueeze(-1)

In [26]:
#Fully-connected value network (state observation -> state value)
class FCV(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hidden_dims=(32,32), #define hidden layers as tuple where each element is an int representing # of neurons at a layer
                 activation_fc=nn.ReLU):
        super(FCV, self).__init__()

        hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims)-1):
            hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            hidden_layers.append(activation_fc())
        
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            activation_fc(),
            *hidden_layers,
            nn.Linear(hidden_dims[-1], 1)
        )

        device = "cpu"
        #if torch.cuda.is_available():
        #    device = "cuda"
        self.device = torch.device(device)
        self.to(self.device)
        
    def _format(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, 
                             device=self.device, 
                             dtype=torch.float32)
            x = x.unsqueeze(0)
        return x

    def forward(self, state):
        x = self._format(state)
        return self.layers(x)

In [None]:
class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
        super(SharedAdam, self).__init__(
            params, lr=lr, betas=betas, eps=eps, 
            weight_decay=weight_decay, amsgrad=amsgrad)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['shared_step'] = torch.zeros(1).share_memory_()
                state['exp_avg'] = torch.zeros_like(p.data).share_memory_()
                state['exp_avg_sq'] = torch.zeros_like(p.data).share_memory_()
                if weight_decay:
                    state['weight_decay'] = torch.zeros_like(p.data).share_memory_()
                if amsgrad:
                    state['max_exp_avg_sq'] = torch.zeros_like(p.data).share_memory_()

    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                self.state[p]['steps'] = self.state[p]['shared_step'].item()
                self.state[p]['shared_step'] += 1
        super().step(closure)

![Alt text](image.png)
https://arxiv.org/pdf/1602.01783.pdf

In [27]:
from copy import deepcopy
class A3CWorker(mp.Process):
    def __init__(self, rank, make_env_fn, shared_policy_model, shared_value_model, shared_policy_optimizer, shared_value_optimizer, shared_T, max_T, max_td_steps=5, gamma=1.0, entropy_weight=1e-4, seed=0):
        super(A3CWorker, self).__init__()
        self.seed = seed + rank
        self.rank = rank
        self.env = make_env_fn()
        self.gamma = gamma
        self.entropy_weight = entropy_weight
        self.shared_policy_model = shared_policy_model
        self.shared_value_model = shared_value_model
        self.shared_policy_optimizer = shared_policy_optimizer
        self.shared_value_optimizer = shared_value_optimizer
        self.max_td_steps = max_td_steps
        self.T = shared_T
        self.max_T = max_T
    
    def _optimize_model(self, rewards, values, log_probs, entropies):
        T = len(rewards)
        #Calculate n_step returns
        discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
        returns = np.array([np.sum(discounts[:T-t] * rewards[t:]) for t in range(T)])
        #drop return of final td step next_state
        discounts = torch.FloatTensor(discounts[:-1]).unsqueeze(1)
        returns = torch.FloatTensor(returns[:-1]).unsqueeze(1)
        
        log_probs = torch.cat(log_probs)
        entropies = torch.cat(entropies)
        values = torch.cat(values)

        #calculate "policy loss" as the negative policy gradient with weighted entropy
        advantages = returns - values #use advantage estimates (A_t = G_t-V(S_t)) instead of returns for policy gradient
        policy_grad = (advantages.detach() * log_probs).mean()
        policy_loss = -(policy_grad + self.entropy_weight*entropies.mean())

        self.shared_policy_optimizer.zero_grad()
        policy_loss.backward()
        #transfer gradients from local model to shared model and step
        for param, shared_param in zip(self.local_policy_model.parameters(), 
                                       self.shared_policy_model.parameters()):
            shared_param._grad = param.grad
        self.shared_policy_optimizer.step()
        #load updated shared model back into local model
        self.local_policy_model.load_state_dict(self.shared_policy_model.state_dict())

        value_loss = advantages.pow(2).mul(0.5).mean() #mean square error
        self.shared_value_optimizer.zero_grad()
        value_loss.backward()
        #transfer gradients from local model to shared model and step
        for param, shared_param in zip(self.local_value_model.parameters(), 
                                       self.shared_value_model.parameters()):
            shared_param._grad = param.grad
        self.shared_value_optimizer.step()
        #load updated shared model back into local model
        self.local_value_model.load_state_dict(self.shared_value_model.state_dict())


    def run(self):
        torch.manual_seed(self.seed) ; np.random.seed(self.seed) ; random.seed(self.seed) #set seeds
        #initialize local models by deep-copying shared models
        self.local_policy_model = deepcopy(self.shared_policy_model)
        self.local_value_model = deepcopy(self.shared_value_model)
        state = self.env.reset(seed=self.seed)[0]
        while self.T < self.max_T:
            log_probs, rewards, values, entropies = [], [], [], []
            #gather data for n_step td
            for _ in range(self.max_td_steps):
                action, log_prob, entropy = self.local_policy_model.select_action(state) #select action and get corresponding log prob and dist entropy
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                #gather log probs, rewards, and entropies to calculate policy gradient
                log_probs.append(log_prob)
                rewards.append(reward)
                entropies.append(entropy)
                #get estimated value of current state
                values.append(self.local_value_model(state))      
                if terminated or truncated:
                    state = self.env.reset(seed=self.seed)[0]
                else:
                    state = next_state
                
                self.T += 1

            #value of next state for final td step
            R = 0 if terminated else self.local_value_model(next_state)
            rewards.append(R)

In [1]:
class A3C():
    def __init__(self, 
                policy_model_fn = lambda num_obs, nA: FCDAP(num_obs, nA), #state vars, nA -> model
                policy_optimizer_fn = lambda params, lr : SharedAdam(params, lr), #model params, lr -> optimizer
                policy_optimizer_lr = 1e-4, #optimizer learning rate
                value_model_fn = lambda num_obs: FCV(num_obs), #state vars  -> model
                value_optimizer_fn = lambda params, lr : SharedAdam(params, lr), #model params, lr -> optimizer
                value_optimizer_lr = 1e-4, #optimizer learning rate
                entropy_weight = 1e-4
                ):
        self.policy_model_fn = policy_model_fn
        self.policy_optimizer_fn = policy_optimizer_fn
        self.policy_optimizer_lr = policy_optimizer_lr
        self.value_model_fn = value_model_fn
        self.value_optimizer_fn = value_optimizer_fn
        self.value_optimizer_lr = value_optimizer_lr
        self.entropy_weight = entropy_weight

    def _init_model(self, env, policy_lr=None, value_lr=None):
        if not policy_lr:
            policy_lr = self.policy_optimizer_lr
        if not value_lr:
            value_lr = self.value_optimizer_lr

        self.policy_model = self.policy_model_fn(len(env.observation_space.sample()), env.action_space.n)
        self.policy_optimizer = self.policy_optimizer_fn(self.policy_model.parameters(), lr=policy_lr)

        self.value_model = self.value_model_fn(len(env.observation_space.sample()))
        self.value_optimizer = self.value_optimizer_fn(self.value_model.parameters(), lr=value_lr)

    def train(self, make_env_fn, gamma=1.0, num_episodes=100, policy_lr=None, value_lr=None, save_models=None):
        pass
