## A3C Half Cheetah Training Continuous Control 

[Mnih, Volodymyr, et al. "Asynchronous methods for deep reinforcement learning."](http://proceedings.mlr.press/v48/mniha16.pdf), short for A3C, is a classic policy gradient method with a special focus on parallel training.

In A3C, the critics learn the value function, $V(s; w)$ while multiple actors are trained in parallel and get synced with global parameters from time to time.  \
Hence, A3C is designed to work well for parallel training.

Let’s use the state-value function as an example. The loss function for state value is to minimize the mean squared error,\
$\mathcal{J}_v (w) = (G_t - V(s; w))^2$ and gradient descent can be applied to find the optimal $w$. This state-value function \
is used as the baseline in the policy gradient update.

<img src="../assets/images/a3c.png" width="auto" height="auto">

Here is the algorithm outline:

1. We have global parameters, $\theta$ and $w$; similar thread-specific parameters  $\theta'$ and $w'$.
2. Initialize the time step $t=1$
3. While T < T_MAX:
   1. Reset gradient dθ = 0 and dw = 0.
   2. Synchronize thread-specific parameters with global ones: θ’ = θ and w’ = w.
   3. $t_\text{start}$ = t and get $s_t$
   4. While ($s_t!= TERMINAL$) and $(t-t_\text{start} <= t_\text{max})$:
      1. Pick the action $a_t \sim \pi(a_t \vert s_t; \theta’)$ and recieve a new reward $r_t$ and a new state $s_{t+1}$
      2. Update $t = t + 1$ and $T = T + 1$.
   5. Initialize the variable that holds the return estimation
   $$ R = \begin{cases} 0 & \text{if } s_t \text{ is TERMINAL} \\
         V(s_t; w’) & \text{otherwise} \end{cases} $$
   6. For $i = t-1, \dots, t_\text{start}$:
      1. $R \leftarrow r_i + \gamma R$, here R is a MC measure of $G_i$
      2. Accumulate gradients w.r.t $\theta': d\theta \leftarrow d\theta + \nabla_{\theta’} \log \pi(a_i \vert s_i; \theta’)(R - V(s_i; w’))$ \
         Accumulate gradients w.r.t $w': dw \leftarrow dw + \nabla_{w’} (R - V(s_i; w’))^2$
   7. Update synchronously $\theta$ using $d\theta$, and $w$ using $dw$.


### Load dependencies for ENV

In [1]:
import os
import datetime

import numpy as np 
import gymnasium as gym

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    SummaryWriter = None  # type: ignore[misc, assignment]


env = gym.make('HalfCheetah-v5')
observation, info = env.reset()
print(env.action_space)
print(env.observation_space)
print(env.observation_space.shape[-1])
print(env.action_space.shape[0])

Box(-1.0, 1.0, (6,), float32)
Box(-inf, inf, (17,), float64)
17
6


In [2]:
import torch 
a = torch.randn(1,6)
print(a)
b = a.squeeze().numpy().clip(env.action_space.low, env.action_space.high)
print(b)

tensor([[ 0.0166, -1.5668, -1.6364,  2.2599, -0.0696, -1.2470]])
[ 0.01662412 -1.         -1.          1.         -0.06960451 -1.        ]


## Define Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

# Prioritize device: CUDA > MPS > CPU
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA is available. Using CUDA.")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("MPS backend is available. Using MPS.")
else:
    DEVICE = torch.device("cpu")
    print("Neither CUDA nor MPS is available. Using CPU.")

# NOTE: We running multiprocessing code, we going to implement it on CPU
DEVICE = torch.device("cpu")

 
class ActorCriticContinuous(nn.Module):

    def __init__(self, state_size, action_size, fc_units=128):
        """Initialize parameters and build model."""
        super(ActorCriticContinuous, self).__init__()

        # First we have same body for the actor and critic, in  our case 64 ReLU units
        self.fc1 = nn.Linear(state_size, fc_units)

        # Actor head
        self.fc_actor = nn.Linear(fc_units, fc_units)

        # Actor Head presented in form of mean modeled by linear layer and log(std) to use Gaussian Distribution
        self.fc_mean = nn.Linear(fc_units, action_size) # fully connected 
        self.log_std = nn.Parameter(torch.zeros(action_size)) # standard initialization std = exp(log_std) = exp(0) = 1
        
        # Critic Head
        self.fc_critic = nn.Linear(fc_units, fc_units)
        self.fc_critic_out = nn.Linear(fc_units, 1)
      

    def forward(self, state):
        """Forward method implementation."""
        x = F.relu(self.fc1(state))
        
        value = self.fc_critic_out(F.relu(self.fc_critic(x)))
        
        action_mean = self.fc_mean(F.relu(self.fc_actor(x)))
        # action_mean = torch.tanh(self.fc_mean(x))*2 #lets make it between -2,2
        
        # action_mean = self.fc_mean(x) 
        action_std = self.log_std.exp() + 0.001 # Convert log-std to std
        distribution = Normal(action_mean, action_std)
        action = distribution.sample() # When to Use rsample: Use rsample if gradients need to flow through sampled actions, such as in:
                                    # SAC or DDPG: Off-policy RL algorithms.
                                    # VAEs or auxiliary losses involving actions
        # we summing up here last dim, for log probs and entropies, as we have multiple actions executed same time
        # (1,6) in case of cheetah we like to turn in (1,1), if it was (1,1) it just stays this way 
        log_prob = distribution.log_prob(action).sum(-1, keepdim=True)  
        entropy = distribution.entropy().sum(-1, keepdim=True)

        return value, action, log_prob, entropy, action_mean


class ActorCriticSeparate(nn.Module):

    def __init__(self, state_size, action_size, fc_units=128):
        """Initialize parameters and build model."""
        super(ActorCriticSeparate, self).__init__()

        # Actor head
        self.fc_actor_0 = nn.Linear(state_size, fc_units)
        self.fc_actor_1 = nn.Linear(fc_units, fc_units)

        # Actor Head presented in form of mean modeled by linear layer and log(std) to use Gaussian Distribution
        self.fc_mean = nn.Linear(fc_units, action_size) # fully connected 
        self.log_std = nn.Parameter(torch.zeros(action_size)) # standard initialization std = exp(log_std) = exp(0) = 1
        
        # Critic Head
        self.fc_critic_0 = nn.Linear(state_size, fc_units)
        self.fc_critic_1 = nn.Linear(fc_units, fc_units)
        self.fc_critic_out = nn.Linear(fc_units, 1)
    

    def forward(self, state):
        """Forward method implementation."""
        
        value_x = F.relu(self.fc_critic_0(state))
        value_x = F.relu(self.fc_critic_1(value_x))
        value = self.fc_critic_out(value_x)
        
        actor_x = F.relu(self.fc_actor_0(state))
        actor_x = F.relu(self.fc_actor_1(actor_x))
        action_mean = self.fc_mean(actor_x)
        # action_mean = torch.tanh(self.fc_mean(x))*2 #lets make it between -2,2
        
        # action_mean = self.fc_mean(x) 
        action_std = self.log_std.exp() # Convert log-std to std
        distribution = Normal(action_mean, action_std)
        action = distribution.sample() # When to Use rsample: Use rsample if gradients need to flow through sampled actions, such as in:
                                    # SAC or DDPG: Off-policy RL algorithms.
                                    # VAEs or auxiliary losses involving actions
        log_prob = distribution.log_prob(action).sum(-1, keepdim=True) 
        entropy = distribution.entropy().sum(-1, keepdim=True)

        return value, action, log_prob, entropy, action_mean


MPS backend is available. Using MPS.


## Shared Optimizer

In [3]:
import torch 
import torch.optim as optim


class SharedAdam(optim.Adam):
    """
    Implement Adam algorithm with shared states
    """

    def __init__(self,
                 params,
                 lr=1e-3,
                 betas=(0.9, 0.99),
                 eps=1e-8,
                 weight_decay=0):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        # State initialization
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)

                # share in memory
                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

## Mujoco ENV 

In [4]:
from gymnasium.wrappers import  NormalizeObservation

# TODO: think if add more envs
def create_mujoco_env(env_id, rgb_array=True):
    if rgb_array:
        env = gym.make(env_id, render_mode='rgb_array')
    else:
        env = gym.make(env_id)
    env = NormalizeObservation(env)
    return env

## Global Hyperparameters

In [5]:
GAMMA = 0.9
GAE_LAMBDA = 1.0
VF_COEF = 0.5
ENTROPY_COEF = 0.01
MAX_GRAD_NORM = 0.5
MAX_EP = 3000
N_STEP = 5
WORKERS_NUM = 4 # on current machine I checked mp.cpu_count(), I have 12 processes
ENV_ID = 'HalfCheetah-v5'

## Worker

In [None]:
import torch.multiprocessing as mp

class Hyperparameters:
    def __init__(self,
                 env_id = 'HalfCheetah-v5',
                 lr = 1e-4,
                 gamma = 0.99,
                 gae_lambda = 1.0,
                 vf_coef = 0.5,
                 entropy_coef = 0.01,
                 n_step = 5,
                 max_grad_norm = 0.5,
                 net_units = 64,
                 max_ep = 3000,
                 suffix = ''
                 ):
        self.env_id = env_id
        self.lr = lr
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.vf_coef = vf_coef
        self.entropy_coef = entropy_coef
        self.n_step = n_step
        self.max_ep = max_ep
        self.max_grad_norm = max_grad_norm
        self.net_units = net_units
        self.name = f'{env_id}_n_{n_step}_g_{gamma}_gae_{gae_lambda}_net_{net_units}_lr_{lr}_ent_{entropy_coef}{suffix}'


class Worker(mp.Process):
    def __init__(self, 
                 rank: int, 
                 config: Hyperparameters,
                 global_model: ActorCriticContinuous, 
                 shared_optimizer: SharedAdam, 
                 global_episodes_num: mp.Value,  # type: ignore
                 global_episode_reward: mp.Value,  # type: ignore
                 results_queue: mp.Queue):
        super(Worker, self).__init__()
        self.name = 'W%i' % rank
        # load hyperparamters
        self.config = config
        # create env copy 
        self.env = create_mujoco_env(config.env_id)
        # global model
        self.global_model = global_model
        # shared optmizer
        self.shared_optimizer = shared_optimizer
        # global episode number
        self.global_episode_num = global_episodes_num
        # global episodes reward
        self.global_episode_reward = global_episode_reward
        # results queue 
        self.results_queue = results_queue
        # local model
        self.local_model =  ActorCriticContinuous(self.env.observation_space.shape[0], self.env.action_space.shape[0]) 

        print(f"Worker {self.name} is initialized.")
               
    
    def run(self):
        print(f"Worker {self.name} is running.")
        while self.global_episode_num.value < self.config.max_ep:
            obs, _ = self.env.reset()
            state = torch.from_numpy(obs)

            done = False
            episode_reward = 0
            while not done:
                # sync with global model
                self.local_model.load_state_dict(self.global_model.state_dict()) 

                values = []
                log_probs = []
                entropies = []
                rewards = []
                
                # collect rollouts
                for _ in range(self.config.n_step):
                    # evaluate next actions using local network
                    value, action, log_prob, entropy, action_mean =  self.local_model(state.unsqueeze(0))
                    # print(value.shape, action.shape, log_prob.shape, entropy.shape)
                    selected_action = action.squeeze().numpy().clip(
                        self.env.action_space.low, 
                        self.env.action_space.high)
                    # make step in environment
                    obs, reward, terminated, truncated, _ = self.env.step(selected_action)
                    done = terminated or truncated 
                    episode_reward += reward
                    # TODO: may be add reward clipping (reward = max(min(reward,1), -1))
                    # some peopel do weird things like for Pendulum + (r+8.1)/8.1
                    # reward = max(min(reward,1), -1) #clipping reward not the best solution I do expect though....
                    # for the n=5 we seems getting above 1000, but then also falling down badly, sad 
                    # seems like env found betters solutions like 1500 for n=10, but then also fall down badly to -500
                    # for n=20 without reward cliping quickly above 1000

                    entropies.append(entropy)
                    values.append(value)
                    rewards.append(reward)
                    log_probs.append(log_prob)

                    state = torch.from_numpy(obs)
                    if done: 
                        break
                
                last_value, _, _, _, _ = self.local_model(state.unsqueeze(0))
                # we care that our environment has terminated and not truncated
                # when environment truncated, we still like to use value in GAE calculation 
                # as we only were limited by time and action could be great still 
                if terminated:
                    last_value = torch.zeros(1,1)
                values.append(last_value) # we have values on 1 more then rewards, entropies, log_probs as we interested in next state value
                
                # learn on collected data
                self._learn(values, rewards, log_probs, entropies)

                if done: 
                    self._make_record(episode_reward)
                    self.results_queue.put({
                        'type': 'reward',
                        'worker_id': self.name,
                        'reward': episode_reward,
                        })

    def _learn(self, values, rewards, log_probs, entropies):
            
            gae = torch.zeros(1,1)
            policy_loss = 0
            value_loss = 0
            entropy_loss = 0
            for step in reversed(range(len(rewards))):
                # calculate delta aka advantage td error
                delta = rewards[step] + self.config.gamma * values[step+1].detach() - values[step]
                
                # calcualte gae aka generalized advantage
                gae = delta + self.config.gamma * self.config.gae_lambda * gae
                # calculate value loss old way (descent)
                value_loss += 0.5 * gae.pow(2)
                # calcualte entropy loss (ascent)
                entropy_loss -= entropies[step]
                # calculate policy loss (ascent)
                policy_loss -= log_probs[step]*gae.detach()
            
            # total loss
            loss = (policy_loss + self.config.vf_coef * value_loss + self.config.entropy_coef * entropy_loss)
                              
            self.shared_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.local_model.parameters(), self.config.max_grad_norm)
            self._ensure_shared_grads()
            self.shared_optimizer.step()

            # Send data to the queue
            # self.results_queue.put({
            #     'type': 'loss',
            #     'worker_id': self.name,
            #     'losses': {
            #         'entropy_loss': entropy_loss.item(),
            #         'value_loss': value_loss.item(),
            #         'policy_loss': policy_loss.item(),
            #         'total_loss': loss.item()
            #     },
            # })


    def _ensure_shared_grads(self):
        for param, global_param in zip(self.local_model.parameters(),
                                        self.global_model.parameters()):
            global_param._grad = param.grad
                
    def _make_record(self, episode_reward):
        with self.global_episode_num.get_lock():
            self.global_episode_num.value += 1
        # with self.global_episode_reward.get_lock():
        #     if  self.global_episode_reward.value == 0.:
        #         self.global_episode_reward.value = episode_reward
        #     else:
        #         self.global_episode_reward.value = self.global_episode_reward.value * 0.99 \
        #             + episode_reward * 0.01
        
        # # self.results_queue.put(self.global_episode_reward.value)
        # print(
        #     self.name,
        #     "Ep:", self.global_episode_num.value,
        #     "| Ep_r: %.0f" % self.global_episode_reward.value,
        # )

## Train and more 

In [None]:
# if __nam e__ == "__main__":
    # mp.set_start_method("spawn", force=True)  # Override the existing context

env = create_mujoco_env(ENV_ID)
global_model = ActorCriticContinuous(env.observation_space.shape[0], env.action_space.shape[0])
global_model.share_memory()

shared_optimizer = SharedAdam(global_model.parameters(), lr=1e-3)
# shared_optimizer.share_memory()

global_episode_num = mp.Value('i', 0)
global_episdoe_reward = mp.Value('d', 0.0)
results_queue = mp.Queue(maxsize=MAX_EP+1)

# parallel workers 
# Create a list to keep track of processes
workers = [Worker(i, 
                global_model, 
                shared_optimizer, 
                global_episodes_num=global_episode_num, 
                global_episode_reward=global_episdoe_reward, 
                results_queue=results_queue) for i in range(WORKERS_NUM)]
for w in workers:
    w.start()
# some parallel while True running code can be done here in between,
# funny this can be done
# Wait for all processes to finish
for w in workers:
    w.join()


res = []     # record episode reward to plot
while True:
    r = results_queue.get()
    if r is not None:
        res.append(r)
    else:
        break

import matplotlib.pyplot as plt
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Episode')
plt.show()

### Video Play of some examples     

In [6]:
import os
import imageio
import numpy as np
from IPython.display import Video, display, HTML
import torch 

def record_video(env, policy, out_directory, out_name, fps=30):
    """
    Generate a replay video of the agent and display it in the notebook.
    :param env: Environment to record.
    :param policy: Policy used to determine actions.
    :param out_directory: Path to save the video.
    :param fps: Frames per second.
    """
    images = []
    done = False
    obs, _ = env.reset()
    img = env.render()

    times = 0
    while times != 1:
        # Preprocess the observation, set input to network to be difference
        state = torch.tensor(obs, dtype=torch.float32)

        # calculate actions and values
        value, action, action_prob, _, _ = policy(state)

        obs, reward, terminated, truncated, _ = env.step(action.squeeze().numpy()) 
        img = env.render()
        images.append(img)
        if terminated or truncated:
            obs, _ = env.reset() 
            times += 1
    
    # Save the video
    video_path = os.path.join(out_directory, out_name)
    imageio.mimsave(video_path, [np.array(img) for img in images], fps=fps)
    
    # Display the video in Jupyter notebook
    display(Video(video_path, embed=True, width=640, height=480))



def eval_policy(env, policy, num_episodes = 10):
    # Store rewards for each episode
    episode_rewards = []

    # Evaluation loop
    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0  # Keep track of total reward in the episode
        while not done:
            state = torch.tensor(obs, dtype=torch.float32)
            # Select the action using the trained model
            value, action, action_prob, _, _= policy(state)
            selected_action = action.squeeze().numpy().clip(
                        env.action_space.low, 
                        env.action_space.high)
            obs, reward, terminated, truncated, info = env.step(selected_action)
            total_reward += reward
            done = terminated or truncated
        
        episode_rewards.append(total_reward)  # Store the total reward for the episode
        print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    # Calculate mean and standard deviation
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)

    print(f"Mean Reward: {mean_reward}, Standard Deviation: {std_reward}")

In [12]:
import gymnasium as gym
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp

def create_mujoco_env(env_id: str):
    """Create a Mujoco environment with some standard wrappers."""
    env = gym.make(env_id, render_mode='rgb_array')
    env = gym.wrappers.NormalizeObservation(env)
    return env

# --- Actor-Critic Model ---
class ActorCritic(nn.Module):

    def __init__(self, 
                 env, 
                 fc_units=128):
        """Initialize parameters and build model."""
        super(ActorCritic, self).__init__()

        # Environment related constants
        state_size = env.observation_space.shape[0]
        action_size = env.action_space.shape[0]

        # Action Scaling
        action_low = torch.tensor(env.action_space.low, dtype=torch.float32)
        action_high = torch.tensor(env.action_space.high, dtype=torch.float32)
        self.action_scale = (action_high - action_low) / 2.0
        self.action_bias = (action_high + action_low) / 2.0
        
        # Actor head
        self.fc_actor = nn.Sequential(
            nn.Linear(state_size, fc_units),
            nn.ReLU(),
            nn.Linear(fc_units, fc_units),
            nn.ReLU()
        )

        # Critic head
        self.fc_critic = nn.Sequential(
            nn.Linear(state_size, fc_units),
            nn.ReLU(),
            nn.Linear(fc_units, fc_units),
            nn.ReLU()
        )

        # Actor head
        self.fc_mean = nn.Linear(fc_units, action_size)
        self.log_std = nn.Parameter(torch.zeros(action_size))

        # Critic Head
        self.fc_value = nn.Linear(fc_units, 1)

    def forward(self, state):
        """Forward method implementation."""
        
        value = self.fc_value(self.fc_critic(state))

        actor = self.fc_actor(state)
        action_mean = self.fc_mean(actor) # Shape: [batch_size, action_size]
        action_std = self.log_std.expand_as(action_mean).exp() # Expand to match batch size and Convert log-std to std
        distribution = torch.distributions.Normal(action_mean, action_std)
        
        action_raw = distribution.sample()
        action_tanh = torch.tanh(action_raw) # Squash action to [-1, 1]
        action = action_tanh * self.action_scale + self.action_bias

        log_prob = distribution.log_prob(action_raw)
        log_prob -= torch.log(self.action_scale * (1 - action_tanh.pow(2)) + 1e-6)
        log_prob = log_prob.sum(-1, keepdim=True)

        entropy = distribution.entropy().sum(-1, keepdim=True)

        return value, action, log_prob, entropy, action_mean

    def get_value(self, state):
        """Get the value of a state."""
        return self.fc_value(self.fc_critic(state))

In [None]:
eval_env = create_mujoco_env("HalfCheetah-v5")

checkpoint = torch.load('../parallel/runs/HalfCheetah-v5/A3C/lr0.0007_gamma0.99_gae0.95_entcoef0.001_steps40_nproc4_1740698473/best-torch.model', weights_only=False)
obs_rms = checkpoint["obs_rms"]
print(obs_rms)
eval_env.obs_rms.mean = obs_rms['mean']
eval_env.obs_rms.var = obs_rms['var']
eval_env.update_running_mean = False

eval_model = ActorCritic(eval_env)
eval_model.load_state_dict(checkpoint["actor_critic_state_dict"])
eval_model.eval()
# print(eval_model, net_arch)

eval_policy(eval_env, eval_model)
record_video(eval_env, eval_model, './videos', 'output_HalfCheetah-v5.mp4')

{'mean': array([-0.23465464,  0.23425916,  0.36460498,  0.12629797,  0.3252023 ,
       -0.07709596, -0.25408328, -0.08087034,  0.8295668 , -0.01423146,
        0.02308888,  0.08331062, -0.05104771,  0.05373305, -0.05551235,
       -0.00956518, -0.02555702], dtype=float32), 'var': array([1.1551289e-02, 3.6418128e-01, 8.9542590e-02, 8.7098792e-02,
       8.9260265e-02, 1.8935443e-01, 1.3567378e-01, 7.1290866e-02,
       9.7234517e-01, 4.6854475e-01, 2.0636289e+00, 1.6379667e+01,
       8.2736649e+01, 1.0510030e+02, 7.8262154e+01, 3.0124855e+01,
       3.6188259e+01], dtype=float32)}
Episode 1: Total Reward = 1738.4717508862163
Episode 2: Total Reward = 1874.0461478293632
Episode 3: Total Reward = 1955.2154433270357
Episode 4: Total Reward = 1929.9961169429855
Episode 5: Total Reward = 1862.267183228698
Episode 6: Total Reward = 1836.951228483455
Episode 7: Total Reward = 1859.2185334799017
Episode 8: Total Reward = 1632.7307366976156
Episode 9: Total Reward = 1683.54989097008
Episode 10

: 

<video width="640" height="480" controls>
  <source src="../assets/videos/a3c_HalfCheetah-v5.mp4" type="video/mp4">
</video>