In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [0, 500],
    "residual_layers": [],  # ??? more depth better? up to depth 16, need at most 16 filters
    "representation_dense_layer_widths": [512, 64],
    "dynamics_dense_layer_widths": [512, 64],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [512],  # ???
    "critic_dense_layer_widths": [512],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 0.25,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 50,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.5, 0.25],
    "temperature_updates": [30000, 60000],
    "temperature_with_training_steps": True,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.005,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "discount_factor": 0.997,
    "value_loss_function": CategoricalCrossentropyLoss(),
    "reward_loss_function": CategoricalCrossentropyLoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_onehot,
    "training_steps": 100000,
    "minibatch_size": 128,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 5000,  # ???
    "replay_buffer_size": 50000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 10,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "he_normal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": True,
    "per_initial_priority_max": True,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 6,  # ???
    "lr_ratio": float("inf"),  # 0.1
    # "lr_ratio": 0.1,  # 0.1
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": 31,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# # run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = CartPoleConfig().make_env()
game_config = CartPoleConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_cartpole",
    device="cpu",
    test_agents=[],
)

In [None]:
agent.checkpoint_interval = 10
agent.test_interval = 250
agent.test_trials = 50
agent.train()

In [None]:
import sys


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(16, 3, 1)] * 3,
    "conv_layers": [],
    "dense_layers": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_conv_layers": [(16, 1, 1)],
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [8],
    "critic_dense_layer_widths": [8],
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    # "games_per_generation": 64,  # 32
    "value_loss_factor": 0.25,
    "root_dirichlet_alpha": 0.1,
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # try larger
    "temperatures": [1.0, 0.1],
    "temperature_updates": [5],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,  # Adam, SGD
    "learning_rate": 0.003,
    "momentum": 0.0,
    "adam_epsilon": 1e-8,  # try lower
    "value_loss_function": CategoricalCrossentropyLoss(),  #  MSELoss(),
    "reward_loss_function": CategoricalCrossentropyLoss(),  # MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 100000,
    "minibatch_size": 64,
    "min_replay_buffer_size": 64,  # try larger
    "replay_buffer_size": 27000,  # try larger
    "unroll_steps": 20,  # 20
    "n_step": 20,  # 20
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "he_normal",  # try he_normal
    "per_alpha": 0.5,
    "per_beta": 1.0,
    "per_beta_final": 1.0,  # 0.0 was original
    "per_use_batch_weights": True,
    "per_initial_priority_max": False,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 1,
    "reanalyze": True,  # TODO
    "support_range": 10,  # None
}

# REANALYZE NOT IMPLEMENTED BUT NEVER USED IN THING I SAW ONLINE (like it is computed but never ends up used since bootstrap index always > than len of game history)

# add a max depth to the tree
# game priority is max of position priorities of the game
# with these two changes should be an exact match with online github implementation

env = tictactoe_v3.env(render_mode="rgb_array")
game_config = TicTacToeConfig(tictactoe_v3.env)
game_config.has_intermediate_rewards = True
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe-test-github",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

In [None]:
agent.checkpoint_interval = 50
agent.test_interval = 1000
agent.test_trials = 1000
agent.train()

In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [],  # ??? more depth better? up to depth 16, need at most 16 filters
    "representation_dense_layer_widths": [256, 64],
    "dynamics_dense_layer_widths": [256, 64],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [256],  # ???
    "critic_dense_layer_widths": [256],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 0.25,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [5],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.002,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_onehot,
    "training_steps": 100000,
    "minibatch_size": 128,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 5000,  # ???
    "replay_buffer_size": 20000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "he_normal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": True,
    "per_initial_priority_max": True,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 6,  # ???
    # "lr_ratio": float("inf"),  # 0.1
    "lr_ratio": 0.1,  # 0.1
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# # run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe-dense-5moves_1",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

tested the ones i found on the github and they worked good? but i think i am gonna do less workers because they did 480k episodes after 32k training steps, i got 250k after like 5k steps, so i want to change that

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 250
agent.test_trials = 100
agent.train()

In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(16, 3, 1)],
    "representation_dense_layer_widths": [],
    "dynamics_dense_layer_widths": [],
    "actor_conv_layers": [(8, 1, 1)],  # ???
    "critic_conv_layers": [(8, 1, 1)],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],  # ???
    "critic_dense_layer_widths": [],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 1.8,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [2],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": SGD,
    "learning_rate": 0.1,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 4000,
    "minibatch_size": 32,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 1000,  # 9000 # ???
    "replay_buffer_size": 10000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 1.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "orthogonal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": False,
    "per_initial_priority_max": False,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 2,  # ???
    "lr_ratio": float("inf"),
    # "lr_ratio": 0.1,
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# add frame stacking
# run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe_hyperopt",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

with 2 moves left one losing one drawing, it finds the drawing move, but evaluates the winning position after tree search as 0 instead of 1 (though predicts it as 1). shown in paper 2m-3

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 250
agent.test_trials = 100
agent.train()

In [None]:
import torch
from packages.utils.utils.utils import process_petting_zoo_obs

from pettingzoo.classic import tictactoe_v3


def play_game(player1, player2):

    env = tictactoe_v3.env(render_mode="rgb_array")
    with torch.no_grad():  # No gradient computation during testing
        # Reset environment
        env.reset()
        state, reward, termination, truncation, info = env.last()
        done = termination or truncation
        agent_id = env.agent_selection
        current_player = env.agents.index(agent_id)
        state, info = process_petting_zoo_obs(state, info, current_player)
        agent_names = env.agents.copy()

        episode_length = 0
        while not done and episode_length < 1000:  # Safety limit
            # Get current agent and player
            episode_length += 1

            # Get action from average strategy
            if current_player == 0:
                prediction = player1.predict(state, info, env=env)
                action = player1.select_actions(prediction, info).item()
            else:
                prediction = player2.predict(state, info, env=env)
                action = player2.select_actions(prediction, info).item()

            # Step environment
            env.step(action)
            state, reward, termination, truncation, info = env.last()
            agent_id = env.agent_selection
            current_player = env.agents.index(agent_id)
            state, info = process_petting_zoo_obs(state, info, current_player)
            done = termination or truncation
        print(env.rewards)
        return env.rewards["player_0"]

In [None]:
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from elo.elo import StandingsTable


random_vs_expert_table = StandingsTable([agent, TicTacToeBestAgent()], start_elo=1400)
random_vs_expert_table.play_1v1_tournament(1000, play_game)
print(random_vs_expert_table.bayes_elo())
print(random_vs_expert_table.get_win_table())

In [None]:
import numpy as np


class TicTacToeBestAgent:
    def __init__(self, model_name="tictactoe_expert"):
        self.model_name = model_name

    def predict(self, observation, info, env=None):
        return observation, info

    def select_actions(self, prediction, info):
        # Reconstruct board: +1 for current player, -1 for opponent, 0 otherwise
        board = prediction[0][0] - prediction[0][1]
        print(board)
        # Default: random legal move
        action = np.random.choice(info["legal_moves"])

        # Horizontal and vertical checks
        for i in range(3):
            # Row
            if np.sum(board[i, :]) == 2 and 0 in board[i, :]:
                ind = np.where(board[i, :] == 0)[0][0]
                return np.ravel_multi_index((i, ind), (3, 3))
            elif abs(np.sum(board[i, :])) == 2 and 0 in board[i, :]:
                ind = np.where(board[i, :] == 0)[0][0]
                action = np.ravel_multi_index((i, ind), (3, 3))

            # Column
            if np.sum(board[:, i]) == 2 and 0 in board[:, i]:
                ind = np.where(board[:, i] == 0)[0][0]
                return np.ravel_multi_index((ind, i), (3, 3))
            elif abs(np.sum(board[:, i])) == 2 and 0 in board[:, i]:
                ind = np.where(board[:, i] == 0)[0][0]
                action = np.ravel_multi_index((ind, i), (3, 3))

        # Diagonals
        diag = board.diagonal()
        if np.sum(diag) == 2 and 0 in diag:
            ind = np.where(diag == 0)[0][0]
            return np.ravel_multi_index((ind, ind), (3, 3))
        elif abs(np.sum(diag)) == 2 and 0 in diag:
            ind = np.where(diag == 0)[0][0]
            action = np.ravel_multi_index((ind, ind), (3, 3))

        anti_diag = np.fliplr(board).diagonal()
        if np.sum(anti_diag) == 2 and 0 in anti_diag:
            ind = np.where(anti_diag == 0)[0][0]
            return np.ravel_multi_index((ind, 2 - ind), (3, 3))
        elif abs(np.sum(anti_diag)) == 2 and 0 in anti_diag:
            ind = np.where(anti_diag == 0)[0][0]
            action = np.ravel_multi_index((ind, 2 - ind), (3, 3))

        return action

In [1]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(16, 3, 1)],
    "representation_dense_layer_widths": [],
    "dynamics_dense_layer_widths": [],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],  # ???
    "critic_dense_layer_widths": [],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 2.0,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [5],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.001,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.0,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 33000,
    "minibatch_size": 32,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 1000,  # 9000 # ???
    "replay_buffer_size": 40000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "orthogonal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": False,
    "per_initial_priority_max": False,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 2,  # ???
    "lr_ratio": float("inf"),
    # "lr_ratio": 0.1,
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# add frame stacking
# run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe_hyperopt-2",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],  # RandomAgent(),
)

agent.checkpoint_interval = 100
agent.test_interval = 1000
agent.test_trials = 500
agent.train()

 80%|████████  | 402/500 [02:07<00:32,  3.01it/s]

Losses 0.9789805334761468 8.047132734420302 0.04376915918994026 9.069881439208984
Training Step: 19496


 81%|████████  | 404/500 [02:08<00:27,  3.50it/s]

Losses 1.382751777095169 7.666744708432816 0.034938833933484514 9.084436416625977
Training Step: 19497


 81%|████████  | 406/500 [02:08<00:22,  4.12it/s]

Losses 1.5402220708115308 8.675712044889224 0.023078601911809105 10.23901081085205
Training Step: 19498


 82%|████████▏ | 408/500 [02:09<00:21,  4.25it/s]

Losses 1.705008872465616 7.504435320966877 0.05348969953879745 9.262932777404785
Training Step: 19499


 82%|████████▏ | 409/500 [02:09<00:20,  4.35it/s]

Losses 1.7297336760951083 8.347355518300901 0.015066018676199555 10.092151641845703
Training Step: 19500
Saving Checkpoint
plotting score
plotting policy_loss
plotting value_loss
plotting reward_loss
plotting loss
plotting test_score
  subkey score
  subkey max_score
  subkey min_score
plotting test_score_vs_random
  subkey score
  subkey player_0_score
  subkey player_1_score
  subkey player_0_win%
  subkey player_1_win%
plotting test_score_vs_tictactoe_expert
  subkey score
  subkey player_0_score
  subkey player_1_score
  subkey player_0_win%
  subkey player_1_win%


 84%|████████▍ | 420/500 [02:12<00:22,  3.63it/s]

unroll step 0
predicted value tensor([0.2296], grad_fn=<UnbindBackward0>)
target value tensor(1.)
predicted reward tensor([0.], grad_fn=<UnbindBackward0>)
target reward tensor(0.)
predicted policy tensor([1.6884e-01, 1.4176e-01, 6.3912e-01, 5.2639e-03, 1.0492e-04, 4.9950e-03,
        6.5236e-03, 3.6114e-03, 2.9786e-02], grad_fn=<UnbindBackward0>)
target policy tensor([0.0800, 0.1600, 0.7200, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0400])
sample losses tensor([0.5934], grad_fn=<MulBackward0>) tensor(0.) tensor(0.9178, grad_fn=<NegBackward0>)
unroll step 1
predicted value tensor([-0.5842], grad_fn=<UnbindBackward0>)
target value tensor(-1.)
predicted reward tensor([0.0008], grad_fn=<UnbindBackward0>)
target reward tensor(0.)
predicted policy tensor([5.3861e-02, 2.1185e-02, 2.7938e-04, 6.4275e-03, 9.5829e-06, 1.6870e-02,
        1.0514e-01, 1.8217e-01, 6.1406e-01], grad_fn=<UnbindBackward0>)
target policy tensor([0.0000, 0.0800, 0.0000, 0.0000, 0.0000, 0.0000, 0.1200, 0.1600, 0.6400])
s

 84%|████████▍ | 421/500 [02:12<00:23,  3.43it/s]

Losses 1.428133685464445 8.509321720535809 0.01954444478828421 9.957001686096191
Training Step: 19501


 85%|████████▍ | 423/500 [02:13<00:20,  3.72it/s]

Losses 1.310933444682492 8.12679902146192 0.036966388467649436 9.474699020385742
Training Step: 19502


 85%|████████▌ | 425/500 [02:13<00:20,  3.73it/s]

Losses 1.760165622836331 7.817286049292306 0.01064504792875498 9.58809757232666
Training Step: 19503


 85%|████████▌ | 427/500 [02:14<00:20,  3.55it/s]

Losses 1.2802201423065283 7.887582034112711 0.01442713857574085 9.182226181030273
Training Step: 19504


 86%|████████▌ | 428/500 [02:14<00:21,  3.35it/s]

Losses 1.663477180571178 8.356820088163659 0.026165115016080875 10.046463012695312
Training Step: 19505


 86%|████████▌ | 430/500 [02:15<00:18,  3.83it/s]

Losses 1.34631562793771 7.3633041919601965 0.005520519280871841 8.715140342712402
Training Step: 19506


 86%|████████▋ | 432/500 [02:15<00:19,  3.48it/s]

Losses 1.8985915506513251 7.4575026708698715 0.034641120688465366 9.390734672546387
Training Step: 19507


 87%|████████▋ | 433/500 [02:16<00:23,  2.88it/s]

Losses 1.5090985537870543 8.970156170420523 0.06892091919571869 10.548171043395996
Training Step: 19508


 87%|████████▋ | 435/500 [02:16<00:19,  3.28it/s]

Losses 1.1485258037751418 7.786399624754267 0.006573351108879549 8.941497802734375
Training Step: 19509


 87%|████████▋ | 437/500 [02:17<00:16,  3.77it/s]

Losses 1.2056949288485157 7.984938715839235 0.05122756575469967 9.241859436035156
Training Step: 19510


 88%|████████▊ | 438/500 [02:17<00:17,  3.54it/s]

Losses 0.7435094367952644 8.018978505991981 0.007485334164910599 8.769974708557129
Training Step: 19511


 88%|████████▊ | 441/500 [02:18<00:14,  3.99it/s]

Losses 0.8993861633763114 7.596389985534188 0.00797521245677979 8.503751754760742
Training Step: 19512


 89%|████████▊ | 443/500 [02:18<00:13,  4.20it/s]

Losses 1.1284538021735107 7.533998463652097 0.013016028025860528 8.675469398498535
Training Step: 19513


 89%|████████▉ | 444/500 [02:19<00:13,  4.16it/s]

Losses 1.4010892530125432 6.857556072522129 0.0019284472928834395 8.260573387145996
Training Step: 19514


 89%|████████▉ | 446/500 [02:19<00:11,  4.51it/s]

Losses 1.5493259182755323 7.947292871576792 0.06534046078688957 9.561958312988281
Training Step: 19515


 90%|████████▉ | 448/500 [02:20<00:13,  3.98it/s]

Losses 2.0945289943220686 8.03136102372082 0.030503526638396186 10.156394004821777
Training Step: 19516


 90%|█████████ | 450/500 [02:20<00:10,  4.56it/s]

Losses 1.2180150658082902 7.846633883447794 0.025503144754411622 9.090147972106934
Training Step: 19517


 90%|█████████ | 452/500 [02:21<00:12,  3.98it/s]

Losses 1.9895737983611603 7.48112910753116 0.054936359336215906 9.52563762664795
Training Step: 19518


 91%|█████████ | 454/500 [02:21<00:11,  4.09it/s]

Losses 1.1202673303109854 7.664214143442223 0.009897077563127007 8.794381141662598
Training Step: 19519


 91%|█████████ | 456/500 [02:22<00:11,  3.99it/s]

Losses 1.9816745429520157 7.896648661240761 0.020692259632124677 9.899015426635742
Training Step: 19520


 91%|█████████▏| 457/500 [02:22<00:11,  3.76it/s]

Losses 1.811029830079626 7.409418554379954 0.016808992654667194 9.237258911132812
Training Step: 19521


 92%|█████████▏| 459/500 [02:22<00:10,  3.78it/s]

Losses 1.7051222954308844 8.038682541504386 0.062390846761004704 9.806194305419922
Training Step: 19522


 92%|█████████▏| 461/500 [02:23<00:10,  3.84it/s]

Losses 1.2127873419487294 8.187517616373952 0.0027069884676789446 9.403010368347168
Training Step: 19523


 92%|█████████▏| 462/500 [02:23<00:10,  3.75it/s]

Losses 1.7045460001894737 8.224431575086783 0.033382396207151555 9.962362289428711
Training Step: 19524


 93%|█████████▎| 464/500 [02:24<00:09,  3.66it/s]

Losses 1.004788455064606 8.020301043703512 0.01800532235126795 9.043097496032715
Training Step: 19525


 93%|█████████▎| 466/500 [02:24<00:09,  3.52it/s]

Losses 2.6536475722423276 7.3547227378658135 0.007509188979110432 10.015878677368164
Training Step: 19526


 93%|█████████▎| 467/500 [02:25<00:09,  3.40it/s]

Losses 1.7275052344957147 7.854047432549123 0.03360947224014765 9.615161895751953
Training Step: 19527


 94%|█████████▍| 470/500 [02:25<00:07,  4.02it/s]

Losses 2.3147399547048257 8.17988294755196 0.018464241146797233 10.51308822631836
Training Step: 19528


 94%|█████████▍| 471/500 [02:26<00:07,  3.83it/s]

Losses 0.8711543622763387 8.430681710633507 0.03367738865958081 9.335514068603516
Training Step: 19529


 94%|█████████▍| 472/500 [02:26<00:07,  3.85it/s]

Losses 1.6898273275639113 8.904097436869051 0.06030814002408036 10.654229164123535
Training Step: 19530


 95%|█████████▍| 474/500 [02:26<00:06,  4.06it/s]

Losses 1.268007149443136 8.665311318407475 0.009079448896493064 9.94239616394043
Training Step: 19531


 95%|█████████▌| 476/500 [02:27<00:06,  3.59it/s]

Losses 1.4478558649117033 8.671462070851703 0.018363487935148082 10.13768482208252
Training Step: 19532


 96%|█████████▌| 478/500 [02:28<00:06,  3.62it/s]

Losses 1.3909509338238353 7.960506363735476 0.05644484780921433 9.407902717590332
Training Step: 19533


 96%|█████████▌| 480/500 [02:28<00:05,  3.84it/s]

Losses 1.8304960899197456 8.880853796108568 0.06880919616357695 10.780160903930664
Training Step: 19534


 96%|█████████▌| 481/500 [02:28<00:05,  3.57it/s]

Losses 1.525561789276277 7.369323395221727 0.012033973655082353 8.906919479370117
Training Step: 19535


 97%|█████████▋| 483/500 [02:29<00:04,  3.61it/s]

Losses 1.3450949434594683 8.051947015803307 0.006566448964477445 9.403608322143555
Training Step: 19536


 97%|█████████▋| 485/500 [02:30<00:04,  3.50it/s]

Losses 1.4173133054303695 7.664755324309226 0.00358671246947687 9.085653305053711
Training Step: 19537


 97%|█████████▋| 486/500 [02:30<00:04,  2.96it/s]

Losses 1.453832037197374 8.183243490275345 0.032111970872141904 9.669191360473633
Training Step: 19538


 98%|█████████▊| 488/500 [02:31<00:03,  3.10it/s]

Losses 1.6959041533510404 7.863974294916261 0.05824682188366648 9.618128776550293
Training Step: 19539


 98%|█████████▊| 490/500 [02:31<00:02,  3.54it/s]

Losses 1.2044289101439856 7.885773060130305 0.007407651349758848 9.097612380981445
Training Step: 19540


 98%|█████████▊| 492/500 [02:32<00:02,  3.56it/s]

Losses 1.088523434264279 8.675466101696657 0.03222390624375193 9.796215057373047
Training Step: 19541


 99%|█████████▉| 494/500 [02:32<00:01,  3.63it/s]

Losses 1.342499735072115 7.836370058124885 0.014007017057358087 9.192876815795898
Training Step: 19542


 99%|█████████▉| 496/500 [02:33<00:00,  4.12it/s]

Losses 1.603873225494271 7.290694295981666 0.030021609552143502 8.924586296081543
Training Step: 19543


100%|█████████▉| 498/500 [02:33<00:00,  3.47it/s]

Losses 1.5613351436037064 7.80715021368087 0.027043398036646682 9.395528793334961
Training Step: 19544
Started recording episode 1499 to checkpoints/muzero_tictactoe_hyperopt-2/step_19000/videos/muzero_tictactoe_hyperopt-2/episode_001499.mp4


100%|█████████▉| 499/500 [02:34<00:00,  3.56it/s]

Stopped recording episode 1499. Recorded 6 frames.
average score: 0.156
Test score {'score': 0.156, 'max_score': 1, 'min_score': -1}
Losses 1.9430123336031742 8.122265110825538 0.02560769626367254 10.090887069702148
Training Step: 19545


100%|██████████| 500/500 [02:34<00:00,  3.24it/s]


Losses 1.6567725476548052 7.810082039679401 0.007249742786176849 9.474103927612305
Training Step: 19546
Losses 0.7640155943672085 8.037580729680485 0.01581317637899249 8.81740951538086
Training Step: 19547
Losses 1.2171300614391498 7.935810570095782 0.005974214149103152 9.158920288085938
Training Step: 19548
Losses 1.4163666904385082 8.222698607176426 0.042307989141465774 9.681368827819824
Training Step: 19549
Losses 1.5870062337871005 7.58452379135997 0.058396954217055264 9.229926109313965
Training Step: 19550
Losses 1.2253188036916114 7.540393577270152 0.006637683252957771 8.77235221862793
Training Step: 19551
Losses 0.9857457185960641 8.57183543722931 0.007664591939475329 9.565242767333984
Training Step: 19552
Losses 1.422819176508077 8.16500855977938 0.029776635807755025 9.617602348327637
Training Step: 19553
Losses 1.5804467356653298 8.519526532065356 0.06718832353329732 10.167158126831055
Training Step: 19554
Losses 1.7865356309619704 7.400916676611814 0.0033578358584652745 9.190

Process Process-1:
Process Process-2:
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.10/3.10.14/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Cellar/python@3.10/3.10.14/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/muzero/muzero_agent_torch.py", line 219, in worker_fn
    score, num_steps = self.play_game(env=worker_env)
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/muzero/muzero_agent_torch.py", line 707, in play_game
    prediction = self.predict(
  File "/Users/jonathanlamontange-kratz/Documents/GitHub/rl-stuff/muzero/muzero_agent_torch.py", line 643, in predict
    value, visit_counts = self.monte_carlo_tree_search(
  File "/Users/jonathanlamontange-kratz/Documents/GitHub

KeyboardInterrupt: 

libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe


In [None]:
agent.config.training_steps += 100
agent.train()

In [None]:
agent.config.num_simulations = 800

In [9]:
import sys

sys.path.append("../")
from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    TwoPlayerPlayerPlaneWrapper,
    FrameStackWrapper,
)
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig

# from agents.tictactoe_expert import TicTacToeBestAgent

# from agents.random import RandomAgent
from supersuit import frame_stack_v1, agent_indicator_v0

best_agent = TicTacToeBestAgent()
env = tictactoe_v3.env(render_mode="rgb_array")
print(env.observation_space("player_0"))
env = ActionMaskInInfoWrapper(env)
print(env.observation_space("player_0"))
env = FrameStackWrapper(env, 4, channel_first=False)
print(env.observation_space("player_0"))
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
print(env.observation_space("player_0"))
env = ChannelLastToFirstWrapper(env)
print(env.observation_space("player_0"))

env = TicTacToeConfig().make_env()
env.reset()
env.step(0)
env.step(6)
env.step(2)

state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])
for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])


action = agent.select_actions(prediction).item()

selected_actions = {i: 0 for i in range(agent.num_actions)}
for i in range(100):
    selected_actions[agent.select_actions(prediction).item()] += 1
print(selected_actions)
print("Action", action)
env.step(action)
state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])
for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])


action = agent.select_actions(prediction).item()
print("Action", action)
env.step(action)
state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])

for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])

action = agent.select_actions(prediction).item()
print("Action", action)
# env.step(action)
# state, reward, terminated, truncated, info = env.last()
# prediction = agent.predict(state, info, env)
# print("MCTS Prediction", prediction)
# initial_inference = agent.predict_single_initial_inference(state, info)
# print("Initial Value", initial_inference[0])
# print("Initial Policy", initial_inference[1])
# action = agent.select_actions(prediction).item()
# print("Action", action)
# env.step(action)

Dict('action_mask': Box(0, 1, (9,), int8), 'observation': Box(0, 1, (3, 3, 2), int8))
Box(0, 1, (3, 3, 2), int8)
Box(0, 1, (3, 3, 8), int8)
Box(0, 1, (3, 3, 9), int8)
Box(0, 1, (9, 3, 3), int8)
MCTS Prediction (tensor([0.0000, 0.1600, 0.8400, 0.0000, 0.0000, 0.0000]), tensor([0.0000, 0.0000, 0.0000, 0.1600, 0.8400, 0.0000, 0.0000, 0.0000, 0.0000]), [1, 3, 4, 5, 7, 8], 0.1852449732180685)
Initial Value tensor([-0.5853], grad_fn=<SelectBackward0>)
Initial Policy tensor([1.0743e-04, 7.6787e-02, 2.1165e-04, 4.6148e-01, 3.6692e-01, 6.7124e-02,
        7.2198e-04, 9.0125e-03, 1.7639e-02], grad_fn=<SelectBackward0>)
Move 1
Reccurent Value tensor([0.4830], grad_fn=<SelectBackward0>)
Reccurent Reward tensor([0.0614], grad_fn=<SelectBackward0>)
Reccurent Policy tensor([1.0179e-05, 1.3826e-04, 3.9588e-04, 1.3460e-01, 5.3352e-01, 2.8216e-01,
        1.0824e-04, 3.7477e-03, 4.5313e-02], grad_fn=<SelectBackward0>)
Move 3
Reccurent Value tensor([0.8411], grad_fn=<SelectBackward0>)
Reccurent Reward te

In [4]:
from copy import deepcopy
from math import log, sqrt, inf
import copy
import math
import numpy as np


class Node:
    def __init__(self, prior_policy):
        self.visits = 0
        self.to_play = -1
        self.prior_policy = prior_policy
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expand(self, legal_moves, to_play, policy, hidden_state, reward):
        self.to_play = to_play
        self.reward = reward
        self.hidden_state = hidden_state
        # print(legal_moves)
        policy = {a: policy[a] for a in legal_moves}
        policy_sum = sum(policy.values())

        for action, p in policy.items():
            self.children[action] = Node((p / (policy_sum + 1e-10)).item())

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visits == 0:
            return 0
        return self.value_sum / self.visits

    def add_noise(self, dirichlet_alpha, exploration_fraction):
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior_policy = (1 - frac) * self.children[
                a
            ].prior_policy + frac * n

    def select_child(self, min_max_stats, pb_c_base, pb_c_init, discount, num_players):
        # Select the child with the highest UCB
        child_ucbs = [
            self.child_ucb_score(
                child, min_max_stats, pb_c_base, pb_c_init, discount, num_players
            )
            for action, child in self.children.items()
        ]
        print("Child UCBs", child_ucbs)
        action_index = np.random.choice(
            np.where(np.isclose(child_ucbs, max(child_ucbs)))[0]
        )
        action = list(self.children.keys())[action_index]
        return action, self.children[action]

    def child_ucb_score(
        self, child, min_max_stats, pb_c_base, pb_c_init, discount, num_players
    ):
        pb_c = log((self.visits + pb_c_base + 1) / pb_c_base) + pb_c_init
        pb_c *= sqrt(self.visits) / (child.visits + 1)

        prior_score = pb_c * child.prior_policy
        if child.visits > 0:
            value_score = min_max_stats.normalize(
                child.reward
                + discount
                * (
                    child.value() if num_players == 1 else -child.value()
                )  # (or if on the same team)
            )
        else:
            value_score = 0.0

        # check if value_score is nan
        assert (
            value_score == value_score
        ), "value_score is nan, child value is {}, and reward is {},".format(
            child.value(),
            child.reward,
        )
        assert prior_score == prior_score, "prior_score is nan"
        print("Prior Score", prior_score)
        print("Value Score", value_score)
        return prior_score + value_score
        # return value_score

In [None]:
agent.config.num_simulations = 10

In [5]:
import collections
from typing import Optional

from pyparsing import List

MAXIMUM_FLOAT_VALUE = float("inf")


class MinMaxStats(object):
    def __init__(
        self, known_bounds: Optional[List[float]]
    ):  # might need to say known_bounds=None
        self.max = known_bounds[1] if known_bounds else MAXIMUM_FLOAT_VALUE
        self.min = known_bounds[0] if known_bounds else -MAXIMUM_FLOAT_VALUE

    def update(self, value: float):
        self.max = max(self.max, value)
        self.min = min(self.min, value)

    def normalize(self, value: float) -> float:
        print("Initial value", value)
        if self.max > self.min:
            # We normalize only when we have at a max and min value
            print("normalized value", (value - self.min) / (self.max - self.min))
            return (value - self.min) / (self.max - self.min)
        return value

    def __repr__(self):
        return f"min: {self.min}, max: {self.max}"

In [None]:
from utils.utils import get_legal_moves

# from muzero.muzero_mcts import Node
# from muzero.muzero_minmax_stats import MinMaxStats


root = Node(0.0)
_, policy, hidden_state = agent.predict_single_initial_inference(
    state,
    info,
)
print("root policy", policy)
legal_moves = get_legal_moves(info)[0]
to_play = env.agents.index(env.agent_selection)
root.expand(legal_moves, to_play, policy, hidden_state, 0.0)
print("expanded root")
min_max_stats = MinMaxStats(agent.config.known_bounds)

for _ in range(agent.config.num_simulations):
    print("at root")
    node = root
    search_path = [node]
    to_play = env.agents.index(env.agent_selection)

    # GO UNTIL A LEAF NODE IS REACHED
    while node.expanded():
        print("selecting child")
        action, node = node.select_child(
            min_max_stats,
            agent.config.pb_c_base,
            agent.config.pb_c_init,
            agent.config.discount_factor,
            agent.config.game.num_players,
        )
        print("Selected action", action)
        # THIS NEEDS TO BE CHANGED FOR GAMES WHERE PLAYER COUNT DECREASES AS PLAYERS GET ELIMINATED, USE agent_selector.next() (clone of the current one)
        to_play = (to_play + 1) % agent.config.game.num_players
        search_path.append(node)
    parent = search_path[-2]
    reward, hidden_state, value, policy = agent.predict_single_recurrent_inference(
        parent.hidden_state,
        action,  # model=model
    )
    reward = reward.item()
    value = value.item()
    print("leaf value", value)
    print("leaf reward", reward)

    node.expand(
        list(range(agent.num_actions)),
        to_play,
        policy,
        hidden_state,
        (
            reward  # if self.config.game.has_intermediate_rewards else 0.0
        ),  # for board games and games with no intermediate rewards
    )

    for node in reversed(search_path):
        node.value_sum += value if node.to_play == to_play else -value
        node.visits += 1
        min_max_stats.update(
            node.reward
            + agent.config.discount_factor
            * (node.value() if agent.config.game.num_players == 1 else -node.value())
        )
        value = (
            -node.reward
            if node.to_play == to_play and agent.config.game.num_players > 1
            else node.reward
        ) + agent.config.discount_factor * value

    visit_counts = [(child.visits, action) for action, child in root.children.items()]

print(visit_counts)

root policy tensor([3.8667e-05, 1.0536e-01, 8.9040e-05, 7.6352e-01, 2.9112e-05, 2.5028e-02,
        6.0007e-04, 1.5519e-03, 1.0378e-01], grad_fn=<SelectBackward0>)
expanded root
at root
selecting child
Prior Score 0.0
Value Score 0.0
Prior Score 0.0
Value Score 0.0
Prior Score 0.0
Value Score 0.0
Prior Score 0.0
Value Score 0.0
Child UCBs [0.0, 0.0, 0.0, 0.0]
Selected action 1
leaf value 0.15232999622821808
leaf reward 0.06857515871524811
at root
selecting child
Initial value -0.08375483751296997
normalized value 0.458122581243515
Prior Score 0.06600853030837159
Value Score 0.458122581243515
Prior Score 0.9566919723439027
Value Score 0.0
Prior Score 0.03136020727562555
Value Score 0.0
Prior Score 0.13003252074056998
Value Score 0.0
Child UCBs [0.5241311115518866, 0.9566919723439027, 0.03136020727562555, 0.13003252074056998]
Selected action 3
leaf value 0.5102172493934631
leaf reward 0.028090298175811768
at root
selecting child
Initial value -0.08375483751296997
normalized value 0.45812

: 