In [None]:
! pip install pettingzoo[mpe]

In [7]:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from pettingzoo.mpe import simple_speaker_listener_v3, simple_reference_v2, simple_world_comm_v2
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
import tensorflow as tf
from collections import deque
import time
import random
from tqdm import tqdm
import os
import json
import matplotlib.pyplot as plt
import pickle

In [1]:
#Global Variables
DISCOUNT = 0.99
MINIBATCH_SIZE = 32  # How many steps (samples) to use for training
UPDATE_TARGET_EVERY = 5  # Terminal states (end of episodes)

# Environments
EPISODES = 500
MAX_CYCLES = 25

REPLAY_MEMORY_SIZE = int(EPISODES * MAX_CYCLES / 5)  # How many last steps to keep for model training
CRITIC_MIN_REPLAY_MEMORY_SIZE = int(REPLAY_MEMORY_SIZE / 5 /2)
AGENT_MIN_REPLAY_MEMORY_SIZE = CRITIC_MIN_REPLAY_MEMORY_SIZE + 200  # Minimum number of steps in a memory to start training

# Exploration settings
EPSILON = 0.1  # decaying epsilon
EPSILON_DECAY = 0.99975
MIN_EPSILON = 0.001

# Checkpointing
checkpoint_path_sp = "./models/speaker"
checkpoint_path_ls = "./models/listener"
checkpoint_path_critic = "./models/critic"
checkpoint_dir_sp = os.path.dirname(checkpoint_path_sp)
checkpoint_dir_ls = os.path.dirname(checkpoint_path_ls)
checkpoint_dir_critic = os.path.dirname(checkpoint_path_critic)

cp_callback_sp = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path_sp, 
    verbose=0, 
    save_weights_only=True,
    save_freq=50)

cp_callback_ls = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path_ls, 
    verbose=0, 
    save_weights_only=True,
    save_freq=50)

cp_callback_critic = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path_critic, 
    verbose=0, 
    save_weights_only=True,
    save_freq=50)

## Nash DQN

#### Replay Memory

The reason why this is implemented as a seperate class is because the data in this memory will be shared across the listener, speaker and centralized DQN network. Hence, saving into one object saves memory at runtime (i.e. no multiple appending).

In [2]:
class ReplayMemory:
    def __init__(self, max_len = REPLAY_MEMORY_SIZE):
        # self.replay_memory = deque(maxlen=max_len)
        with open('./models/replay_buffer/my_deque.pickle', 'rb') as f:
            # use pickle.load to deserialize the deque object from the file
            self.replay_memory = pickle.load(f)

    def add_sample(self, sample):
        # sp = speaker, ls = listener, the format of a sample is:
        # (S_sp, S_ls, A_sp, A_ls, R_sp, R_ls, S_next_sp, S_next_ls, done_sp, done_ls)
        self.replay_memory.append(sample)

    def get_size(self):
        return len(self.replay_memory)
    
    def sample_minibatch(self, minibatch_size = MINIBATCH_SIZE):
        return random.sample(self.replay_memory, minibatch_size)
    
    def get_mem(self):
        return self.replay_memory

#### Nash DQN agent (For each game env agent)

In [3]:
class NashDQNAgent:
    def __init__(self, input_layer_size, action_space_size, ReplayMemoryObject, is_speaker, critic):
        # Check if this agent is a speaker, if not then listener
        self.is_speaker = is_speaker

        # Main model which we use to train
        self.model = self.create_model(input_layer_size, action_space_size)

        # Target network to make sure the updating is stable
        self.target_model = self.create_model(input_layer_size, action_space_size)
        self.target_model.set_weights(self.model.get_weights())

        # The array to keep the memory for the last n steps for training
        self.replay_memory = ReplayMemoryObject

        # Count when to update target network with main network's weights
        self.target_update_counter = 0

        # Add the critic -- a centalized network to give Q values for joint actions by inputing joint observations
        self.critic = critic

    def create_model(self, input_layer_size, action_space_size):
        model = Sequential()
        model.add(Dense(64, activation='relu', input_shape=(input_layer_size,)))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(action_space_size, activation = 'linear'))
        model.compile(loss="mse", optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
        model.load_weights(checkpoint_dir_sp) if self.is_speaker else model.load_weights(checkpoint_dir_ls)
        return model

    def train(self, terminal_state):
        # Start training only if enough transition samples has been collected in the memory
        if self.replay_memory.get_size() < AGENT_MIN_REPLAY_MEMORY_SIZE:
            return

        # Get a minibatch from memory replay table
        minibatch = self.replay_memory.sample_minibatch(minibatch_size = MINIBATCH_SIZE)

        # Get the current states and their corresponding q values for each sample in the minibatch
        current_states = np.array([transition[0] for transition in minibatch]) if self.is_speaker else np.array([transition[1] for transition in minibatch])
        current_qs_list = self.model.predict(current_states, verbose=0)

        X = []
        y = []

        # for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):
        for index, (S_sp, S_ls, A_sp, A_ls, R_sp, R_ls, S_next_sp, S_next_ls, done_sp, done_ls) in enumerate(minibatch):
            done = done_sp if self.is_speaker else done_ls
            reward = R_sp if self.is_speaker else R_ls
            action = A_sp if self.is_speaker else A_ls
            current_state = S_sp if self.is_speaker else S_ls

            if not done:
                # Calculate Nash Q using the centralized network
                joint_observation = np.concatenate((S_next_sp, S_next_ls), axis=None)
                joint_q_vals = self.critic.get_qs(joint_observation)
                nash_q = np.max(joint_q_vals)
                # Nash update
                new_q = reward + DISCOUNT * nash_q 
            else:
                new_q = reward
            
            # Update Q value for the given state
            current_qs = current_qs_list[index]
            current_qs[action] = new_q

            # Prepare training data
            X.append(current_state)
            y.append(current_qs)

        if self.is_speaker:
            self.model.fit(np.array(X), np.array(y), batch_size=MINIBATCH_SIZE, shuffle=False, verbose=0, callbacks=[cp_callback_sp])
        else:
            self.model.fit(np.array(X), np.array(y), batch_size=MINIBATCH_SIZE, shuffle=False, verbose=0, callbacks=[cp_callback_ls])


        if terminal_state:
            self.target_update_counter += 1
        
        # update target network with weights of main network if condition satisfied
        if self.target_update_counter > UPDATE_TARGET_EVERY:
            self.target_model.set_weights(self.model.get_weights())
            self.target_update_counter = 0

    # Queries main network for Q values given current observation space (environment state)
    def get_qs(self, state):
        return self.model.predict(np.array(state).reshape(-1, *state.shape), verbose=0)[0]

#### Critic DQN Agent (For providing the Nash Q value of joint states/actions)

This DQN Agent serves as the critic for our Nash DQN algorithm. It takes in the joint states observed by the two agents, and then output an array of Q_valus that corresponds to each combination of agents'action. In a fully collaborative settings, we know that for a given state, the joint actions that lead to the maximal Q value is the nash equilibria move and this maximal Q value is the Nash_Q value.

In [4]:
class JointCritic:
    def __init__(self, input_layer_size, action_space_size, ReplayMemoryObject):
        # Main model which we use to train
        self.model = self.create_model(input_layer_size, action_space_size)

        # Target network to make sure the updating is stable
        self.target_model = self.create_model(input_layer_size, action_space_size)
        self.target_model.set_weights(self.model.get_weights())

        # The array to keep the memory for the last n steps for training
        self.replay_memory = ReplayMemoryObject

        # Count when to update target network with main network's weights
        self.target_update_counter = 0


    def create_model(self, input_layer_size, action_space_size):
        model = Sequential()
        model.add(Dense(64, activation='relu', input_shape=(input_layer_size,)))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(action_space_size, activation = 'linear'))
        model.compile(loss="mse", optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
        model.load_weights(checkpoint_path_critic)
        return model
    
    def train(self, terminal_state):
        # Start training only if enough transition samples has been collected in the memory
        if self.replay_memory.get_size() < CRITIC_MIN_REPLAY_MEMORY_SIZE:
            return

        # Get a minibatch from memory replay table
        minibatch = self.replay_memory.sample_minibatch(minibatch_size = MINIBATCH_SIZE)

        # Get the current states and their corresponding q values for each sample in the minibatch
        current_states = np.array([np.concatenate((transition[0], transition[1]), axis=None) for transition in minibatch])
        current_qs_list = self.model.predict(current_states, verbose=0)

        # Get the next states their corresponding q values for each sample in the minibatch
        new_current_states = np.array([np.concatenate((transition[6], transition[7]), axis=None) for transition in minibatch])
        future_qs_list = self.target_model.predict(new_current_states, verbose=0)

        X = []
        y = []

        # for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):
        for index, (S_sp, S_ls, A_sp, A_ls, R_sp, R_ls, S_next_sp, S_next_ls, done_sp, done_ls) in enumerate(minibatch):
            done = done_sp or done_ls
            if not done:
                max_future_q = np.max(future_qs_list[index])
                new_q = R_sp + R_ls + DISCOUNT * max_future_q
            else:
                new_q = R_sp + R_ls
            
            # Update Q value for the given state
            current_qs = current_qs_list[index]
            action_idx = np.ravel_multi_index((A_sp, A_ls), dims=(3, 5))
            current_qs[action_idx] = new_q

            # Prepare training data
            X.append(current_states[index])
            y.append(current_qs)

        self.model.fit(np.array(X), np.array(y), batch_size=MINIBATCH_SIZE, shuffle=False, verbose=0, callbacks=[cp_callback_critic])

        if terminal_state:
            self.target_update_counter += 1
        
        # update target network with weights of main network if condition satisfied
        if self.target_update_counter > UPDATE_TARGET_EVERY:
            self.target_model.set_weights(self.model.get_weights())
            self.target_update_counter = 0

    # Queries main network for Q values given current observation space (environment state)
    def get_qs(self, state):
        return self.model.predict(np.array(state).reshape(-1, *state.shape), verbose=0)[0]

#### Experiments

In [5]:
def eps_greedy_act_selection(epsilon, action_space_size, q_values):
    if np.random.random() < epsilon:
        # randomly choose one action
        return np.random.randint(0, action_space_size)
    else:
        # all q values
        return np.argmax(q_values)

In [None]:
AGENT_NAMES = ['speaker_0', 'listener_0']
AGENT_INFOS = {name: {"agent_idx": 0 if name == 'speaker_0' else 1,
                        "action_space_size": 3 if name == 'speaker_0' else 5,
                        "input_layer_size": 3 if name == 'speaker_0' else 11,
                        "is_speaker": True if name == 'speaker_0' else False
                        } for name in AGENT_NAMES}

UPDATE_COUNTER = 0
# ALL_REWARDS = {agent_name:[] for agent_name in AGENT_NAMES}
with open('./models/reward_dict/my_dict.json', 'r') as f:
    # use json.load to deserialize the JSON data from the file to a dictionary object
    ALL_REWARDS = json.load(f)

epsilon = EPSILON

# Create a replay buffer
replay_buff = ReplayMemory(max_len = REPLAY_MEMORY_SIZE)

# Create the critic DQN Agent
critic_dqn = JointCritic(input_layer_size=14, action_space_size=15, ReplayMemoryObject=replay_buff)

# Create the Nash DQN Agents
NASH_DQN_AGENTS = {name: NashDQNAgent(input_layer_size=AGENT_INFOS[name]["input_layer_size"],
                                      action_space_size=AGENT_INFOS[name]["action_space_size"],
                                      ReplayMemoryObject=replay_buff,
                                      is_speaker=AGENT_INFOS[name]["is_speaker"],
                                      critic=critic_dqn) for name in AGENT_NAMES}

# Create the environment
env = simple_speaker_listener_v3.env(max_cycles=MAX_CYCLES, continuous_actions=False)

for episode in tqdm(range(1, EPISODES + 1), ascii=True, unit='episodes'):
    # Reset the environment and the reward for this new episode
    env.reset()
    episode_agent_reward = {agent_name:0 for agent_name in AGENT_NAMES}

    # Initialize the SARSD for collecting and building dataset later
    S_sp=S_ls=A_sp=A_ls=R_sp=R_ls=S_next_sp=S_next_ls=done_sp=done_ls=None

    for agent in env.agent_iter():
        if env.truncations[agent] == True or env.terminations[agent] == True:
                env.step(None)
                continue
        
        # Get the current agent
        nash_dqn_agent = NASH_DQN_AGENTS[agent]
        # Observe the current state
        state_curr = env.observe(agent)
        # Get the Q values for each action of the curr state
        q_state_curr = nash_dqn_agent.get_qs(state_curr)
        # Choose and take an action
        action = eps_greedy_act_selection(epsilon, AGENT_INFOS[agent]["action_space_size"], q_state_curr)
        env.step(action)
        # Get reward and accumulate it
        _, reward, termination, truncation, info = env.last()
        episode_agent_reward[agent] += reward
        # Observe the next state
        state_next = env.observe(agent)

        # Update the variables
        if nash_dqn_agent.is_speaker:
            S_sp = state_curr
            A_sp = action
            R_sp = reward
            S_next_sp = state_next
            done_sp = termination or truncation
        else:
            S_ls = state_curr
            A_ls = action
            R_ls = reward
            S_next_ls = state_next
            done_ls = termination or truncation

        # Increase the update counter
        UPDATE_COUNTER += 1
        
        # Add this sample to the replay buffer after each agent takes a move
        if UPDATE_COUNTER == 2:
            UPDATE_COUNTER = 0
            transition = (S_sp, S_ls, A_sp, A_ls, R_sp, R_ls, S_next_sp, S_next_ls, done_sp, done_ls)
            replay_buff.add_sample(transition)

            # Clean up after adding
            S_sp=S_ls=A_sp=A_ls=R_sp=R_ls=S_next_sp=S_next_ls=done_sp=done_ls=None

            # Trains the centralized critic
            critic_dqn.train(done_sp or done_ls)
        
        # Trains the agent if has enough data
        nash_dqn_agent.train(termination or truncation)

    # store the total rewards for last game play in one episode
    for name in AGENT_NAMES:
        ALL_REWARDS[name].append(episode_agent_reward[name]/MAX_CYCLES)

    # Peform epsilon decay
    if epsilon > MIN_EPSILON:
            epsilon *= EPSILON_DECAY
            epsilon = max(MIN_EPSILON, epsilon)


# Finally, plot the average reward per step per episode per agent
# plt.plot(range(EPISODES), ALL_REWARDS['speaker_0'])
# plt.title('Avg reward per step for speaker_0')
# plt.xlabel('num_episodes')
# plt.ylabel('reward')
# plt.show()

# plt.plot(range(EPISODES), ALL_REWARDS['listener_0'])
# plt.title('Avg reward per step for listener_0')
# plt.xlabel('num_episodes')
# plt.ylabel('reward')
# plt.show()

# Save the replay buffer
with open('./models/replay_buffer/my_deque.pickle', 'wb') as f:
    pickle.dump(replay_buff.replay_memory, f)

# Save the current reward dictionary
with open('./models/reward_dict/my_dict.json', 'w') as f:
    # use json.dump to serialize the dictionary object to JSON and write it to the file
    json.dump(ALL_REWARDS, f)

os.system("tmux wait-for -S script_finished")
