In [None]:
import sys

sys.path.append("../../")

import pickle
from hyperopt import hp
from utils import CategoricalCrossentropyLoss, MSELoss, generate_layer_widths
import gymnasium as gym
import torch


def action_function(x):
    onehot_action = torch.zeros((3, 3)).view(-1)
    onehot_action[x] = 1
    return onehot_action.view(1, 3, 3)


search_space = {
    "kernel_initializer": hp.choice(
        "kernel_initializer",
        [
            "he_uniform",
            "he_normal",
            "glorot_uniform",
            "glorot_normal",
            "orthogonal",
        ],
    ),
    "learning_rate": hp.choice("learning_rate", [0.01, 0.001, 0.0001, 0.00001]),
    "adam_epsilon": hp.choice("adam_epsilon", [0.3125, 0.03125, 0.003125, 0.0003125]),
    "known_bounds": hp.choice("known_bounds", [[-1, 1]]),
    "residual_layers": hp.choice(
        "residual_layers",
        [
            [(16, 3, 1)] * 1,
            [(32, 3, 1)] * 1,
            [(64, 3, 1)] * 1,
            [(128, 3, 1)] * 1,
            [(16, 3, 1)] * 3,
            [(32, 3, 1)] * 3,
            [(64, 3, 1)] * 3,
            [(128, 3, 1)] * 3,
            [(16, 3, 1)] * 5,
            [(32, 3, 1)] * 5,
            [(64, 3, 1)] * 5,
            [(128, 3, 1)] * 5,
            # [(16, 3, 1)] * 10,
            # [(32, 3, 1)] * 10,
            # [(64, 3, 1)] * 10,
            # [(128, 3, 1)] * 10,
            # [(16, 3, 1)] * 20,
            # [(32, 3, 1)] * 20,
            # [(64, 3, 1)] * 20,
            # [(128, 3, 1)] * 20,
        ],
    ),
    "conv_layers": hp.choice("conv_layers", [[]]),
    "dense_layers": hp.choice("dense_layers", [[]]),
    "actor_conv_layers": hp.choice(
        "actor_conv_layers", [[], [(32, 1, 1)], [(64, 1, 1)], [(128, 1, 1)]]
    ),
    "critic_conv_layers": hp.choice(
        "critic_conv_layers", [[], [(32, 1, 1)], [(64, 1, 1)], [(128, 1, 1)]]
    ),
    "reward_conv_layers": hp.choice("reward_conv_layers", [[]]),
    "actor_dense_layer_widths": hp.choice("actor_dense_layer_widths", [[]]),
    "critic_dense_layer_widths": hp.choice("critic_dense_layer_widths", [[]]),
    "reward_dense_layer_widths": hp.choice("reward_dense_layer_widths", [[]]),
    "dense_layer_widths": hp.choice("dense_layer_widths", [[]]),
    "noisy_sigma": hp.choice("noisy_sigma", [0.0]),
    "games_per_generation": hp.choice(
        "games_per_generation",
        [
            32,
            64,
            # 128
        ],
    ),
    "value_loss_factor": hp.choice("value_loss_factor", [0.25, 1.0]),
    "root_dirichlet_alpha": hp.choice("root_dirichlet_alpha", [0.3, 1.0, 2.0]),
    "root_exploration_fraction": hp.choice("root_exploration_fraction", [0.25]),
    "num_simulations": hp.choice(
        "num_simulations",
        [
            25,
            50,
            100,
            200,
            # 400,
            # 800
        ],
    ),
    "num_sampling_moves": hp.choice("num_sampling_moves", [0, 1, 2, 3, 5, 9]),
    "exploration_temperature": hp.choice("exploration_temperature", [1.0]),
    "exploitation_temperature": hp.choice("exploitation_temperature", [0.1]),
    "clip_low_prob": hp.choice("clip_low_prob", [0.0]),
    "pb_c_base": hp.choice("pb_c_base", [19652]),
    "pb_c_init": hp.choice("pb_c_init", [1.25]),
    "value_loss_function": hp.choice("value_loss_function", [MSELoss()]),
    "reward_loss_function": hp.choice("reward_loss_function", [MSELoss()]),
    "policy_loss_function": hp.choice(
        "policy_loss_function", [CategoricalCrossentropyLoss()]
    ),
    "training_steps": hp.choice("training_steps", [200]),
    "minibatch_size": hp.choice(
        "minibatch_size",
        [
            32,
            64,
            # 128
        ],
    ),
    "min_replay_buffer_size": hp.choice(
        "min_replay_buffer_size",
        [
            32,
            1024,
            # 2048
        ],
    ),
    "replay_buffer_size": hp.choice("replay_buffer_size", [4000, 8000, 16000, 32000]),
    "unroll_steps": hp.choice("unroll_steps", [5]),
    "n_step": hp.choice("n_step", [9]),
    "clipnorm": hp.choice("clipnorm", [0.0, 1.0, 10.0]),
    "weight_decay": hp.choice("weight_decay", [1e-5, 1e-4, 1e-3]),
    "per_alpha": hp.choice("per_alpha", [0.0, 0.5, 1.0]),
    "per_beta": hp.choice("per_beta", [0.0, 0.5, 1.0]),
    "per_beta_final": hp.choice("per_beta_final", [0.0, 0.5, 1.0]),
    "per_epsilon": hp.choice("per_epsilon", [1e-4]),
    "action_function": hp.choice("action_function", [action_function]),
}

initial_best_config = [{}]


pickle.dump(search_space, open("./search_spaces/search_space.pkl", "wb"))
pickle.dump(
    initial_best_config,
    open("./search_spaces/initial_best_config.pkl", "wb"),
)

In [None]:
import pandas as pd
import random
from tqdm import tqdm
import sys

sys.path.append("../../")
from elo.elo import StandingsTable

players = []
games_per_pair = 10

player_names = []
table = StandingsTable(player_names, start_elo=1000)

import pickle

print(table.bayes_elo())
print(table.get_win_table())
print(table.get_draw_table())
file = "tictactoe_table.pkl"
pickle.dump(table, open(file, "wb"))


def play_1v1_tournament(players, games_per_pair, play_game):
    tournament_results = []
    for player1 in players:
        results = play_matches(player1, players, games_per_pair, play_game)
        tournament_results.extend(results)
    tournament_results = pd.DataFrame(
        tournament_results, columns=["player1", "player2", "result"]
    )
    return tournament_results


def play_matches(player1, players, games_per_pair, play_game):
    results = []
    for opponent in players:
        if opponent != player1:
            for _ in range(games_per_pair // 2):
                print(
                    f"Playing {player1.model_name} vs {opponent.model_name} game {_+1}"
                )
                result = play_game(player1, opponent)
                results.append((player1.model_name, opponent.model_name, result))

    for opponent in players:
        if opponent != player1:
            for _ in range(games_per_pair // 2):
                print(
                    f"Playing {opponent.model_name} vs {player1.model_name} game {_+1}"
                )
                result = play_game(opponent, player1)
                results.append(
                    (
                        opponent.model_name,
                        player1.model_name,
                        result,
                    )
                )
    table.add_results_from_array(results)
    print(table.bayes_elo())
    pickle.dump(table, open(file, "wb"))
    return results

In [None]:
import torch
from packages.utils.utils.utils import process_petting_zoo_obs

from pettingzoo.classic import tictactoe_v3


def play_game(player1, player2):

    env = tictactoe_v3.env(render_mode="rgb_array")
    with torch.no_grad():  # No gradient computation during testing
        # Reset environment
        env.reset()
        state, reward, termination, truncation, info = env.last()
        done = termination or truncation
        agent_id = env.agent_selection
        current_player = env.agents.index(agent_id)
        state, info = process_petting_zoo_obs(state, info, current_player)
        agent_names = env.agents.copy()

        episode_length = 0
        while not done and episode_length < 1000:  # Safety limit
            # Get current agent and player
            episode_length += 1

            # Get action from average strategy
            if current_player == 0:
                prediction = player1.predict(state, info, env=env)
                action = player1.select_actions(prediction, info).item()
            else:
                prediction = player2.predict(state, info, env=env)
                action = player2.select_actions(prediction, info).item()

            # Step environment
            env.step(action)
            state, reward, termination, truncation, info = env.last()
            agent_id = env.agent_selection
            current_player = env.agents.index(agent_id)
            state, info = process_petting_zoo_obs(state, info, current_player)
            done = termination or truncation
        print(env.rewards)
        return env.rewards["player_0"]

In [None]:
import os
import pickle
from agent_configs.muzero_config import MuZeroConfig
from game_configs.tictactoe_config import TicTacToeConfig
import numpy as np
import pandas
import gymnasium as gym
from hyperopt import tpe, fmin, space_eval, STATUS_OK, STATUS_FAIL
from agent_configs import RainbowConfig
import gc
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from muzero.muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch

import sys

sys.path.append("../..")
from dqn.rainbow.rainbow_agent import RainbowAgent

# from rainbow_agent import RainbowAgent
from game_configs import CartPoleConfig

global file_name
global eval_method


def run_training(args):
    m = MuZeroAgent(
        env=args[1],
        config=MuZeroConfig(args[0], TicTacToeConfig()),
        name="{}".format(args[2]),
    )
    m.checkpoint_interval = 10
    m.checkpoint_trials = 100
    m.train()
    print("Training complete")
    if eval_method == "final_score":
        return -m.test(num_trials=10, step=5000, dir=f"./checkpoints/")["score"]
    elif eval_method == "rolling_average":
        return -np.around(
            np.mean([stat_dict["score"] for stat_dict in m.stats["test_score"][-10:]]),
            1,
        )
    elif eval_method == "final_score_rolling_average":
        return (
            -m.test(num_trials=10, step=5000, dir=f"./checkpoints/")["score"]
            - np.around(
                np.mean(
                    [stat_dict["score"] for stat_dict in m.stats["test_score"][-10:-1]]
                ),
                1,
            )
        ) / 2
    elif eval_method == "elo":
        # pick n opps to test against from players
        opponents = np.random.choice(players, size=min(5, len(players)), replace=False)
        print(f"Testing against opponents: {[o for o in opponents]}")
        table.add_player(m.model_name)
        players.append(m)
        print(table.players)
        print(players)
        pickle.dump(table, open(file, "wb"))
        if len(opponents) == 0:
            return 0
        results = play_matches(m, opponents, games_per_pair, play_game)
        pickle.dump(table, open(file, "wb"))
        bayes_elo = table.bayes_elo()["Elo table"]
        # elo = bayes_elo[m.model_name]
        elo = bayes_elo.iloc[-1]["Elo"]
        print(bayes_elo)
        print(f"Elo: {elo}")
        return -elo


def objective(params):
    gc.collect()
    print("Params: ", params)
    print("Making environments")
    env = tictactoe_v3.env(render_mode="rgb_array")
    if os.path.exists(f"./{file_name}_trials.p"):
        trials = pickle.load(open(f"./{file_name}_trials.p", "rb"))
        name = "{}_{}".format(file_name, len(trials.trials) + 1)
    else:
        name = f"{file_name}_1"
    # name = datetime.datetime.now().timestamp()
    params["model_name"] = name
    params_for_csv = params.copy()
    for key, value in params_for_csv.items():
        if hasattr(value, "__name__"):  # For functions/classes
            params_for_csv[key] = value.__name__
        elif hasattr(value, "__class__") and hasattr(
            value.__class__, "__name__"
        ):  # For objects
            params_for_csv[key] = value.__class__.__name__
        elif (
            isinstance(value, list)
            and len(value) > 0
            and hasattr(value[0], "__class__")
        ):  # For lists of objects
            params_for_csv[key] = str(value)
        # Keep other values as-is

    entry = pandas.DataFrame.from_dict(
        params_for_csv,
        orient="index",
    ).T

    entry.to_csv(
        f"./{file_name}_results.csv",
        mode="a",
        header=False,
    )

    status = STATUS_OK
    try:
        # add other illegal hyperparameter combinations here
        # assert params["min_replay_buffer_size"] >= params["minibatch_size"]
        # assert params["replay_buffer_size"] > params["min_replay_buffer_size"]
        score = run_training([params, env, name])
    except AssertionError as e:
        status = STATUS_FAIL
        print(f"exited due to invalid hyperparameter combination: {e}")
        return {"status": status, "loss": 0}

    # if status != STATUS_FAIL:
    # score = run_training([params, env, name])

    # num_workers = len(environments_list)
    # args_list = np.array(
    #     [
    #         [params for env in environments_list],
    #         environments_list,
    #         [name for env in environments_list],
    #     ]
    # ).T
    # with contextlib.closing(multiprocessing.Pool()) as pool:
    #     scores_list = pool.map_async(
    #         globalized_training_func, (args for args in args_list)
    #     ).get()
    #     print(scores_list)
    print("parallel programs done")
    return {"status": status, "loss": score}  # np.mean(scores_list)

In [None]:
search_space_path, initial_best_config_path = (
    "./search_spaces/search_space.pkl",
    "./search_spaces/initial_best_config.pkl",
)
search_space = pickle.load(open(search_space_path, "rb"))
initial_best_config = pickle.load(open(initial_best_config_path, "rb"))
file_name = "tictactoe_muzero"
eval_method = "elo"  # elo?
assert (
    eval_method == "final_score"
    or eval_method == "rolling_average"
    or eval_method == "final_score_rolling_average"
    or eval_method == "elo"
)
max_trials = 1
trials_step = 64  # how many additional trials to do after loading the last ones

try:  # try to load an already saved trials object, and increase the max
    trials = pickle.load(open(f"./{file_name}_trials.p", "rb"))
    print("Found saved Trials! Loading...")
    max_trials = len(trials.trials) + 1
    print(
        "Rerunning from {} trials to {} (+{}) trials".format(
            len(trials.trials), max_trials, trials_step
        )
    )
except:  # create a new trials object and start searching
    trials = None

for i in range(trials_step):
    best = fmin(
        fn=objective,  # Objective Function to optimize
        space=search_space,  # Hyperparameter's Search Space
        algo=tpe.suggest,  # Optimization algorithm (representative TPE)
        max_evals=max_trials,  # Number of optimization attempts
        trials=trials,  # Record the results
        # early_stop_fn=no_progress_loss(5, 1),
        trials_save_file=f"./{file_name}_trials.p",
        # points_to_evaluate=initial_best_config,
        show_progressbar=False,
    )

    trials = pickle.load(open(f"./{file_name}_trials.p", "rb"))
    print("Found saved Trials! Loading and Updating...")
    try:
        elo_table = table.bayes_elo()["Elo table"]
        for trial in range(len(trials.trials)):
            trial_elo = elo_table.iloc[trial]["Elo"]
            print(f"Trial {trials.trials[trial]['tid']} ELO: {trial_elo}")
            trials.trials[trial]["result"]["loss"] = -trial_elo
            pickle.dump(trials, open(f"./{file_name}_trials.p", "wb"))
    except ZeroDivisionError:
        print("Not enough players to calculate elo.")
    max_trials = len(trials.trials) + 1
    print(best)
    best_trial = space_eval(search_space, best)
# gc.collect()

100%|██████████| 64/64 [06:52<00:00,  6.45s/it]


Samples: {'observations': array([[[[0, 1, 1],
         [0, 0, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 0, 1],
         [0, 1, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 1, 0],
         [1, 0, 0],
         [1, 1, 0]],

        [[1, 0, 1],
         [0, 1, 1],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [08:08<00:00,  7.63s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [1, 0, 0]]],


       [[[1, 0, 0],
         [0, 1, 0],
         [1, 1, 0]],

        [[0, 1, 1],
         [1, 0, 1],
         [0, 0, 1]]],


       [[[1, 1, 0],
         [0, 0, 1],
         [0, 0, 1]],

        [[0, 0, 1],
         [1, 1, 0],
         [1, 1, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [06:48<00:00,  6.38s/it]


Samples: {'observations': array([[[[0, 1, 1],
         [0, 1, 0],
         [0, 0, 0]],

        [[1, 0, 0],
         [1, 0, 1],
         [0, 1, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[1, 0, 1],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 1, 0],
         [1, 1, 0],
         [1, 0, 0]]],


       ...,


       [[[0, 0, 0],
         [1, 0, 1],
         [1, 1, 0]],

        [[1, 1, 1],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [1, 1, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [07:46<00:00,  7.29s/it]


Samples: {'observations': array([[[[1, 0, 0],
         [0, 1, 1],
         [0, 1, 0]],

        [[0, 1, 1],
         [1, 0, 0],
         [1, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 1],
         [0, 0, 1]]],


       [[[0, 1, 0],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 0, 1],
         [0, 0, 1],
         [0, 1, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [07:24<00:00,  6.94s/it]


Samples: {'observations': array([[[[0, 1, 0],
         [0, 0, 0],
         [1, 0, 1]],

        [[1, 0, 1],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 0, 1],
         [0, 0, 1],
         [1, 0, 0]],

        [[1, 1, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 1, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [1, 1, 0]]],


       [[[0, 1, 0],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 1]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [06:59<00:00,  6.56s/it]


Samples: {'observations': array([[[[1, 0, 0],
         [0, 0, 1],
         [0, 0, 1]],

        [[0, 1, 1],
         [0, 0, 0],
         [1, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 1, 0]],

        [[0, 1, 1],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 0, 1],
         [0, 0, 1],
         [1, 0, 0]],

        [[0, 1, 0],
         [1, 1, 0],
         [0, 1, 0]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]]],


       [[[0, 0, 0],
         [1, 0, 1],
         [0, 1, 0]],

        [[0, 1, 1],
         [0, 1, 0],
         [1, 0, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [06:23<00:00,  6.00s/it]


Samples: {'observations': array([[[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 1, 0]]],


       ...,


       [[[0, 0, 0],
         [1, 1, 0],
         [1, 0, 0]],

        [[0, 1, 0],
         [0, 0, 1],
         [0, 0, 1]]],


       [[[0, 1, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 1, 0]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 0],
         [1, 1, 0],
         [1, 0, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [06:08<00:00,  5.75s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[1, 0, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 1, 0],
         [0, 0, 0]]],


       ...,


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

100%|██████████| 64/64 [06:55<00:00,  6.49s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[1, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [1, 0, 0]]],


       ...,


       [[[0, 0, 1],
         [0, 0, 1],
         [1, 1, 0]],

        [[1, 1, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]]]], dtype=int8), 'rewards': array([list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, None, None, None, None]),
       list([None, None, Non

  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Prediction: (tensor([0.0900, 0.1900, 0.1600, 0.0600, 0.1000, 0.1800, 0.0900, 0.0700, 0.0600]), tensor([0.0900, 0.1900, 0.1600, 0.0600, 0.1000, 0.1800, 0.0900, 0.0700, 0.0600]), [0, 1, 2, 3, 4, 5, 6, 7, 8], tensor([0.0068]))
Player 1 Random Action: 2
Player 0 Prediction: (tensor([0.1100, 0.3100, 0.1400, 0.1500, 0.1200, 0.0900, 0.0800]), tensor([0.1100, 0.3100, 0.0000, 0.1400, 0.1500, 0.1200, 0.0000, 0.0900, 0.0800]), [0, 1, 3, 4, 5, 7, 8], tensor([0.0284]))
Player 1 Random Action: 8


  1%|          | 1/100 [00:02<04:52,  2.96s/it]

Player 0 Prediction: (tensor([0.3500, 0.1400, 0.1500, 0.2100, 0.1500]), tensor([0.0000, 0.3500, 0.0000, 0.1400, 0.1500, 0.2100, 0.0000, 0.1500, 0.0000]), [1, 3, 4, 5, 7], tensor([-0.1125]))
Player 1 Random Action: 5


 99%|█████████▉| 99/100 [06:03<00:03,  3.25s/it]

Started recording episode 2499 to checkpoints/tictactoe_muzero_3/step_90/videos/random/tictactoe_muzero_3/None/episode_002499.mp4


100%|██████████| 100/100 [06:07<00:00,  3.68s/it]


Stopped recording episode 2499. Recorded 10 frames.
Testing Player 1 vs Agent random


  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Random Action: 5
Player 1 Prediction: (tensor([0.0900, 0.3000, 0.1800, 0.1400, 0.0900, 0.1000, 0.0800, 0.0200]), tensor([0.0900, 0.3000, 0.1800, 0.1400, 0.0900, 0.0000, 0.1000, 0.0800, 0.0200]), [0, 1, 2, 3, 4, 6, 7, 8], tensor([0.0091]))
Player 0 Random Action: 8
Player 1 Prediction: (tensor([0.1700, 0.2400, 0.1300, 0.1300, 0.1700, 0.1600]), tensor([0.1700, 0.0000, 0.2400, 0.1300, 0.1300, 0.0000, 0.1700, 0.1600, 0.0000]), [0, 2, 3, 4, 6, 7], tensor([-0.0126]))
Player 0 Random Action: 3


  1%|          | 1/100 [00:02<03:36,  2.18s/it]

Player 1 Prediction: (tensor([0.1800, 0.3400, 0.2000, 0.2800]), tensor([0.1800, 0.0000, 0.3400, 0.0000, 0.2000, 0.0000, 0.0000, 0.2800, 0.0000]), [0, 2, 4, 7], tensor([0.0063]))
Player 0 Random Action: 4


 99%|█████████▉| 99/100 [04:30<00:02,  2.47s/it]

Started recording episode 2599 to checkpoints/tictactoe_muzero_3/step_90/videos/random/tictactoe_muzero_3/None/episode_002599.mp4


100%|██████████| 100/100 [04:33<00:00,  2.74s/it]


Stopped recording episode 2599. Recorded 9 frames.


 99%|█████████▉| 99/100 [10:28<00:06,  6.44s/it]

Started recording episode 2699 to checkpoints/tictactoe_muzero_3/step_90/videos/tictactoe_muzero_3/90/episode_002699.mp4


100%|██████████| 100/100 [10:34<00:00,  6.34s/it]

Stopped recording episode 2699. Recorded 8 frames.
average score: 0.35





Plotting score...
scores are win/loss plotting a rolling average of the scores
Plotting policy_loss...
Plotting value_loss...
Plotting loss...
Plotting test_score...
Plotting test_score_vs_random...


  3%|▎         | 2/64 [00:10<05:26,  5.26s/it]


Stopping game generation
Stopping training
Finished Training
Testing Player 0 vs Agent random


  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Prediction: (tensor([0.0900, 0.1900, 0.1600, 0.0600, 0.1000, 0.1800, 0.0900, 0.0700, 0.0600]), tensor([0.0900, 0.1900, 0.1600, 0.0600, 0.1000, 0.1800, 0.0900, 0.0700, 0.0600]), [0, 1, 2, 3, 4, 5, 6, 7, 8], tensor([0.0068]))
Player 1 Random Action: 7
Player 0 Prediction: (tensor([0.0900, 0.4100, 0.1000, 0.1200, 0.1200, 0.1200, 0.0400]), tensor([0.0900, 0.4100, 0.1000, 0.0000, 0.1200, 0.1200, 0.1200, 0.0000, 0.0400]), [0, 1, 2, 4, 5, 6, 8], tensor([-0.0668]))
Player 1 Random Action: 5
Player 0 Prediction: (tensor([0.2200, 0.3400, 0.1500, 0.2500, 0.0400]), tensor([0.2200, 0.0000, 0.3400, 0.0000, 0.1500, 0.0000, 0.2500, 0.0000, 0.0400]), [0, 2, 4, 6, 8], tensor([0.0169]))
Player 1 Random Action: 0
Player 0 Prediction: (tensor([0.5100, 0.3300, 0.1600]), tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.5100, 0.0000, 0.3300, 0.0000, 0.1600]), [4, 6, 8], tensor([-0.0113]))
Player 1 Random Action: 8


  1%|          | 1/100 [00:04<07:27,  4.52s/it]

Player 0 Prediction: (tensor([1.]), tensor([0., 0., 0., 0., 1., 0., 0., 0., 0.]), [4], tensor([0.0008]))


 99%|█████████▉| 99/100 [10:58<00:03,  3.09s/it]  

Started recording episode 2799 to checkpoints/tictactoe_muzero_3/step_90/videos/random/tictactoe_muzero_3/None/episode_002799.mp4


100%|██████████| 100/100 [11:01<00:00,  6.61s/it]


Stopped recording episode 2799. Recorded 7 frames.
Testing Player 1 vs Agent random


  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Random Action: 8
Player 1 Prediction: (tensor([0.0800, 0.2500, 0.1000, 0.1200, 0.1400, 0.0400, 0.1800, 0.0900]), tensor([0.0800, 0.2500, 0.1000, 0.1200, 0.1400, 0.0400, 0.1800, 0.0900, 0.0000]), [0, 1, 2, 3, 4, 5, 6, 7], tensor([-0.0527]))
Player 0 Random Action: 2
Player 1 Prediction: (tensor([0.3600, 0.1100, 0.1300, 0.1500, 0.1500, 0.1000]), tensor([0.0000, 0.3600, 0.0000, 0.1100, 0.1300, 0.1500, 0.1500, 0.1000, 0.0000]), [1, 3, 4, 5, 6, 7], tensor([0.0489]))
Player 0 Random Action: 4


  1%|          | 1/100 [00:02<03:43,  2.26s/it]

Player 1 Prediction: (tensor([0.4500, 0.1300, 0.2100, 0.2100]), tensor([0.0000, 0.4500, 0.0000, 0.1300, 0.0000, 0.2100, 0.0000, 0.2100, 0.0000]), [1, 3, 5, 7], tensor([-0.0077]))
Player 0 Random Action: 5


 99%|█████████▉| 99/100 [04:24<00:03,  3.24s/it]

Started recording episode 2899 to checkpoints/tictactoe_muzero_3/step_90/videos/random/tictactoe_muzero_3/None/episode_002899.mp4


100%|██████████| 100/100 [04:28<00:00,  2.69s/it]


Stopped recording episode 2899. Recorded 8 frames.


 99%|█████████▉| 99/100 [09:50<00:05,  5.88s/it]

Started recording episode 2999 to checkpoints/tictactoe_muzero_3/step_90/videos/tictactoe_muzero_3/90/episode_002999.mp4


100%|██████████| 100/100 [10:00<00:00,  6.01s/it]

Stopped recording episode 2999. Recorded 10 frames.
average score: 0.34





Plotting score...
scores are win/loss plotting a rolling average of the scores
Plotting policy_loss...
Plotting value_loss...
Plotting loss...
Plotting test_score...
Plotting test_score_vs_random...
Training complete
Testing against opponents: [<muzero.muzero_agent_torch.MuZeroAgent object at 0x3245a0100>, <muzero.muzero_agent_torch.MuZeroAgent object at 0x322026050>]
['tictactoe_muzero_1', 'tictactoe_muzero_2', 'tictactoe_muzero_3']
[<muzero.muzero_agent_torch.MuZeroAgent object at 0x3245a0100>, <muzero.muzero_agent_torch.MuZeroAgent object at 0x322026050>, <muzero.muzero_agent_torch.MuZeroAgent object at 0x3220c39d0>]
Playing tictactoe_muzero_3 vs tictactoe_muzero_1 game 1
{'player_0': 1, 'player_1': -1}
Playing tictactoe_muzero_3 vs tictactoe_muzero_1 game 2
{'player_0': 1, 'player_1': -1}
Playing tictactoe_muzero_3 vs tictactoe_muzero_1 game 3
{'player_0': 1, 'player_1': -1}
Playing tictactoe_muzero_3 vs tictactoe_muzero_1 game 4
{'player_0': -1, 'player_1': 1}
Playing tictactoe_mu

  +0.5 * drawTable[i][j] * np.log(
  +0.5 * drawTable[i][j] * np.log(
  return 1.0 / (1 + np.power(10, D / 400))
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  df = fun(x) - f0
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  +0.5 * drawTable[i][j] * np.log(
  +0.5 * drawTable[i][j] * np.log(
  +0.5 * drawTable[i][j] * np.log(
  return 1.0 / (1 + np.power(10, D / 400))
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  df = fun(x) - f0
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  +0.5 * drawTable[i][j] * np.log(


Found saved Trials! Loading and Updating...
Trial 0 ELO: 917
Trial 1 ELO: 1035
Trial 2 ELO: 1047
{'action_function': 0, 'actor_conv_layers': 2, 'actor_dense_layer_widths': 0, 'adam_epsilon': 0, 'clip_low_prob': 0, 'clipnorm': 2, 'conv_layers': 0, 'critic_conv_layers': 2, 'critic_dense_layer_widths': 0, 'dense_layer_widths': 0, 'dense_layers': 0, 'exploitation_temperature': 0, 'exploration_temperature': 0, 'games_per_generation': 1, 'kernel_initializer': 2, 'known_bounds': 0, 'learning_rate': 3, 'min_replay_buffer_size': 1, 'minibatch_size': 1, 'n_step': 0, 'noisy_sigma': 0, 'num_sampling_moves': 3, 'num_simulations': 2, 'pb_c_base': 0, 'pb_c_init': 0, 'per_alpha': 0, 'per_beta': 1, 'per_beta_final': 0, 'per_epsilon': 0, 'policy_loss_function': 0, 'replay_buffer_size': 3, 'residual_layers': 11, 'reward_conv_layers': 2, 'reward_dense_layer_widths': 0, 'reward_loss_function': 0, 'root_dirichlet_alpha': 0, 'root_exploration_fraction': 0, 'training_steps': 0, 'unroll_steps': 0, 'value_loss_

  +0.5 * drawTable[i][j] * np.log(
  +0.5 * drawTable[i][j] * np.log(
  return 1.0 / (1 + np.power(10, D / 400))
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  df = fun(x) - f0
  l += winTable[i][j] * np.log(f(elos[i] - elos[j] - eloAdvantage + eloDraw))
  +0.5 * drawTable[i][j] * np.log(
This process is not trusted! Input event monitoring will not be possible until it is added to accessibility clients.


Params:  {'action_function': <function action_function at 0x105e0f400>, 'actor_conv_layers': ((128, 1, 1),), 'actor_dense_layer_widths': (), 'adam_epsilon': 0.03125, 'clip_low_prob': 0.0, 'clipnorm': 1.0, 'conv_layers': (), 'critic_conv_layers': ((32, 1, 1),), 'critic_dense_layer_widths': (), 'dense_layer_widths': (), 'dense_layers': (), 'exploitation_temperature': 0.1, 'exploration_temperature': 1.0, 'games_per_generation': 64, 'kernel_initializer': 'glorot_uniform', 'known_bounds': (-1, 1), 'learning_rate': 0.0001, 'min_replay_buffer_size': 32, 'minibatch_size': 32, 'n_step': 9, 'noisy_sigma': 0.0, 'num_sampling_moves': 2, 'num_simulations': 200, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'per_alpha': 0.0, 'per_beta': 0.5, 'per_beta_final': 1.0, 'per_epsilon': 0.0001, 'policy_loss_function': <utils.utils.CategoricalCrossentropyLoss object at 0x324192800>, 'replay_buffer_size': 32000, 'residual_layers': ((128, 3, 1), (128, 3, 1), (128, 3, 1)), 'reward_conv_layers': ((32, 1, 1),), 'reward_

100%|██████████| 64/64 [08:51<00:00,  8.30s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 1, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 0, 1]],

        [[1, 0, 0],
         [1, 1, 0],
         [0, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 1],
         [0, 0, 1]],

        [[0, 1, 1],
         [1, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 1, 0],
         [

100%|██████████| 64/64 [08:43<00:00,  8.17s/it]


Samples: {'observations': array([[[[0, 0, 1],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 1, 0],
         [1, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[1, 0, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 0],
         [1, 1, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 1, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]],

        [[1, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 1],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [1, 0, 1],
         [0, 0, 0]],

        [[1, 1, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 1, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [10:02<00:00,  9.41s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 1, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 1, 1],
         [1, 0, 0]],

        [[0, 1, 1],
         [1, 0, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 1, 1]],

        [[1, 1, 0],
         [1, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 1, 0]],

        [[0, 0, 0],
         [1, 1, 1],
         [0, 0, 1]]],


       [[[1, 0, 1],
         [1, 1, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 1],
         [1, 1, 1]]],


       [[[0, 0, 1],
         [1, 0, 0],
         [0, 1, 1]],

        [[1, 1, 0],
         [0, 1, 1],
         [1, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [07:30<00:00,  7.04s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 1, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 0],
         [1, 1, 0],
         [0, 0, 0]],

        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [1, 0, 0],
         [0, 1, 0]],

        [[0, 1, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [1, 1, 0],
         [0, 0, 0]],

        [[0, 1, 1],
         [0, 0, 1],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [08:23<00:00,  7.87s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]]],


       [[[0, 1, 1],
         [0, 1, 1],
         [0, 0, 0]],

        [[1, 0, 0],
         [1, 0, 0],
         [1, 1, 1]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 1, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[1, 1, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[0, 0, 1],
         [1, 0, 0],
         [0, 0, 1]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 0],
         [0, 0, 0],
         [0, 1, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [08:24<00:00,  7.88s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 1]]],


       [[[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [1, 0, 1],
         [0, 0, 1]],

        [[1, 1, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[1, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [08:39<00:00,  8.12s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [1, 0, 0],
         [0, 0, 1]],

        [[1, 0, 1],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 1]],

        [[1, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 1, 0],
         [1, 0, 0],
         [1, 1, 0]],

        [[0, 0, 1],
         [0, 1, 1],
         [0, 0, 1]]],


       [[[1, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 1]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[1, 0, 1],
         [0, 1, 1],
         [0, 0, 0]],

        [[0, 1, 0],
         [1, 0, 0],
         [1, 1, 1]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [1, 0, 0]],

        [[1, 1, 0],
         [

100%|██████████| 64/64 [23:43<00:00, 22.25s/it] 


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [1, 0, 1],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 1, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[1, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 0],
         [0, 1, 1],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [

100%|██████████| 64/64 [09:13<00:00,  8.65s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 1, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 0, 1],
         [0, 0, 0],
         [0, 1, 1]],

        [[0, 1, 0],
         [1, 1, 0],
         [1, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 1, 0],
         [

100%|██████████| 64/64 [06:19<00:00,  5.93s/it]


Samples: {'observations': array([[[[0, 1, 0],
         [0, 1, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [0, 0, 1]]],


       [[[1, 0, 0],
         [0, 0, 1],
         [1, 1, 0]],

        [[0, 1, 1],
         [1, 1, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 1, 0],
         [0, 0, 1]]],


       [[[0, 1, 1],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 1],
         [1, 0, 1]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 1]]],


       [[[0, 0, 1],
         [0, 1, 0],
         [0, 0, 0]],

        [[1, 1, 0],
         [

  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Prediction: (tensor([0.1000, 0.1750, 0.1500, 0.0850, 0.0750, 0.0900, 0.1300, 0.0600, 0.1350]), tensor([0.1000, 0.1750, 0.1500, 0.0850, 0.0750, 0.0900, 0.1300, 0.0600, 0.1350]), [0, 1, 2, 3, 4, 5, 6, 7, 8], tensor([-0.0170]))
Player 1 Random Action: 8
Player 0 Prediction: (tensor([0.1300, 0.1500, 0.1500, 0.1700, 0.1400, 0.1850, 0.0750]), tensor([0.1300, 0.1500, 0.0000, 0.1500, 0.1700, 0.1400, 0.1850, 0.0750, 0.0000]), [0, 1, 3, 4, 5, 6, 7], tensor([0.0147]))
Player 1 Random Action: 5
Player 0 Prediction: (tensor([0.2150, 0.2300, 0.1850, 0.2450, 0.1250]), tensor([0.2150, 0.2300, 0.0000, 0.1850, 0.0000, 0.0000, 0.2450, 0.1250, 0.0000]), [0, 1, 3, 6, 7], tensor([-0.0022]))
Player 1 Random Action: 0
Player 0 Prediction: (tensor([0.5100, 0.3100, 0.1800]), tensor([0.0000, 0.5100, 0.0000, 0.0000, 0.0000, 0.0000, 0.3100, 0.1800, 0.0000]), [1, 6, 7], tensor([-0.0047]))
Player 1 Random Action: 7


  1%|          | 1/100 [00:02<04:05,  2.48s/it]

Player 0 Prediction: (tensor([1.]), tensor([0., 0., 0., 0., 0., 0., 1., 0., 0.]), [6], tensor([0.0125]))


 99%|█████████▉| 99/100 [03:23<00:01,  1.71s/it]

Started recording episode 99 to checkpoints/tictactoe_muzero_4/step_10/videos/random/tictactoe_muzero_4/None/episode_000099.mp4


100%|██████████| 100/100 [03:24<00:00,  2.05s/it]


Stopped recording episode 99. Recorded 6 frames.
Testing Player 1 vs Agent random


  0%|          | 0/100 [00:00<?, ?it/s]

Player 0 Random Action: 0
Player 1 Prediction: (tensor([0.1700, 0.1050, 0.1050, 0.2150, 0.1000, 0.0750, 0.0600, 0.1700]), tensor([0.0000, 0.1700, 0.1050, 0.1050, 0.2150, 0.1000, 0.0750, 0.0600, 0.1700]), [1, 2, 3, 4, 5, 6, 7, 8], tensor([0.0133]))
Player 0 Random Action: 5
Player 1 Prediction: (tensor([0.2600, 0.1650, 0.1400, 0.0950, 0.0800, 0.2600]), tensor([0.0000, 0.2600, 0.0000, 0.1650, 0.1400, 0.0000, 0.0950, 0.0800, 0.2600]), [1, 3, 4, 6, 7, 8], tensor([0.0361]))
Player 0 Random Action: 7
Player 1 Prediction: (tensor([0.2050, 0.2450, 0.1850, 0.3650]), tensor([0.0000, 0.0000, 0.0000, 0.2050, 0.2450, 0.0000, 0.1850, 0.0000, 0.3650]), [3, 4, 6, 8], tensor([0.0210]))
Player 0 Random Action: 3


  1%|          | 1/100 [00:01<02:52,  1.74s/it]

Player 1 Prediction: (tensor([0.4700, 0.5300]), tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4700, 0.0000, 0.5300]), [6, 8], tensor([0.0441]))
Player 0 Random Action: 6


 99%|█████████▉| 99/100 [02:36<00:01,  1.64s/it]

Started recording episode 199 to checkpoints/tictactoe_muzero_4/step_10/videos/random/tictactoe_muzero_4/None/episode_000199.mp4


100%|██████████| 100/100 [02:38<00:00,  1.58s/it]


Stopped recording episode 199. Recorded 10 frames.


 99%|█████████▉| 99/100 [05:41<00:03,  3.37s/it]

Started recording episode 299 to checkpoints/tictactoe_muzero_4/step_10/videos/tictactoe_muzero_4/10/episode_000299.mp4


100%|██████████| 100/100 [05:44<00:00,  3.45s/it]

Stopped recording episode 299. Recorded 8 frames.
average score: 0.27





Plotting score...
scores are win/loss plotting a rolling average of the scores
Plotting policy_loss...
Plotting value_loss...
Plotting loss...
Plotting test_score...
Plotting test_score_vs_random...


  axs[row][col].legend()
  axs[row][col].set_xlim(1, len(scores))
  axs[row][col].legend()
100%|██████████| 64/64 [03:32<00:00,  3.32s/it]


Samples: {'observations': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 0, 0]]],


       [[[1, 0, 0],
         [0, 0, 0],
         [0, 0, 1]],

        [[0, 0, 0],
         [1, 0, 1],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 1]]],


       [[[0, 0, 0],
         [0, 1, 0],
         [0, 0, 1]],

        [[1, 1, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[1, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],

        [[0, 1, 0],
         [1, 1, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 1, 0],
         [

 95%|█████████▌| 61/64 [07:15<00:29,  9.91s/it]

In [None]:
import pickle

file_name = "tictactoe_muzero"

trials = pickle.load(open(f"./{file_name}_trials.p", "rb"))
print("Found saved Trials! Loading...")

print(trials.trials)
print(table.bayes_elo()["Elo table"])

In [None]:
# shared network but not shared buffer?
# 1 vs 2 minibatches
import sys

sys.path.append("../..")

from dqn.NFSP.nfsp_agent_clean import NFSPDQN
from agent_configs import NFSPDQNConfig
from game_configs import TicTacToeConfig
from utils import KLDivergenceLoss, CategoricalCrossentropyLoss, HuberLoss, MSELoss
from torch.optim import Adam, SGD

config_dict = {
    "shared_networks_and_buffers": False,
    "training_steps": 10000,
    "anticipatory_param": 0.1,
    "replay_interval": 128,  #
    "num_minibatches": 1,  # or 2, could be 2 minibatches per network, or 2 minibatches (1 for each network/player)
    "learning_rate": 0.1,
    "momentum": 0.0,
    "optimizer": SGD,
    "loss_function": MSELoss(),
    "min_replay_buffer_size": 128,
    "minibatch_size": 128,
    "replay_buffer_size": 2e5,
    "transfer_interval": 300,
    "residual_layers": [(128, 3, 1)] * 3,
    "conv_layers": [(32, 3, 1)],
    "dense_layer_widths": [],
    "value_hidden_layer_widths": [],
    "advantage_hidden_layer_widths": [],
    "noisy_sigma": 0.0,
    "eg_epsilon": 0.06,
    # "eg_epsilon_final": 0.06,
    "eg_epsilon_decay_type": "inverse_sqrt",
    "eg_epsilon_decay_final_step": 0,
    "sl_learning_rate": 0.005,
    "sl_momentum": 0.0,
    # "sl_weight_decay": 1e-9,
    # "sl_clipnorm": 1.0,
    "sl_optimizer": SGD,
    "sl_loss_function": CategoricalCrossentropyLoss(),
    "sl_min_replay_buffer_size": 128,
    "sl_minibatch_size": 128,
    "sl_replay_buffer_size": 2000000,
    "sl_residual_layers": [(128, 3, 1)] * 3,
    "sl_conv_layers": [(32, 3, 1)],
    "sl_dense_layer_widths": [],
    "sl_clip_low_prob": 0.0,
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_epsilon": 0.00001,
    "n_step": 1,
    "atom_size": 1,
    "dueling": False,
    "clipnorm": 10.0,
    "sl_clipnorm": 10.0,
}
config = NFSPDQNConfig(
    config_dict=config_dict,
    game_config=TicTacToeConfig(),
)
config.save_intermediate_weights = True

In [None]:
from pettingzoo.classic import tictactoe_v3

env = tictactoe_v3.env(render_mode="rgb_array")

print(env.observation_space("player_0"))

agent = NFSPDQN(env, config, name="NFSP-TicTacToe-Standard", device="cpu")

In [None]:
agent.checkpoint_interval = 100
agent.checkpoint_trials = 100
agent.train()

In [None]:
# shared network but not shared buffer?
# 1 vs 2 minibatches
import sys

sys.path.append("../..")

from dqn.NFSP.nfsp_agent_clean import NFSPDQN
from agent_configs import NFSPDQNConfig
from game_configs import TicTacToeConfig
from utils import KLDivergenceLoss, CategoricalCrossentropyLoss, HuberLoss, MSELoss
from torch.optim import Adam, SGD

config_dict = {
    "shared_networks_and_buffers": False,
    "training_steps": 10000,
    "anticipatory_param": 0.1,
    "replay_interval": 128,  #
    "num_minibatches": 1,  # or 2, could be 2 minibatches per network, or 2 minibatches (1 for each network/player)
    "learning_rate": 0.1,
    "momentum": 0.0,
    "optimizer": SGD,
    "loss_function": KLDivergenceLoss(),
    "min_replay_buffer_size": 1000,
    "minibatch_size": 128,
    "replay_buffer_size": 2e5,
    "transfer_interval": 300,
    "residual_layers": [(128, 3, 1)] * 3,
    "conv_layers": [(32, 3, 1)],
    "dense_layer_widths": [],
    "value_hidden_layer_widths": [],
    "advantage_hidden_layer_widths": [],
    "noisy_sigma": 0.06,
    "eg_epsilon": 0.0,
    # "eg_epsilon_final": 0.06,
    "eg_epsilon_decay_type": "inverse_sqrt",
    "eg_epsilon_decay_final_step": 0,
    "sl_learning_rate": 0.005,
    "sl_momentum": 0.0,
    # "sl_weight_decay": 1e-9,
    # "sl_clipnorm": 1.0,
    "sl_optimizer": SGD,
    "sl_loss_function": CategoricalCrossentropyLoss(),
    "sl_min_replay_buffer_size": 1000,
    "sl_minibatch_size": 128,
    "sl_replay_buffer_size": 2000000,
    "sl_residual_layers": [(128, 3, 1)] * 3,
    "sl_conv_layers": [(32, 3, 1)],
    "sl_dense_layer_widths": [],
    "sl_clip_low_prob": 0.0,
    "per_alpha": 0.5,
    "per_beta": 0.5,
    "per_beta_final": 1.0,
    "per_epsilon": 0.00001,
    "n_step": 3,
    "atom_size": 51,
    "dueling": True,
    "clipnorm": 10.0,
    "sl_clipnorm": 10.0,
}
config = NFSPDQNConfig(
    config_dict=config_dict,
    game_config=TicTacToeConfig(),
)
config.save_intermediate_weights = True

In [None]:
from pettingzoo.classic import tictactoe_v3

env = tictactoe_v3.env(render_mode="rgb_array")

print(env.observation_space("player_0"))

agent = NFSPDQN(env, config, name="NFSP-TicTacToe-Rainbow", device="cpu")

In [None]:
agent.checkpoint_interval = 100
agent.checkpoint_trials = 100
agent.train()