In [24]:
from __future__ import annotations

import glob
import os
import re
import time

import supersuit as ss
from stable_baselines3 import PPO, DQN
from stable_baselines3.ppo import MlpPolicy

from pettingzoo.sisl import multiwalker_v9
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.results_plotter import load_results


logdir = "logs"

class RewardLoggerCallback(BaseCallback):
    """
    A custom callback that prints the reward at each step.
    """

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.
        """
        # Get the current reward
        reward = self.training_env.get_attr('reward')[0]

        # Print the reward
        print(f"Step: {self.num_timesteps}, Reward: {reward}")

        return True
    

def objective(trial):
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128, 256, 512])
    
    env = multiwalker_v9.parallel_env(n_walkers = 2)
    env = ss.frame_stack_v1(env, 3)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
    
    model = PPO(MlpPolicy,
                env,
                verbose=1,
                learning_rate=learning_rate
                ,batch_size=batch_size,
                normalize_advantage=True,
                n_steps=2048,
                n_epochs=10,
                gae_lambda=0.95,
                gamma=0.99,
                clip_range=0.2,
                ent_coef=0.001,
                )
    
    model.learn(total_timesteps=10000)
    #return the rewward of the last episode
    return model.ep_info_buffer[0]['r']
    
    
    
    

def train_butterfly_supersuit(
    env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
    # Train a single model to play as each agent in a cooperative Parallel environment
    env = env_fn.parallel_env(**env_kwargs, n_walkers = 2)

    env.reset(seed=seed)

    print(f"Starting training on {str(env.metadata['name'])}.")
    
    env = ss.frame_stack_v1(env, 3)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
    
    callback = RewardLoggerCallback()
    
    

    model = PPO(
        MlpPolicy,
        env,
        verbose=1,
        learning_rate=2.5e-4,
        batch_size=256,
        normalize_advantage=True,
        n_steps=2048,
        n_epochs=10,
        gae_lambda=0.95,
        gamma=0.99,
        clip_range=0.2,
        ent_coef=0.001,
        tensorboard_log=logdir,
        callback = callback
    )
    

    model.learn(total_timesteps=steps)

    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")

    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")

    env.close()


def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env = env_fn.env(render_mode=render_mode, **env_kwargs, n_walkers = 2)
    
    # Apply the same frame stacking to the evaluation environment
    env = ss.frame_stack_v1(env, 3)
    

    print(
        f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = PPO.load(latest_policy)

    rewards = {agent: 0 for agent in env.possible_agents}

    # Note: We train using the Parallel API but evaluate using the AEC API
    # SB3 models are designed for single-agent settings, we get around this by using the same model for every agent
    for i in range(num_games):
        env.reset(seed=i)   

        for agent in env.agent_iter():
            
            obs, reward, termination, truncation, info = env.last()

            for a in env.agents:
                rewards[a] += env.rewards[a]
            if termination or truncation:
                break
            else:
                act = model.predict(obs, deterministic=True)[0]

            env.step(act)
    env.close()

    avg_reward = sum(rewards.values()) / len(rewards.values())
    print("Rewards: ", rewards)
    print(f"Avg reward: {avg_reward}")
    return avg_reward

In [25]:
env_fn = multiwalker_v9
env_kwargs = {}

In [26]:
train_butterfly_supersuit(env_fn, **env_kwargs)

Starting training on multiwalker_v9.


TypeError: __init__() got an unexpected keyword argument 'callback'

In [None]:
import optuna
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100, timeout=600)

In [6]:

eval(env_fn, num_games=2, render_mode="human", **env_kwargs)


Starting evaluation on multiwalker_v9 (num_games=2, render_mode=human)
Rewards:  {'walker_0': 51.442794658243656, 'walker_1': 51.442794658243656}
Avg reward: 51.442794658243656


51.442794658243656