In [1]:
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)

In [2]:
from ray.rllib.algorithms import ppo
from envs import TicTacToeEnv



In [3]:
env = TicTacToeEnv(env_config={"train": True})
terminated = False
while not terminated:
    action = env.action_space.sample()
    obs, reward, terminated, _, info = env.step(action)
    print(f"Action: {action}, Reward: {reward}, Terminated: {terminated}, Info: {info}")
    print(f"Observation: {obs}")
    print(f"State: {env.state}")
    print()

Action: 8, Reward: 0, Terminated: False, Info: {}
Observation: [<Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_1: 1>]
State: [<Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_1: 1>]

Action: 6, Reward: 0, Terminated: False, Info: {}
Observation: [<Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.PLAYER_1: 1>, <Mark.EMPTY: 0>, <Mark.PLAYER_1: 1>]
State: [<Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.PLAYER_1: 1>, <Mark.EMPTY: 0>, <Mark.PLAYER_1: 1>]

Action: 5, Reward: -1, Terminated: True, Info: {}
Observation: [<Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.EMPTY: 0>, <Mark.PLAYER_2: 2>, <Mark.PLAYER_1: 1>, <Mark.EMPTY: 0

In [4]:
from tqdm import tqdm

def train(max_episodes=10):
    config = ppo.PPOConfig().environment(env=TicTacToeEnv, env_config={"train": True})
    config = config.rl_module(_enable_rl_module_api=False).training(_enable_learner_api=False)
    algo = config.build()
    
    # tqdm pbar
    with tqdm(total=max_episodes) as pbar:
        for i in range(max_episodes):
            result = algo.train()
            pbar.update(1)
            pbar.set_description(f"episode_reward_mean: {result['episode_reward_mean']}")

    return algo


def evaluate(algo):
    env = TicTacToeEnv()
    observation, info = env.reset()
    terminated = False
    while not terminated:
        # action = env.action_space.sample()
        action = algo.compute_single_action(observation)
        observation, reward, terminated, truncated, info = env.step(action)
        print("observation: ", observation)

        if terminated or truncated:
            observation, info = env.reset()

    env.close()

In [5]:
algo = train(max_episodes=100)
algo.export_policy_model("policy_model", onnx=18)

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-08-04 14:44:22,371	INFO worker.py:1621 -- Started a local Ray instance.
[2m[36m(RolloutWorker pid=40388)[0m   logger.warn("Casting input x to numpy array.")
2023-08-04 14:44:32,750	INFO trainable.py:172 -- Trainable.setup took 16.110 seconds. If your trainable is slow to initialize, consider se

verbose: False, log level: Level.ERROR






In [6]:
import onnx
onnx_model = onnx.load("policy_model/model.onnx")
onnx.checker.check_model(onnx_model)

In [7]:
onnx.checker.check_model(onnx_model, full_check=True)

In [8]:
import onnxruntime as ort
import numpy as np

env = TicTacToeEnv({})

# obs = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])
obs = np.zeros((1, 9)).astype(np.int32)
state_ins = np.zeros((1,)).astype(np.float32)
# 
obs, info = env.reset()
obs = np.array(obs).astype(np.int32).reshape(1, -1)
print(obs.shape)
print(obs)
ort_sess = ort.InferenceSession('policy_model/model.onnx', providers=['CPUExecutionProvider'])
outputs = ort_sess.run(None, {'obs': obs, 'state_ins': state_ins})
outputs

(1, 9)
[[0 0 0 0 0 2 1 0 0]]


[array([[  1.5947901 ,   9.396042  ,   0.59201837,   1.4645842 ,
           1.6824638 , -17.239763  ,  -7.9333267 ,   7.101857  ,
           4.3303375 ]], dtype=float32),
 array([0.], dtype=float32)]

In [9]:
outputs[0]

array([[  1.5947901 ,   9.396042  ,   0.59201837,   1.4645842 ,
          1.6824638 , -17.239763  ,  -7.9333267 ,   7.101857  ,
          4.3303375 ]], dtype=float32)

In [10]:
policy = algo.get_policy()
print(policy)

PPOTorchPolicy


In [11]:
policy.observation_space_struct

Box(0, 2, (9,), int32)

In [12]:
env.action_space.sample()

0