# A3C for Kung Fu

## Part 0 - Installing the required packages and importing the libraries

### Installing Gymnasium

In [11]:
# !pip install gymnasium
# !pip install "gymnasium[atari, accept-rom-license]"
# !apt-get install -y swig
# !pip install gymnasium[box2d]

### Importing the libraries

In [12]:
import cv2
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributions as distributions
from torch.distributions import Categorical
import gymnasium as gym
from gymnasium import ObservationWrapper
from gymnasium.spaces import Box

## Part 1 - Building the AI

### Creating the architecture of the Neural Network

In [13]:
class Network(nn.Module):
    def __init__(self, action_size):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512, 128)
        self.fc2a = nn.Linear(128, action_size)
        self.fc2s = nn.Linear(128, 1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        action_value = self.fc2a(x)
        state_value = self.fc2s(x)[0]
        return action_value, state_value


## Part 2 - Training the AI

### Setting up the environment

In [14]:
class PreprocessAtari(ObservationWrapper):

  def __init__(self, env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4):
    super(PreprocessAtari, self).__init__(env)
    self.img_size = (height, width)
    self.crop = crop
    self.dim_order = dim_order
    self.color = color
    self.frame_stack = n_frames
    n_channels = 3 * n_frames if color else n_frames
    obs_shape = {'tensorflow': (height, width, n_channels), 'pytorch': (n_channels, height, width)}[dim_order]
    self.observation_space = Box(0.0, 1.0, obs_shape)
    self.frames = np.zeros(obs_shape, dtype = np.float32)

  def reset(self):
    self.frames = np.zeros_like(self.frames)
    obs, info = self.env.reset()
    self.update_buffer(obs)
    return self.frames, info

  def observation(self, img):
    img = self.crop(img)
    img = cv2.resize(img, self.img_size)
    if not self.color:
      if len(img.shape) == 3 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.astype('float32') / 255.
    if self.color:
      self.frames = np.roll(self.frames, shift = -3, axis = 0)
    else:
      self.frames = np.roll(self.frames, shift = -1, axis = 0)
    if self.color:
      self.frames[-3:] = img
    else:
      self.frames[-1] = img
    return self.frames

  def update_buffer(self, obs):
    self.frames = self.observation(obs)

def make_env():
  env = gym.make("MsPacmanDeterministic-v4", render_mode = 'rgb_array')
  env = PreprocessAtari(env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4)
  return env

env = make_env()

state_shape = env.observation_space.shape
number_actions = env.action_space.n
print("Observation shape:", state_shape)
print("Number actions:", number_actions)
print("Action names:", env.unwrapped.get_action_meanings())

Observation shape: (4, 42, 42)
Number actions: 9
Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT']


### Initializing the hyperparameters

In [15]:
learning_rate = 1e-4
gamma = 0.99
no_of_env = 10

### Implementing the A3C class

In [16]:
class Agent:
  def __init__(self, action_size) -> None:
    self.device = torch.device("cpu")
    self.action_size = action_size
    self.network = Network(action_size).to(self.device)
    self.optimizer = optim.Adam(self.network.parameters(), lr = learning_rate)
  
  def act(self, state):
    if state.ndim == 3:
      state = [state]
    state = torch.tensor(state, dtype = torch.float32, device = self.device)
    action_values, _ = self.network(state)
    policy = F.softmax(action_values, -1)
    return np.array([np.random.choice(len(p), p = p) for p in policy.detach().numpy()])
    
  def step(self, state, action, reward, next_state, done):
    batch_size = state.shape[0]
    state = torch.tensor(state, dtype = torch.float32, device = self.device)
    next_state = torch.tensor(next_state, dtype = torch.float32, device = self.device)
    reward = torch.tensor(reward, dtype = torch.float32, device = self.device)
    done = torch.tensor(done, dtype = torch.bool, device = self.device).to(dtype = torch.float32)
    action_values, state_value = self.network(state)
    _, next_state_value = self.network(next_state)
    target_state_value = reward + gamma * next_state_value * (1 - done)
    advantage = target_state_value - state_value
    probs = F.softmax(action_values, dim = -1)
    logprobs = F.log_softmax(action_values, dim = -1)
    entropy = -torch.sum(probs * logprobs, dim = -1)
    batch_idx = np.arange(batch_size)
    logp_actions = logprobs[batch_idx, action]
    actor_loss = -(logp_actions * advantage.detach()).mean() + 0.001 * entropy.mean()
    critic_loss = F.mse_loss(target_state_value.detach(), state_value).to(dtype = torch.float32)
    loss = (actor_loss + critic_loss)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    

### Initializing the A3C agent

In [17]:
agent = Agent(number_actions)

### Evaluating our A3C agent on a certain number of episodes

In [18]:
def evaluate(agent, env, n_episodes = 1):
  episodes_rewards = []
  for _ in range(n_episodes):
    state, _ = env.reset()
    total_reward = 0
    while True:
      action = agent.act(state)
      state, reward, done, info, _ = env.step(action[0])
      total_reward += reward
      if done:
        break
    episodes_rewards.append(total_reward)
  return episodes_rewards

### Testing multiple agents on multiple environments at the same time

In [19]:
class EnvBatch:
  
  def __init__(self, n_envs= 10) -> None:
    self.envs = [make_env() for _ in range(n_envs)]
  
  def reset(self):
    return np.array([env.reset()[0] for env in self.envs])
  
  def step(self, actions):
    new_states, rewards, dones, infos, _ = map(np.array, zip(*[env.step(action) for env, action in zip(self.envs, actions)]))
    for i in range(len(self.envs)):
      if dones[i]:
        new_states[i] = self.envs[i].reset()[0]
    return new_states, rewards, dones, infos

### Training the A3C agent

In [20]:
import tqdm

env_batch = EnvBatch(no_of_env)
batch_states = env_batch.reset()

for i in range(1, 2001):
  batch_actions = agent.act(batch_states)
  batch_next_states, batch_rewards, batch_dones, _ = env_batch.step(batch_actions)
  batch_rewards *= 0.01
  agent.step(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
  batch_states = batch_next_states
  
  print('Computing... ', i, end = '\r')
  
  if i % 1000 == 0:
    print('Computing rewards after', i, 'iterations', end = '\r')
    rewards = evaluate(agent, env, 10)
    print(f"Mean reward: {np.mean(rewards):.3f} +/- {np.std(rewards):.3f}")
      


  critic_loss = F.mse_loss(target_state_value.detach(), state_value).to(dtype = torch.float32)


Computing rewards after 1000 iterations

  state = torch.tensor(state, dtype = torch.float32, device = self.device)


Mean reward: 396.000 +/- 353.757
Mean reward: 253.000 +/- 72.395erations


## Part 3 - Visualizing the results

In [22]:
import glob
import io
import base64
import imageio
from IPython.display import HTML, display
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder
os.environ['IMAGEIO_FFMPEG_EXE'] = "/opt/homebrew/bin/ffmpeg"

def show_video_of_model(agent, env):
  state, _ = env.reset()
  done = False
  frames = []
  while not done:
    frame = env.render()
    frames.append(frame)
    action = agent.act(state)
    state, reward, done, _, _ = env.step(action[0])
  env.close()
  imageio.mimsave('video.mp4', frames, fps=30)

show_video_of_model(agent, env)

def show_video():
    mp4list = glob.glob('*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        display(HTML(data='''<video alt="test" autoplay
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else:
        print("Could not find video")

show_video()

