<a href="https://colab.research.google.com/github/majd-adawieh/multi-agent-rl/blob/main/multi_agent_pong.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pettingzoo
!pip install pettingzoo[all]
!pip install supersuit
!pip install multi-agent-ale-py
!pip install ale-py
!pip install AutoROM
!pip install pyvirtualdisplay pytorch-lightning
!AutoROM -y
!apt-get install -y xvfb

In [235]:
import random
import copy
from collections import deque
import itertools

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML
from pyvirtualdisplay import Display
from supersuit import resize_v1, color_reduction_v0, reshape_v0, frame_skip_v0
from pettingzoo.atari import pong_v3

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint



Display(visible=False, size=(1400, 900)).start()

<pyvirtualdisplay.display.Display at 0x7f5c7414fd10>

In [104]:
# Copied from: https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=gKc1FNhKiVJX
def display_video(frames, framerate=30):
  height, width, _ = frames[0].shape
  dpi = 70
  orig_backend = matplotlib.get_backend()
  matplotlib.use('Agg')
  fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
  matplotlib.use(orig_backend)
  ax.set_axis_off()
  ax.set_aspect('equal')
  ax.set_position([0, 0, 1, 1])
  im = ax.imshow(frames[0])
  def update(frame):
    im.set_data(frame)
    return [im]
  interval = 1000/framerate
  anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                  interval=interval, blit=True, repeat=False)
  return HTML(anim.to_html5_video())

In [139]:
def create_env(num_agents):
  env = pong_v3.env(num_players=num_agents, max_cycles=900)
  #env = resize_v1(env, x_size=64, y_size=64)
  env = frame_skip_v0(env, 16)
  #env = color_reduction_v0(env, 'full')
  return env

In [231]:
def slice_deque(buffer, start, stop, step):
    buffer.rotate(-start)
    slice = list(itertools.islice(buffer, 0, stop-start, step))
    buffer.rotate(start)
    return slice

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
        
    def sample(self, sample_size, sequence_len):
        batch = []
        while len(batch) < int(sample_size / 2) :
            start = random.randint(0, len(self.buffer) - sequence_len)
            sequenze_sample = slice_deque(self.buffer, start, start + sequence_len, 1)
            for sample in sequenze_sample:
                if sample[1] == 1 or sample[1] == -1:
                    batch += sequenze_sample
                    
        while len(batch) < sample_size:
            start = random.randint(0, len(self.buffer) - sequence_len)
            sequenze_sample = slice_deque(self.buffer, start, start + sequence_len, 1)
            batch += sequenze_sample
                    
        return batch

In [234]:
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400, sequence_len=5):
        self.buffer = buffer
        self.sample_size = sample_size
        self.sequence_len = sequence_len
        
    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size,self.sequence_len):
            yield experience

In [141]:
display_video(frames)

In [None]:
class DeepQLearning(LightningModule):
    def __init__(self, env_name, policy=epsilon_greedy, capacity=100_000, 
               batch_size=256, lr=1e-3, hidden_size=128, gamma=0.99, 
               loss_fn=nn.MSELoss(), optim=AdamW, eps_start=1.0, eps_end=0.15, 
               eps_last_episode=400, samples_per_epoch=1024, sync_rate=10,
               sequence_length = 8):
    
        super().__init__()
        self.env = create_environment()

        obs_size = self.env.observation_space.shape
        n_actions = self.env.action_space.n

        self.q_net = DRQN(obs_size, n_actions)

        self.target_q_net = copy.deepcopy(self.q_net)

        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)
        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f"{len(self.buffer)} samples in experience buffer. Filling...")
            self.play_episode(epsilon=self.hparams.eps_start)
            
    @torch.no_grad()
    def play_episode(self, policy=None, epsilon=0.):
        
        self.env.reset()
        done = False
        experience = []
        experiences = [None,None,None,None]
        for agent in self.env.agent_iter():
          state, last_reward , done , _  = self.env.last()
          last_exp = experiences[self.env.agents.index(agent)]
          if last_exp:
            last_exp[1] = last_reward
            last_exp[4] = state
            self.buffers[agent].append(tuple(last_exp))
            experiences[self.env.agents.index(agent)] = None
          if done:
            break
          
          if policy:
                action, hidden = policy(state.unsqueeze(dim=0), self.env, self.q_net, hidden, epsilon=epsilon)
          else:
                action = random.randint(0,5)## random

          self.env.step(action)
          exp = [state, None, action, done, None]
          experiences[self.env.agents.index(agent)] = exp
          
        self.env.close()
        
        
    def forward(self, x):
        return self.q_net(x)

    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer]

     # Create dataloader.
    def train_dataloader(self):
        dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch, self.hparams.sequence_length)
       
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size * self.hparams.sequence_length
        )
        return dataloader
    
    def training_step(self, batch, batch_idx):
        states, actions, rewards, dones, next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)
        
        stack_q_values = []
        stack_actions = []
        stack_rewards = []
        stack_dones = []
        hidden = None
        for i in range(self.hparams.sequence_length):
            stack_actions.append(actions[i::self.hparams.sequence_length])
            stack_dones.append(dones[i::self.hparams.sequence_length])
            stack_rewards.append(rewards[i::self.hparams.sequence_length])
            q_values, hidden  = self.q_net(states[i::self.hparams.sequence_length],hidden)
            stack_q_values.append(q_values)
        stack_q_values = torch.cat(stack_q_values,dim=0)
        stack_actions = torch.cat(stack_actions,dim=0)
        stack_dones = torch.cat(stack_dones,dim=0)
        stack_rewards = torch.cat(stack_rewards,dim=0)

    
        state_action_values = torch.gather(stack_q_values, -1, stack_actions)
        
        
        with torch.no_grad():
            stack_next_q_values = []
            hidden = None
            for i in range(self.hparams.sequence_length):
                q_values, hidden = self.target_q_net(next_states[i::self.hparams.sequence_length],hidden)
                stack_next_q_values.append(q_values)
                
        stack_next_q_values = torch.cat(stack_next_q_values,dim=0)
        next_action_values = torch.max(stack_next_q_values, dim=1)[0].unsqueeze(dim=1)
        next_action_values[stack_dones] = 0.0

        expected_state_action_values = stack_rewards + self.hparams.gamma * next_action_values
        
       
        loss = self.hparams.loss_fn(state_action_values.float(), expected_state_action_values.float())
        self.log('episode/Q-Error', loss)
        return loss
    
    # Training epoch end.
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        self.play_episode(policy=self.policy, epsilon=epsilon)
        self.log('episode/Return', self.env.return_queue[-1])

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            
            
    def save_model(self):
        torch.save(self.q_net.state_dict(), "./model")
        
    def load_model(self):
        self.q_net.load_state_dict(torch.load( "./model"))
