<a href="https://colab.research.google.com/github/majd-adawieh/multi-agent-rl/blob/main/multi_agent_pong.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pettingzoo
!pip install pettingzoo[all]
!pip install supersuit
!pip install multi-agent-ale-py
!pip install ale-py
!pip install AutoROM
!pip install pyvirtualdisplay pytorch-lightning
!AutoROM -y
!apt-get install -y xvfb

In [46]:
import random
import copy
from collections import deque
import itertools

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML
from pyvirtualdisplay import Display
from supersuit import resize_v1, color_reduction_v0, normalize_obs_v0, reshape_v0, frame_skip_v0
from pettingzoo.atari import pong_v3

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint



Display(visible=False, size=(1400, 900)).start()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

In [3]:
# Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=gKc1FNhKiVJX
def display_video(frames, framerate=30):
  height, width, _ = frames[0].shape
  dpi = 70
  orig_backend = matplotlib.get_backend()
  matplotlib.use('Agg')
  fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
  matplotlib.use(orig_backend)
  ax.set_axis_off()
  ax.set_aspect('equal')
  ax.set_position([0, 0, 1, 1])
  im = ax.imshow(frames[0])
  def update(frame):
    im.set_data(frame)
    return [im]
  interval = 1000/framerate
  anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                  interval=interval, blit=True, repeat=False)
  return HTML(anim.to_html5_video())

In [69]:
def create_env(num_agents):
  env = pong_v3.env(num_players=num_agents, max_cycles=900)
  env = frame_skip_v0(env, 8)
  env = resize_v1(env, x_size=64, y_size=64)
  #env = color_reduction_v0(env, 'full')
  env = reshape_v0(env, (3,64,64)) 
  return env

def normalize_observations(o):
  return np.interp(o, (o.min(), o.max()), (0, +1))


env = create_env(2)

frames = []

def play_episode():
  env.reset()
  done = False  
  last_observation = env.observe(env.agent_selection)
  last_action = 0
  while not done:
    current_agent = env.agent_selection
    if current_agent == "first_0":
      new_observation, last_reward, done, truncated, _ = env.last()
      frames.append(new_observation.reshape((64,64,3)))
      if truncated:
        done = truncated
      action = random.randint(0,5)
      if not done:
        env.step(action)
        exp = (last_observation, last_action, last_reward, done, new_observation)
        print(f"last_action {last_action}, last_reward {last_reward}")
        last_action = action
        last_observation = new_observation
    else:
      action = random.randint(0,5) # random
      env.step(action)


play_episode()



last_action 0, last_reward 0.0
last_action 5, last_reward 0.0
last_action 3, last_reward 0.0
last_action 3, last_reward 0.0
last_action 3, last_reward 0.0
last_action 4, last_reward 0.0
last_action 4, last_reward 0.0
last_action 0, last_reward 0.0
last_action 5, last_reward 0.0
last_action 0, last_reward 0.0
last_action 3, last_reward 0.0
last_action 1, last_reward 0.0
last_action 5, last_reward 0.0
last_action 1, last_reward 0.0
last_action 3, last_reward 0.0
last_action 3, last_reward 0.0
last_action 3, last_reward 0.0
last_action 0, last_reward 1.0
last_action 3, last_reward 0.0
last_action 1, last_reward 0.0
last_action 4, last_reward 0.0
last_action 4, last_reward 0.0
last_action 2, last_reward 0.0
last_action 0, last_reward 0.0
last_action 4, last_reward 0.0
last_action 2, last_reward 0.0
last_action 0, last_reward 0.0
last_action 2, last_reward 0.0
last_action 2, last_reward 0.0
last_action 0, last_reward 0.0
last_action 1, last_reward 0.0
last_action 0, last_reward 0.0
last_act

In [70]:
display_video(frames)

In [44]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [45]:
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400):
        self.buffer = buffer
        self.sample_size = sample_size
    
    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

In [71]:
def epsilon_greedy(state, env, net, epsilon=0.0):
    if np.random.random() < epsilon:
        action = env.action_space.sample()
    else:
        state = state.to(device)
        q_values = net(state)
        _, action = torch.max(q_values, dim=1)
        action = int(action.item())
    return action

In [47]:
class DQN(nn.Module):
    def __init__(self, state_size , n_actions):
        super(DQN, self).__init__()
        
        self.state_size = state_size
        self.conv = nn.Sequential(
                        nn.Conv2d(state_size[0], 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU()
                    )
        conv_out_size = self._get_conv_out(state_size)
        self.fc1 = nn.Linear(conv_out_size, 256)
        self.fc_adv = nn.Linear(256, n_actions) 
        self.fc_value = nn.Linear(256, 1)
        
    def _get_conv_out(self, shape):
        conv_out = self.conv(torch.zeros(1, *shape))
        return int(np.prod(conv_out.size()))
    
    def forward(self, x):        
        o = self.conv(x.float()).view(x.shape[0], -1)
        o = F.relu(self.fc1(o))
        
        adv = self.fc_adv(o)
        value = self.fc_value(o)  
        
        return value + adv - torch.mean(adv, dim=1, keepdim=True)

In [None]:
class DeepQLearning(LightningModule):
    def __init__(self, env_name, policy=epsilon_greedy, capacity=100_000, 
               batch_size=256, lr=1e-3, hidden_size=128, gamma=0.99, 
               loss_fn=nn.MSELoss(), optim=AdamW, eps_start=1.0, eps_end=0.15, 
               eps_last_episode=400, samples_per_epoch=1024, sync_rate=10,
               sequence_length = 8):
    
        super().__init__()
        self.env = create_env(env_name, 2)

        obs_size = (1, 64, 64)
        n_actions = 6

        self.q_net = DQN(obs_size, n_actions)

        self.target_q_net = copy.deepcopy(self.q_net)

        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)
        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f"{len(self.buffer)} samples in experience buffer. Filling...")
            self.play_episode(epsilon=self.hparams.eps_start)
            
    @torch.no_grad()
    def play_episode(self, policy=None, epsilon=0.):
        state  = self.env.reset()
        state  = torch.from_numpy(state[0]).unsqueeze(dim=0)
        done = False
        
        while not done:
            if policy:
                action = policy(state.unsqueeze(dim=0), self.env, self.q_net, epsilon=epsilon)
            else:
                action = self.env.action_space.sample()
            next_state, reward, done, tru , _ = self.env.step(action)
            if tru:
                done = tru
            
            next_state = torch.from_numpy(next_state).unsqueeze(dim=0) 
            exp = (state, action, reward, done, next_state)
            
            self.buffer.append(exp)
            state = next_state
            
        self.env.close()
        
        
    def forward(self, x):
        return self.q_net(x)

    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer]

     # Create dataloader.
    def train_dataloader(self):
        dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
       
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size
        )
        return dataloader
    
    def training_step(self, batch, batch_idx):
        states, actions, rewards, dones, next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)
        
        state_action_values = self.q_net(states).gather(1, actions)

        next_action_values, _ = self.target_q_net(next_states).max(dim=1, keepdim=True)
        next_action_values[dones] = 0.0

        expected_state_action_values = rewards + self.hparams.gamma * next_action_values

        loss = self.hparams.loss_fn(state_action_values.float(), expected_state_action_values.float())
        self.log('episode/Q-Error', loss)
        return loss
    
    # Training epoch end.
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        self.play_episode(policy=self.policy, epsilon=epsilon)
        self.log('episode/Return', self.env.return_queue[-1])

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            
            
    def save_model(self):
        torch.save(self.q_net.state_dict(), "./model")
        
    def load_model(self):
        self.q_net.load_state_dict(torch.load( "./model"))
