<a href="https://colab.research.google.com/github/zhongjie-wu/579project/blob/main/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
! pip install pettingzoo[mpe]

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pettingzoo[mpe]
  Downloading PettingZoo-1.22.3-py3-none-any.whl (816 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m816.1/816.1 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gymnasium>=0.26.0
  Downloading gymnasium-0.28.1-py3-none-any.whl (925 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
Collecting pygame==2.1.3.dev8
  Downloading pygame-2.1.3.dev8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.7/13.7 MB[0m [31m56.2 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting jax-jumpy>=1.0.0
  Downloading jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Installing collected packages: farama-notif

In [15]:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
import tensorflow as tf
from collections import deque
import time
import random
from tqdm import tqdm
import os

from pettingzoo.mpe import simple_speaker_listener_v3, simple_reference_v2, simple_world_comm_v2

In [11]:
#Global Variables
DISCOUNT = 0.99
REPLAY_MEMORY_SIZE = 50_000  # How many last steps to keep for model training
MIN_REPLAY_MEMORY_SIZE = 1_000  # Minimum number of steps in a memory to start training
MINIBATCH_SIZE = 64  # How many steps (samples) to use for training
UPDATE_TARGET_EVERY = 5  # Terminal states (end of episodes)
MODEL_NAME = 'model'
MIN_REWARD = -200  # For model save
MEMORY_FRACTION = 0.20

# Environments
EPISODES = 100

# Exploration settings
epsilon = 1  # decaying epsilon
EPSILON_DECAY = 0.99975
MIN_EPSILON = 0.001

#  Stats settings
AGGREGATE_STATS_EVERY = 50  # episodes

In [12]:
class DQNAgent:
    def __init__(self, input_layer_size, action_space_size):
        # Main model which we use to train
        self.model = self.create_model(input_layer_size, action_space_size)

        # Target network to make sure the updating is stable
        self.target_model = self.create_model(input_layer_size, action_space_size)
        self.target_model.set_weights(self.model.get_weights())

        # The array to keep the memory for the last n steps for training
        self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)

        # Count when to update target network with main network's weights
        self.target_update_counter = 0


    def create_model(self, input_layer_size, action_space_size):
        model = Sequential()
        model.add(Dense(64, activation='relu', input_shape=(input_layer_size,)))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(action_space_size, activation = 'linear'))

        return model

    def update_replay_memory(self, transition):
        self.replay_memory.append(transition)

    def train(self, terminal_state, step):
        # Start training only if enough transition samples has been collected in the memory
        if len(self.replay_memory) < MIN_REPLAY_MEMORY_SIZE:
            return

        # Get a minibatch from memory replay table
        minibatch = random.sample(self.replay_memory, MINIBATCH_SIZE)

        # Get the current states and their corresponding q values for each sample in the minibatch
        current_states = np.array([transition[0] for transition in minibatch])
        current_qs_list = self.model.predict(current_states)

        # Get the next states their corresponding q values for each sample in the minibatch
        new_current_states = np.array([transition[3] for transition in minibatch])
        future_qs_list = self.target_model.predict(new_current_states)

        X = []
        y = []

        for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):
            if not done:
                max_future_q = np.max(future_qs_list[index])
                new_q = reward + DISCOUNT * max_future_q
            else:
                new_q = reward
            
            # Update Q value for the given state
            current_qs = current_qs_list[index]
            current_qs[action] = new_q

            # Prepare training data
            X.append(current_state)
            y.append(current_qs)

        self.model.fit(np.array(X), np.array(y), batch_size=MINIBATCH_SIZE, shuffle=False)

        if terminal_state:
            self.target_update_counter += 1
        
        # update target network with weights of main network if condition satisfied
        if self.target_update_counter > UPDATE_TARGET_EVERY:
            self.target_model.set_weights(self.model.get_weights())
            self.target_update_counter = 0

    # Queries main network for Q values given current observation space (environment state)
    def get_qs(self, state):
        return self.model.predict(np.array(state).reshape(-1, *state.shape))[0]


In [13]:
listener_dqn = DQNAgent(input_layer_size=11, action_space_size=5)
speaker_dqn = DQNAgent(input_layer_size=3, action_space_size=3)
env = simple_speaker_listener_v3.env(max_cycles=200, continuous_actions=False)

for episode in tqdm(range(1, EPISODES + 1), ascii=True, unit='episodes'):

    # Restarting episode and environment
    episode_reward = 0
    step = 1
    env.reset()
    current_state_speaker =
    current_state_listener = 

    for agent in env.agent_iter():
        if agent == 'speaker_0':
            if np.random.random() > epsilon:
            # Get action from Q table
                action = np.argmax(agent.get_qs(current_state))
            else:
            # Get random action
                action = np.random.randint(0, env.ACTION_SPACE_SIZE)



