In [None]:
import math, random

import gymnasium as gym
import numpy as np
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F
from torch.distributions import Categorical

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
USE_CUDA = torch.cuda.is_available()
device = lambda inp: inp.cuda() if USE_CUDA else inp

In [None]:
## ENVIRONMENT

env_id = "CartPole-v1"

In [None]:
## NEURAL NETWORK

class ActorNet(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(ActorNet, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, num_actions),
            nn.Softmax(dim=1)
        )
        self.num_actions = num_actions
        
    def forward(self, state):
        return self.layers(state)
    
    def greedy_act(self, state):
        state = device(torch.FloatTensor(state))
        with torch.no_grad():
            prob = self.forward(state)
        action = prob.max(1)[1].cpu().numpy()
        return action

class CriticNet(nn.Module):
    def __init__(self, num_inputs):
        super(CriticNet, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
        
    def forward(self, state):
        return self.layers(state)

A2C: Synchronous Advantage Actor Critic

$$L_w = \left[R_t + \gamma\hat{v}_w\left(S_{t+1}\right) - \hat{v}_w\left(S_t\right)\right]^2$$
$$L_\theta = -\left[R_t + \gamma\hat{v}_w\left(S_{t+1}\right) - \hat{v}_w\left(S_t\right)\right]\ln\pi_\theta\left(A_t{\vert}S_t\right)$$

In [None]:
## A2C Agent

class A2CAgent:
    def __init__(self, env_id, gamma, lr, num_frames, num_steps, batch_size, ppo_epochs, ppo_rel_batch_size, ppo_clip_param):
        self.envs = gym.vector.make(env_id, num_envs=batch_size)
        self.gamma = gamma
        self.lr = lr
        self.num_frames = num_frames
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.ppo_epochs = ppo_epochs
        self.ppo_batch_size = int(ppo_rel_batch_size * num_steps * batch_size)
        self.ppo_clip_param = ppo_clip_param
        
        self.actor = device(ActorNet(self.envs.single_observation_space.shape[0], self.envs.single_action_space.n))
        self.critic = device(CriticNet(self.envs.single_observation_space.shape[0]))
        self.optimizer = optim.Adam([{'params': self.actor.parameters()}, {'params': self.critic.parameters()}], lr=lr)

    def train(self):
        losses = [0.]
        all_rewards = []
        episode_reward = np.zeros(self.batch_size)
        
        state, _ = self.envs.reset()
        frame_idx = 0
        while frame_idx < self.num_frames:
            log_probs = []
            values = []
            rewards = []
            not_term_masks = []
            trunc_masks = []
            next_values = []
            states = []
            actions = []
            
            for _ in range(self.num_steps):
                state = device(torch.FloatTensor(state))
                with torch.no_grad():
                    prob = self.actor(state)
                    value = self.critic(state)
                dist = Categorical(probs=prob)
        
                action = dist.sample()
                next_state, reward, terminated, truncated, info = self.envs.step(action.cpu().numpy())
        
                log_prob = dist.log_prob(action)
                
                log_probs.append(log_prob)
                values.append(value)
                rewards.append(device(torch.FloatTensor(reward).unsqueeze(1)))
                not_term_masks.append(device(torch.FloatTensor(1 - terminated).unsqueeze(1)))
                
                trunc_masks.append(device(torch.BoolTensor(truncated).unsqueeze(1)))
                with torch.no_grad():
                    trunc_next_state = np.stack(
                        [
                            info['final_observation'][i] if truncated[i] else next_state[i,:]
                            for i in range(truncated.shape[0])
                        ],
                        axis=0
                    )
                    next_values.append(self.critic(device(torch.FloatTensor(trunc_next_state))))

                states.append(state)
                actions.append(action)
                
                state = next_state
                frame_idx += 1

                episode_reward += reward
                all_rewards.extend(episode_reward[np.logical_or(terminated, truncated)].tolist())
                episode_reward[np.logical_or(terminated, truncated)] = 0
                
                if frame_idx % 200 == 0:
                    self.plot_training(frame_idx, all_rewards, losses)

            returns = self.compute_returns(rewards, not_term_masks, trunc_masks, next_values)
            
            log_probs = torch.cat(log_probs)  # [num_steps*batch_size]
            returns = torch.cat(returns)  # [num_steps*batch_size]
            values = torch.cat(values)  # [num_steps*batch_size]
            states = torch.cat(states)  # [num_steps*batch_size]
            actions = torch.cat(actions)  # [num_steps*batch_size]
        
            advantages = returns - values

            loss_item = 0.
            for _ in range(self.ppo_epochs):
                idc = torch.randint(low=0, high=self.num_steps * self.batch_size, size=(self.ppo_batch_size,))
                old_state = states[idc, :]
                old_action = actions[idc]
                old_log_prob = log_probs[idc]
                old_return = returns[idc, :]
                old_advantage = advantages[idc, :]

                new_prob = self.actor(old_state)
                new_value = self.critic(old_state)
                new_dist = Categorical(probs=new_prob)

                new_log_prob = new_dist.log_prob(old_action)
                new_entropy = new_dist.entropy().mean()

                ratio = (new_log_prob - old_log_prob).exp()
                clip1 = ratio * old_advantage
                clip2 = torch.clamp(ratio, 1. - self.ppo_clip_param, 1. + self.ppo_clip_param) * old_advantage
        
                actor_loss  = torch.min(clip1, clip2).mean()
                critic_loss = (old_return - new_value).pow(2).mean()
        
                loss = -actor_loss + 0.5 * critic_loss - 0.001 * new_entropy
            
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                loss_item +=loss.item()
                
            losses.append(loss_item / self.ppo_epochs)
        
        self.envs.close()

    def compute_returns(self, rewards, not_term_masks, trunc_masks, next_values):
        R = next_values[-1]
        returns = []
        for step in reversed(range(len(rewards))):
            R[trunc_masks[step]] = (next_values[step])[trunc_masks[step]]  # to handle truncated episode
            R = rewards[step] + self.gamma * R * not_term_masks[step]
            returns.insert(0, R)
        return returns

    @staticmethod
    def plot_training(frame_idx, rewards, losses):
        clear_output(True)
        plt.figure(figsize=(20,5))
        plt.subplot(131)
        plt.title('episode: {}, total reward(ma-10): {}'.format(len(rewards), np.mean(rewards[-10:])))
        plt.plot(np.array(rewards)[:100 * (len(rewards) // 100)].reshape(-1, 100).mean(axis=1))
        plt.subplot(132)
        plt.title('frame: {}, loss(ma-10): {:.4f}'.format(frame_idx, np.mean(losses[-10:])))
        plt.plot(losses)
        plt.show()

In [None]:
## Training

a2c_agent = A2CAgent(
    env_id=env_id,
    gamma=0.99,
    lr=5e-4,
    num_frames=50000,
    num_steps=10,
    batch_size=128,
    ppo_epochs=4,
    ppo_rel_batch_size=1.,
    ppo_clip_param=0.2)
a2c_agent.train()

In [None]:
## Visualization (Test)

env = gym.make(env_id, render_mode='human')
state, _ = env.reset()
done = False
while not done:
    action = a2c_agent.actor.greedy_act(np.expand_dims(state, 0))
    state, reward, terminated, truncated, _ = env.step(action[0])
    done = terminated or truncated
    env.render()
env.close()