In [1]:
import gymnasium as gym
import numpy as np
from typing import Any, Dict, List, Tuple, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv



In [2]:
import os
from ray.tune.registry import register_env
from ray.rllib.algorithms import ppo, a2c, dqn
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.examples.policy.random_policy import RandomPolicy
from rl.envs.othello import OthelloEnv
from ray.rllib.models.torch.visionnet import VisionNetwork

register_env("othello", lambda _: OthelloEnv({}))

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    agent_id = int(agent_id[-1]) - 1
    return "agent_1" if episode.episode_id % 2 == agent_id else "agent_2"

config = ppo.PPOConfig().environment("othello").framework("torch").rollouts(num_rollout_workers=12).resources(num_gpus=0)
# config = a2c.A2CConfig().environment("othello").framework("torch").rollouts(num_rollout_workers=12).resources(num_gpus=0)
# config = dqn.DQNConfig().environment("othello").framework("torch").rollouts(num_rollout_workers=12).resources(num_gpus=0)
# replay_config = config.replay_buffer_config.update( 
#     {
#         "capacity": 60000,
#         "prioritized_replay_alpha": 0.5,
#         "prioritized_replay_beta": 0.5,
#         "prioritized_replay_eps": 3e-6,
#     }
# )
config = config.multi_agent(policies={"agent_1": PolicySpec(), "agent_2": PolicySpec()}, policy_mapping_fn=policy_mapping_fn, policies_to_train=["agent_1", "agent_2"])
# config = config.training(num_sgd_iter=10, model={"conv_filters": [[32, [3, 3], 1], [64, [3, 3], 1]]})
# config = config.training(model={"conv_filters": [[32, [3, 3], 2], [64, [3, 3], 2]]}, _enable_learner_api=False)
config = config.training(_enable_learner_api=False)
config = config.rl_module(_enable_rl_module_api=False)



In [3]:
# import random
# # multiagent environment for othello
# env = OthelloEnv({})
# obs, _ = env.reset()
# terminated = False
# current_player = "agent_1"
# while not terminated:
#     env.render()
#     # action = env.action_space.sample()
#     valid_actions = env.get_valid_moves(current_player)
#     print(current_player, obs.keys())
#     print(obs[current_player].shape)
#     action = algo.compute_single_action(obs[current_player], policy_id=current_player)
#     print(action)
#     if len(valid_actions) > 0:
#         action = random.choice(valid_actions)
#     else:
#         action = 64
#     action = {current_player: action}
#     obs, reward, terminated, truncated, _= env.step(action)
#     reward = reward[current_player]
#     terminated = terminated[current_player]
#     truncated = truncated[current_player]
#     current_player = env.current_player
#     print(terminated, truncated, reward)

In [4]:
from ray import air
from ray import tune

results = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            checkpoint_at_end=True,
            checkpoint_frequency=5,
        )
    ),
).fit()

0,1
Current time:,2023-08-10 16:17:27
Running for:,01:01:10.00
Memory:,14.4/31.9 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_othello_6ab59_00000,RUNNING,127.0.0.1:24940,405,3372.23,1620000,16.5805,44,-15,15.1873


[2m[36m(PPO pid=24940)[0m Install gputil for GPU system monitoring.
[2m[36m(PPO pid=24940)[0m Caught sync error: Sync process failed: [WinError 32] Failed copying 'C:/Users/yoshi/ray_results/PPO/PPO_othello_6ab59_00000_0_2023-08-10_15-16-17/checkpoint_000005/.is_checkpoint' to 'c:///Users/yoshi/ray_results/PPO/PPO_othello_6ab59_00000_0_2023-08-10_15-16-17/checkpoint_000005/.is_checkpoint'. Detail: [Windows error 32] �v���Z�X�̓t�@�C���ɃA�N�Z�X�ł��܂���B�ʂ̃v���Z�X���g�p���ł��B
[2m[36m(PPO pid=24940)[0m . Retrying after sleeping for 1.0 seconds...
[2m[36m(PPO pid=24940)[0m Caught sync error: Sync process failed: [WinError 32] Failed copying 'C:/Users/yoshi/ray_results/PPO/PPO_othello_6ab59_00000_0_2023-08-10_15-16-17/checkpoint_000005/.is_checkpoint' to '/Users/yoshi/ray_results/PPO/PPO_othello_6ab59_00000_0_2023-08-10_15-16-17/checkpoint_000005/.is_checkpoint'. Detail: [Windows error 32] �v���Z�X�̓t�@�C���ɃA�N�Z�X�ł��܂���B�ʂ̃v���Z�X���g�p���ł��B
[2m[36m(PPO pid=24940)[0m . 

In [5]:
results

ResultGrid<[
  Result(
    metrics={'custom_metrics': {}, 'episode_media': {}, 'info': {'learner': {'agent_2': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 1.676207290465633, 'cur_kl_coeff': 0.5931615286377674, 'cur_lr': 5.0000000000000016e-05, 'total_loss': 7.1374688337246575, 'policy_loss': -0.034164507951936686, 'vf_loss': 7.167583111921946, 'vf_explained_var': 0.18800327243904272, 'kl': 0.006828161656739192, 'entropy': 0.2875105511707564, 'entropy_coeff': 0.0}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 124.8125, 'num_grad_updates_lifetime': 196590.5, 'diff_num_grad_updates_vs_sampler_policy': 239.5}, 'agent_1': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 1.7385510587443909, 'cur_kl_coeff': 0.28152002238081536, 'cur_lr': 5.0000000000000016e-05, 'total_loss': 6.906781392296155, 'policy_loss': -0.03183785179110903, 'vf_loss': 6.933691987395287, 'vf_explained_var': -0.1055408635487159, 'kl': 0.017502398914850895, 'entropy': 0.29952402260

In [6]:
checkpoint = results.get_best_result().checkpoint
print(checkpoint)

Checkpoint(uri=c://\Users\yoshi\ray_results\PPO\PPO_othello_6ab59_00000_0_2023-08-10_15-16-17\checkpoint_000405)


In [7]:
import ray
ray.shutdown()
register_env("othello", lambda _: OthelloEnv({}))

In [8]:
config.expolore = False
algo = config.build()
algo.restore(checkpoint)

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-08-10 16:17:43,856	INFO worker.py:1621 -- Started a local Ray instance.
2023-08-10 16:17:52,727	INFO trainable.py:172 -- Trainable.setup took 11.506 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2023-08-10 16:17:52,738	INFO

In [9]:
algo.get_policy("agent_1").export_model("othello_policy", onnx=18)

verbose: False, log level: Level.ERROR



In [10]:
import random
# multiagent environment for othello
env = OthelloEnv({})
obs, _ = env.reset()
terminated = False
current_player = "agent_1"
while not terminated:
    env.render()
    # action = env.action_space.sample()
    valid_actions = env.get_valid_moves(current_player)
    action = algo.compute_single_action(obs[current_player], policy_id=current_player)
    # if len(valid_actions) > 0:
    #     action = random.choice(valid_actions)
    # else:
    #     action = 64
    action = {current_player: action}
    print(action)
    obs, reward, terminated, truncated, _= env.step(action)
    reward = reward[current_player]
    terminated = terminated[current_player]
    truncated = truncated[current_player]
    current_player = env.current_player

.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|X|.|.|.
.|.|.|X|O|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

{'agent_1': 29}
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|O|O|.|.
.|.|.|X|O|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

{'agent_2': 37}
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|O|O|.|.
.|.|.|X|X|X|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

{'agent_1': 19}


In [11]:
env.render()

.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|O|O|O|.|.
.|.|.|X|X|X|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.
.|.|.|.|.|.|.|.

