In [1]:
import gymnasium as gym
import numpy as np
from typing import Any, Dict, List, Tuple, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv



In [3]:
from ray.tune.registry import register_env
from ray.rllib.algorithms import ppo
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.examples.policy.random_policy import RandomPolicy
from rl.envs import OthelloEnv

register_env("othello", lambda _: OthelloEnv({}))

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    agent_id = int(agent_id[-1]) - 1
    return "agent_1" if episode.episode_id % 2 == agent_id else "agent_2"

config = ppo.PPOConfig().environment("othello").framework("torch").rollouts(num_rollout_workers=8)
config = config.multi_agent(policies={"agent_1": PolicySpec(), "agent_2": PolicySpec()}, policy_mapping_fn=policy_mapping_fn, policies_to_train=["agent_1"])
config = config.training(model={"conv_filters": [[32, [3, 3], 1], [64, [3, 3], 1]]})



In [4]:
from ray import air
from ray import tune

results = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            checkpoint_at_end=True,
        )
    ),
).fit()

0,1
Current time:,2023-08-10 13:01:30
Running for:,00:05:13.92
Memory:,14.7/31.9 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_othello_dafdb_00000,RUNNING,127.0.0.1:21564,58,293.949,232000,-0.976601,-0.90625,-1,1.63051


[2m[36m(PPO pid=21564)[0m Install gputil for GPU system monitoring.



In [None]:
results

ResultGrid<[
  Result(
    error='TuneError',
    metrics={'trial_id': '79536_00000'},
    path='c://\\Users\\yoshi\\ray_results\\PPO\\PPO_othello_79536_00000_0_2023-08-10_12-53-32',
    checkpoint=None
  )
]>

In [None]:
checkpoint = results.get_best_result().checkpoint
print(checkpoint)

None


[2m[36m(PPO pid=23440)[0m 2023-08-10 12:53:41,398	ERROR actor_manager.py:500 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, [36mray::RolloutWorker.__init__()[39m (pid=24508, ip=127.0.0.1, actor_id=0752922d0a6b44d53ffa963b01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x0000017992C9F100>)
[2m[36m(PPO pid=23440)[0m   File "c:\Users\yoshi\Documents\Repository\rl-mini-app\rl\envs.py", line 134, in step
[2m[36m(PPO pid=23440)[0m     assert (
[2m[36m(PPO pid=23440)[0m AssertionError: Only one agent can take action at a time. {}
[2m[36m(PPO pid=23440)[0m 
[2m[36m(PPO pid=23440)[0m The above exception was the direct cause of the following exception:
[2m[36m(PPO pid=23440)[0m 
[2m[36m(PPO pid=23440)[0m [36mray::RolloutWorker.__init__()[39m (pid=24508, ip=127.0.0.1, actor_id=0752922d0a6b44d53ffa963b01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000

In [None]:
import random
# multiagent environment for othello
env = OthelloEnv({})
obs, _ = env.reset()
terminated = False
current_player = "agent_1"
while not terminated:
    env.render()
    # action = env.action_space.sample()
    valid_actions = env.get_valid_moves(current_player)
    if len(valid_actions) > 0:
        action = random.choice(valid_actions)
    else:
        action = 64
    action = {current_player: action}
    obs, reward, terminated, truncated, _= env.step(action)
    obs = obs["agent_1" if current_player == "agent_2" else "agent_2"]
    reward = reward[current_player]
    terminated = terminated[current_player]
    truncated = truncated[current_player]
    current_player = env.current_player
    print(terminated, truncated, reward)

.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|X|.|.|.
.|.|.|X|O|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

False False 0.046875
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|X|.|.|.
.|.|.|O|O|.|.|.
.|.|.|O|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

False False -0.0
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|X|.|.|.
.|.|.|O|X|.|.|.
.|.|.|O|X|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

False False 0.046875
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|O|.|.
.|.|.|O|O|.|.|.
.|.|.|O|X|.|.|.
.|.|.|O|X|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

False False -0.0
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|O|.|.
.|.|.|O|O|.|.|.
.|.|.|O|X|.|.|.
.|.|.|X|X|.|.|.
.|.|X|.|.|.|.|.
.|.|.|.|.|.|.|.

False False 0.046875
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|O|.|.
.|.|.|O|O|.|.|.
.|.|.|O|O|O|.|.
.|.|.|X|X|.|.|.
.|.|X|.|.|.|.|.
.|.|.|.|.|.|.|.

False False -0.0
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|O|.|.
.|.|.|O|O|.|X|.
.|.|.|O|O|X|.|.
.|.|.|X|X|.|.|.
.|.|X|.|.|.|.|.


In [None]:
env.render()

O|O|O|O|O|O|O|O
X|O|O|O|O|O|O|O
X|O|O|O|O|O|O|O
O|O|X|O|O|X|O|O
X|X|X|X|O|O|X|O
X|X|X|X|X|O|X|O
X|X|X|X|X|X|X|O
O|O|O|O|O|O|X|O

