<a href="https://colab.research.google.com/github/kuds/rl-connect-four/blob/main/%5BConnect%20Four%5D%20Self%20Play.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Playing Connect Four using Self Play

In [None]:
!pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl

In [None]:
!pip install gputil open_spiel gymnasium

In [None]:
import functools
import numpy as np
import multiprocessing as mp
import ray
from ray import tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule
from ray.rllib.examples.multi_agent.utils import (
    ask_user_for_action,
    SelfPlayCallback,
    SelfPlayCallbackOldAPIStack,
)
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env
import platform

import torch
from importlib.metadata import version

# Import after try_import_open_spiel, so we can error out with hints.
from open_spiel.python.rl_environment import Environment  # noqa: E402

In [None]:
print(f"Python Version: {platform.python_version()}")
print(f"Torch Version: {version('torch')}")
print(f"Is Cuda Available: {torch.cuda.is_available()}")
print(f"Cuda Version: {torch.version.cuda}")
print(f"Numpy Version: {version('numpy')}")
print(f"Stable Baselines3 Version: {version('open_spiel')}")
print(f"Ray Version: {version('ray')}")
print(f"Gymnasium Version: {version('Gymnasium')}")
print(f"Open Spiel Version: {version('open_spiel')}")

In [None]:
print(f"Number or CPUs Available: {mp.cpu_count()}")

In [None]:
open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)

In [None]:
class Args:
    def __init__(self):
        self.env = "connect_four"
        self.checkpoint_freq = 1
        self.checkpoint_at_end = True
        self.win_rate_threshold = 0.95
        self.min_league_size = 3
        self.num_episodes_human_play = 10
        self.from_checkpoint = None
        # Add other necessary attributes from parser arguments
        self.algo = 'PPO' # Assuming PPO is the default algorithm
        self.num_env_runners = 2
        self.enable_new_api_stack = True
        self.stop_timesteps = 2000000
        self.stop_iters = 100
        self.as_release_test = False
        self.num_cpus = 10
        self.local_mode = False
        self.framework = 'torch'
        self.num_gpus = 0
        self.num_gpus_per_learner = 1
        self.num_learners = 1
        self.evaluation_interval = 0
        self.log_level = None
        self.output = None
        self.no_tune = False
        self.num_agents = 0
        self.verbose = 2
        self.num_samples = 1
        self.max_concurrent_trials = None
        self.as_test = False
        self.num_envs_per_env_runner = 1 if args.enable_new_api_stack else 5

args = Args()

In [None]:
def agent_to_module_mapping_fn(agent_id, episode, **kwargs):
        # agent_id = [0|1] -> module depends on episode ID
        # This way, we make sure that both modules sometimes play agent0
        # (start player) and sometimes agent1 (player to move 2nd).
        return "main" if hash(episode.id_) % 2 == agent_id else "random"

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    return "main" if episode.episode_id % 2 == agent_id else "random"

config = (
    get_trainable_cls(args.algo)
    .get_default_config()
    .environment("open_spiel_env")
    # Set up the main piece in this experiment: The league-bases self-play
    # callback, which controls adding new policies/Modules to the league and
    # properly matching the different policies in the league with each other.
    .callbacks(
        functools.partial(
            (
                SelfPlayCallback
                if args.enable_new_api_stack
                else SelfPlayCallbackOldAPIStack
            ),
            win_rate_threshold=args.win_rate_threshold,
        )
    )
    .env_runners(
        num_env_runners=(args.num_env_runners or 2),
        num_envs_per_env_runner=1 if args.enable_new_api_stack else 5,
    )
    .multi_agent(
        # Initial policy map: Random and default algo one. This will be expanded
        # to more policy snapshots taken from "main" against which "main"
        # will then play (instead of "random"). This is done in the
        # custom callback defined above (`SelfPlayCallback`).
        policies=(
            {
                # Our main policy, we'd like to optimize.
                "main": PolicySpec(),
                # An initial random opponent to play against.
                "random": PolicySpec(policy_class=RandomPolicy),
            }
            if not args.enable_new_api_stack
            else {"main", "random"}
        ),
        # Assign agent 0 and 1 randomly to the "main" policy or
        # to the opponent ("random" at first). Make sure (via episode_id)
        # that "main" always plays against "random" (and not against
        # another "main").
        policy_mapping_fn=(
            agent_to_module_mapping_fn
            if args.enable_new_api_stack
            else policy_mapping_fn
        ),
        # Always just train the "main" policy.
        policies_to_train=["main"],
    )
    .rl_module(
        model_config=DefaultModelConfig(fcnet_hiddens=[512, 512]),
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={
                "main": RLModuleSpec(),
                "random": RLModuleSpec(module_class=RandomRLModule),
            }
        ),
    )
)

# Only for PPO, change the `num_epochs` setting.
if args.algo == "PPO":
    config.training(num_epochs=20)

stop = {
    NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
    TRAINING_ITERATION: args.stop_iters,
    "league_size": args.min_league_size,
}

# Train the "main" policy to play really well using self-play.
results = None
if not args.from_checkpoint:
    results = run_rllib_example_script_experiment(
        config, args, stop=stop
    )

# Restore trained Algorithm (set to non-explore behavior) and play against
# human on command line.
if args.num_episodes_human_play > 0:
    num_episodes = 0
    config.explore = False
    algo = config.build()
    if args.from_checkpoint:
        algo.restore(args.from_checkpoint)
    else:
        checkpoint = results.get_best_result().checkpoint
        if not checkpoint:
            raise ValueError("No last checkpoint found in results!")
        algo.restore(checkpoint)

    if args.enable_new_api_stack:
        rl_module = algo.get_module("main")

    # Play from the command line against the trained agent
    # in an actual (non-RLlib-wrapped) open-spiel env.
    human_player = 1
    env = Environment(args.env)

    while num_episodes < args.num_episodes_human_play:
        print("You play as {}".format("o" if human_player else "x"))
        time_step = env.reset()
        while not time_step.last():
            player_id = time_step.observations["current_player"]
            if player_id == human_player:
                action = ask_user_for_action(time_step)
            else:
                obs = np.array(time_step.observations["info_state"][player_id])
                if args.enable_new_api_stack:
                    action = np.argmax(
                        rl_module.forward_inference(
                            {"obs": torch.from_numpy(obs).unsqueeze(0).float()}
                        )["action_dist_inputs"][0].numpy()
                    )
                else:
                    action = algo.compute_single_action(obs, policy_id="main")
                # In case computer chooses an invalid action, pick a
                # random one.
                legal = time_step.observations["legal_actions"][player_id]
                if action not in legal:
                    action = np.random.choice(legal)
            time_step = env.step([action])
            print(f"\n{env.get_state}")

        print(f"\n{env.get_state}")

        print("End of game!")
        if time_step.rewards[human_player] > 0:
            print("You win")
        elif time_step.rewards[human_player] < 0:
            print("You lose")
        else:
            print("Draw")
        # Switch order of players.
        human_player = 1 - human_player

        num_episodes += 1

    algo.stop()