In [1]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [0, 500],
    "residual_layers": [],  # ??? more depth better? up to depth 16, need at most 16 filters
    "representation_dense_layer_widths": [512, 64],
    "dynamics_dense_layer_widths": [512, 64],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [512],  # ???
    "critic_dense_layer_widths": [512],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 0.25,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 50,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.5, 0.25],
    "temperature_updates": [30000, 60000],
    "temperature_with_training_steps": True,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.005,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "discount_factor": 0.997,
    "value_loss_function": CategoricalCrossentropyLoss(),
    "reward_loss_function": CategoricalCrossentropyLoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_onehot,
    "training_steps": 100000,
    "minibatch_size": 128,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 5000,  # ???
    "replay_buffer_size": 50000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 10,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "he_normal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": True,
    "per_initial_priority_max": True,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 6,  # ???
    "lr_ratio": float("inf"),  # 0.1
    # "lr_ratio": 0.1,  # 0.1
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": 31,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# # run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = CartPoleConfig().make_env()
game_config = CartPoleConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_cartpole",
    device="cpu",
    test_agents=[],
)

ModuleNotFoundError: No module named 'wrappers'

In [None]:
agent.checkpoint_interval = 10
agent.test_interval = 250
agent.test_trials = 50
agent.train()

In [None]:
import sys


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(24, 3, 1)] * 1,  # increase num layers
    "conv_layers": [],
    "dense_layers": [],
    "actor_conv_layers": [(16, 1, 1)],
    "critic_conv_layers": [(16, 1, 1)],
    "reward_conv_layers": [(16, 1, 1)],
    "actor_dense_layer_widths": [],
    "critic_dense_layer_widths": [],
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 0.25,
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # try larger
    "temperatures": [1.0, 0.1],
    "temperature_updates": [8],  # change this
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.001,  # slightly increase this, maybe 0.002 or 0.003
    "momentum": 0.0,
    "adam_epsilon": 1e-8,  # try lower
    "value_loss_function": MSELoss(),  #  MSELoss(),
    "reward_loss_function": MSELoss(),  # MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 100000,
    "minibatch_size": 8,
    "min_replay_buffer_size": 4000,  # try lower (or just different)
    "replay_buffer_size": 100000,  # try lower, 50k or 20k or 10k
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "glorot_normal",  # try different
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,  # 0.0 was original
    "per_use_batch_weights": False,
    "per_initial_priority_max": True,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 2,
    "reanalyze": True,  # TODO
    "support_range": None,  # None
}

# REANALYZE NOT IMPLEMENTED BUT NEVER USED IN THING I SAW ONLINE (like it is computed but never ends up used since bootstrap index always > than len of game history)

# add a max depth to the tree
# game priority is max of position priorities of the game
# with these two changes should be an exact match with online github implementation

env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe-hyperopt-best",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

In [None]:
agent.checkpoint_interval = 50
agent.test_interval = 1000
agent.test_trials = 300
agent.train()

In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [],  # ??? more depth better? up to depth 16, need at most 16 filters
    "representation_dense_layer_widths": [256, 64],
    "dynamics_dense_layer_widths": [256, 64],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [256],  # ???
    "critic_dense_layer_widths": [256],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 0.25,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [5],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.002,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_onehot,
    "training_steps": 100000,
    "minibatch_size": 128,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 5000,  # ???
    "replay_buffer_size": 20000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "he_normal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": True,
    "per_initial_priority_max": True,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 6,  # ???
    # "lr_ratio": float("inf"),  # 0.1
    "lr_ratio": 0.1,  # 0.1
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# # run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe-dense-5moves_1",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

tested the ones i found on the github and they worked good? but i think i am gonna do less workers because they did 480k episodes after 32k training steps, i got 250k after like 5k steps, so i want to change that

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 250
agent.test_trials = 100
agent.train()

In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(16, 3, 1)],
    "representation_dense_layer_widths": [],
    "dynamics_dense_layer_widths": [],
    "actor_conv_layers": [(8, 1, 1)],  # ???
    "critic_conv_layers": [(8, 1, 1)],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],  # ???
    "critic_dense_layer_widths": [],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 1.8,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [2],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": SGD,
    "learning_rate": 0.1,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.9,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 4000,
    "minibatch_size": 32,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 1000,  # 9000 # ???
    "replay_buffer_size": 10000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 1.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "orthogonal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": False,
    "per_initial_priority_max": False,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 2,  # ???
    "lr_ratio": float("inf"),
    # "lr_ratio": 0.1,
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# add frame stacking
# run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe_hyperopt",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],
)

with 2 moves left one losing one drawing, it finds the drawing move, but evaluates the winning position after tree search as 0 instead of 1 (though predicts it as 1). shown in paper 2m-3

In [None]:
agent.checkpoint_interval = 100
agent.test_interval = 250
agent.test_trials = 100
agent.train()

In [None]:
import torch
from packages.utils.utils.utils import process_petting_zoo_obs

from pettingzoo.classic import tictactoe_v3


def play_game(player1, player2):

    env = tictactoe_v3.env(render_mode="rgb_array")
    with torch.no_grad():  # No gradient computation during testing
        # Reset environment
        env.reset()
        state, reward, termination, truncation, info = env.last()
        done = termination or truncation
        agent_id = env.agent_selection
        current_player = env.agents.index(agent_id)
        state, info = process_petting_zoo_obs(state, info, current_player)
        agent_names = env.agents.copy()

        episode_length = 0
        while not done and episode_length < 1000:  # Safety limit
            # Get current agent and player
            episode_length += 1

            # Get action from average strategy
            if current_player == 0:
                prediction = player1.predict(state, info, env=env)
                action = player1.select_actions(prediction, info).item()
            else:
                prediction = player2.predict(state, info, env=env)
                action = player2.select_actions(prediction, info).item()

            # Step environment
            env.step(action)
            state, reward, termination, truncation, info = env.last()
            agent_id = env.agent_selection
            current_player = env.agents.index(agent_id)
            state, info = process_petting_zoo_obs(state, info, current_player)
            done = termination or truncation
        print(env.rewards)
        return env.rewards["player_0"]

In [None]:
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from elo.elo import StandingsTable


random_vs_expert_table = StandingsTable([agent, TicTacToeBestAgent()], start_elo=1400)
random_vs_expert_table.play_1v1_tournament(1000, play_game)
print(random_vs_expert_table.bayes_elo())
print(random_vs_expert_table.get_win_table())

In [None]:
import numpy as np


class TicTacToeBestAgent:
    def __init__(self, model_name="tictactoe_expert"):
        self.model_name = model_name

    def predict(self, observation, info, env=None):
        return observation, info

    def select_actions(self, prediction, info):
        # Reconstruct board: +1 for current player, -1 for opponent, 0 otherwise
        board = prediction[0][0] - prediction[0][1]
        print(board)
        # Default: random legal move
        action = np.random.choice(info["legal_moves"])

        # Horizontal and vertical checks
        for i in range(3):
            # Row
            if np.sum(board[i, :]) == 2 and 0 in board[i, :]:
                ind = np.where(board[i, :] == 0)[0][0]
                return np.ravel_multi_index((i, ind), (3, 3))
            elif abs(np.sum(board[i, :])) == 2 and 0 in board[i, :]:
                ind = np.where(board[i, :] == 0)[0][0]
                action = np.ravel_multi_index((i, ind), (3, 3))

            # Column
            if np.sum(board[:, i]) == 2 and 0 in board[:, i]:
                ind = np.where(board[:, i] == 0)[0][0]
                return np.ravel_multi_index((ind, i), (3, 3))
            elif abs(np.sum(board[:, i])) == 2 and 0 in board[:, i]:
                ind = np.where(board[:, i] == 0)[0][0]
                action = np.ravel_multi_index((ind, i), (3, 3))

        # Diagonals
        diag = board.diagonal()
        if np.sum(diag) == 2 and 0 in diag:
            ind = np.where(diag == 0)[0][0]
            return np.ravel_multi_index((ind, ind), (3, 3))
        elif abs(np.sum(diag)) == 2 and 0 in diag:
            ind = np.where(diag == 0)[0][0]
            action = np.ravel_multi_index((ind, ind), (3, 3))

        anti_diag = np.fliplr(board).diagonal()
        if np.sum(anti_diag) == 2 and 0 in anti_diag:
            ind = np.where(anti_diag == 0)[0][0]
            return np.ravel_multi_index((ind, 2 - ind), (3, 3))
        elif abs(np.sum(anti_diag)) == 2 and 0 in anti_diag:
            ind = np.where(anti_diag == 0)[0][0]
            action = np.ravel_multi_index((ind, 2 - ind), (3, 3))

        return action

In [None]:
import sys

from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    FrameStackWrapper,
    TwoPlayerPlayerPlaneWrapper,
)


sys.path.append("../..")
from agent_configs import MuZeroConfig
import gymnasium as gym
from utils.utils import CategoricalCrossentropyLoss
from action_functions import action_as_plane, action_as_onehot
from muzero_agent_torch import MuZeroAgent
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig, CartPoleConfig
from utils import MSELoss
import torch
import os
from torch.optim import Adam, SGD
from agents.random import RandomAgent
from agents.tictactoe_expert import TicTacToeBestAgent
from supersuit import frame_stack_v1, agent_indicator_v0

# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["MKL_NUM_THREADS"] = "1"
# torch.set_num_threads(1)


config = {
    "known_bounds": [-1, 1],
    "residual_layers": [(16, 3, 1)],
    "representation_dense_layer_widths": [],
    "dynamics_dense_layer_widths": [],
    "actor_conv_layers": [],  # ???
    "critic_conv_layers": [],  # ???
    "reward_conv_layers": [],
    "actor_dense_layer_widths": [],  # ???
    "critic_dense_layer_widths": [],  # ???
    "reward_dense_layer_widths": [],
    "conv_layers": [],
    "dense_layer_widths": [],
    "noisy_sigma": 0.0,
    "value_loss_factor": 1.0,
    "root_dirichlet_alpha": 2.0,  # ???
    "root_exploration_fraction": 0.25,
    "num_simulations": 25,  # ??? goal is to increase this and see if it learns faster
    "temperatures": [1.0, 0.1],
    "temperature_updates": [5],
    "temperature_with_training_steps": False,
    "clip_low_prob": 0.0,
    "pb_c_base": 19652,
    "pb_c_init": 1.25,
    "optimizer": Adam,
    "learning_rate": 0.001,  # ??? find a learning rate that works okay (no exploding, but not too small) # 0.1 to 0.01 decrease to 10% of the init value after 400k steps in pseudocode, but 0.2 in alphazero paper (and decreased 3 times)
    "momentum": 0.0,
    "adam_epsilon": 1e-8,
    "value_loss_function": MSELoss(),
    "reward_loss_function": MSELoss(),
    "policy_loss_function": CategoricalCrossentropyLoss(),
    "action_function": action_as_plane,
    "training_steps": 33000,
    "minibatch_size": 32,  # ??? this should be about 0.1 of the number of positions collected... or is it in the replay buffer? AlphaZero did a batch size of 4096 muzero 2048, and they said this was about 0.1.
    "min_replay_buffer_size": 1000,  # 9000 # ???
    "replay_buffer_size": 40000,  # ??? paper used a buffer size of 1M games
    "unroll_steps": 5,
    "n_step": 9,
    "clipnorm": 0.0,
    "weight_decay": 0.0001,
    "kernel_initializer": "orthogonal",  # ???
    "per_alpha": 0.0,
    "per_beta": 0.0,
    "per_beta_final": 0.0,
    "per_use_batch_weights": False,
    "per_initial_priority_max": False,
    "per_epsilon": 0.0001,
    "multi_process": True,
    "num_workers": 2,  # ???
    "lr_ratio": float("inf"),
    # "lr_ratio": 0.1,
    "games_per_generation": 8,  # ??? AlphaZero did ~64 games per generation
    "reanalyze": True,  # TODO
    "support_range": None,
}

# DO AN ABALATION ON NUM SIMULATIONS, THE OG PAPER FOUND MORE SIMULATIONS MEANS BETTER LEARNING SIGNAL
# CHECK MY MCTS STUFF, IS THE SIGN CORRECT? IS IT CORRECT WITH REWARDS? IS IT CORRECT FOR TERMINAL STATES?

# steps, run tictactoe on fast settings for at least 200k steps, see if it learns to play okay
# add only updating mcts network every x steps
# add a ratio for learning steps to self play steps
# add frame stacking
# run it without multiprocessing
# increase num simulations to 50 or 100 and see if it learns faster


env = tictactoe_v3.env(render_mode="rgb_array")
env = ActionMaskInInfoWrapper(env)
env = FrameStackWrapper(env, 4, channel_first=False)
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
env = ChannelLastToFirstWrapper(env)

game_config = TicTacToeConfig()
config = MuZeroConfig(config, game_config)

agent = MuZeroAgent(
    env,
    config,
    name="muzero_tictactoe_hyperopt-2",
    device="cpu",
    test_agents=[RandomAgent(), TicTacToeBestAgent()],  # RandomAgent(),
)

agent.checkpoint_interval = 100
agent.test_interval = 1000
agent.test_trials = 500
agent.train()

In [None]:
agent.config.training_steps += 100
agent.train()

In [None]:
agent.config.num_simulations = 800

In [None]:
import sys

sys.path.append("../")
from wrappers import (
    ActionMaskInInfoWrapper,
    ChannelLastToFirstWrapper,
    TwoPlayerPlayerPlaneWrapper,
    FrameStackWrapper,
)
from pettingzoo.classic import tictactoe_v3
from game_configs import TicTacToeConfig

# from agents.tictactoe_expert import TicTacToeBestAgent

# from agents.random import RandomAgent
from supersuit import frame_stack_v1, agent_indicator_v0

best_agent = TicTacToeBestAgent()
env = tictactoe_v3.env(render_mode="rgb_array")
print(env.observation_space("player_0"))
env = ActionMaskInInfoWrapper(env)
print(env.observation_space("player_0"))
env = FrameStackWrapper(env, 4, channel_first=False)
print(env.observation_space("player_0"))
env = TwoPlayerPlayerPlaneWrapper(env, channel_first=False)
print(env.observation_space("player_0"))
env = ChannelLastToFirstWrapper(env)
print(env.observation_space("player_0"))

env = TicTacToeConfig().make_env()
env.reset()
env.step(0)
env.step(6)
env.step(2)

state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])
for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])


action = agent.select_actions(prediction).item()

selected_actions = {i: 0 for i in range(agent.num_actions)}
for i in range(100):
    selected_actions[agent.select_actions(prediction).item()] += 1
print(selected_actions)
print("Action", action)
env.step(action)
state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])
for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])


action = agent.select_actions(prediction).item()
print("Action", action)
env.step(action)
state, reward, terminated, truncated, info = env.last()
prediction = agent.predict(state, info, env)
print("MCTS Prediction", prediction)
initial_inference = agent.predict_single_initial_inference(state, info)
print("Initial Value", initial_inference[0])
print("Initial Policy", initial_inference[1])

for move in info["legal_moves"]:
    reccurent_inference = agent.predict_single_recurrent_inference(
        initial_inference[2], move
    )
    print("Move", move)
    print("Reccurent Value", reccurent_inference[2])
    print("Reccurent Reward", reccurent_inference[0])
    print("Reccurent Policy", reccurent_inference[3])

action = agent.select_actions(prediction).item()
print("Action", action)
# env.step(action)
# state, reward, terminated, truncated, info = env.last()
# prediction = agent.predict(state, info, env)
# print("MCTS Prediction", prediction)
# initial_inference = agent.predict_single_initial_inference(state, info)
# print("Initial Value", initial_inference[0])
# print("Initial Policy", initial_inference[1])
# action = agent.select_actions(prediction).item()
# print("Action", action)
# env.step(action)

In [None]:
from copy import deepcopy
from math import log, sqrt, inf
import copy
import math
import numpy as np


class Node:
    def __init__(self, prior_policy):
        self.visits = 0
        self.to_play = -1
        self.prior_policy = prior_policy
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expand(self, legal_moves, to_play, policy, hidden_state, reward):
        self.to_play = to_play
        self.reward = reward
        self.hidden_state = hidden_state
        # print(legal_moves)
        policy = {a: policy[a] for a in legal_moves}
        policy_sum = sum(policy.values())

        for action, p in policy.items():
            self.children[action] = Node((p / (policy_sum + 1e-10)).item())

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visits == 0:
            return 0
        return self.value_sum / self.visits

    def add_noise(self, dirichlet_alpha, exploration_fraction):
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior_policy = (1 - frac) * self.children[
                a
            ].prior_policy + frac * n

    def select_child(self, min_max_stats, pb_c_base, pb_c_init, discount, num_players):
        # Select the child with the highest UCB
        child_ucbs = [
            self.child_ucb_score(
                child, min_max_stats, pb_c_base, pb_c_init, discount, num_players
            )
            for action, child in self.children.items()
        ]
        print("Child UCBs", child_ucbs)
        action_index = np.random.choice(
            np.where(np.isclose(child_ucbs, max(child_ucbs)))[0]
        )
        action = list(self.children.keys())[action_index]
        return action, self.children[action]

    def child_ucb_score(
        self, child, min_max_stats, pb_c_base, pb_c_init, discount, num_players
    ):
        pb_c = log((self.visits + pb_c_base + 1) / pb_c_base) + pb_c_init
        pb_c *= sqrt(self.visits) / (child.visits + 1)

        prior_score = pb_c * child.prior_policy
        if child.visits > 0:
            value_score = min_max_stats.normalize(
                child.reward
                + discount
                * (
                    child.value() if num_players == 1 else -child.value()
                )  # (or if on the same team)
            )
        else:
            value_score = 0.0

        # check if value_score is nan
        assert (
            value_score == value_score
        ), "value_score is nan, child value is {}, and reward is {},".format(
            child.value(),
            child.reward,
        )
        assert prior_score == prior_score, "prior_score is nan"
        print("Prior Score", prior_score)
        print("Value Score", value_score)
        return prior_score + value_score
        # return value_score

In [None]:
agent.config.num_simulations = 10

In [None]:
import collections
from typing import Optional

from pyparsing import List

MAXIMUM_FLOAT_VALUE = float("inf")


class MinMaxStats(object):
    def __init__(
        self, known_bounds: Optional[List[float]]
    ):  # might need to say known_bounds=None
        self.max = known_bounds[1] if known_bounds else MAXIMUM_FLOAT_VALUE
        self.min = known_bounds[0] if known_bounds else -MAXIMUM_FLOAT_VALUE

    def update(self, value: float):
        self.max = max(self.max, value)
        self.min = min(self.min, value)

    def normalize(self, value: float) -> float:
        print("Initial value", value)
        if self.max > self.min:
            # We normalize only when we have at a max and min value
            print("normalized value", (value - self.min) / (self.max - self.min))
            return (value - self.min) / (self.max - self.min)
        return value

    def __repr__(self):
        return f"min: {self.min}, max: {self.max}"

In [None]:
from utils.utils import get_legal_moves

# from muzero.muzero_mcts import Node
# from muzero.muzero_minmax_stats import MinMaxStats


root = Node(0.0)
_, policy, hidden_state = agent.predict_single_initial_inference(
    state,
    info,
)
print("root policy", policy)
legal_moves = get_legal_moves(info)[0]
to_play = env.agents.index(env.agent_selection)
root.expand(legal_moves, to_play, policy, hidden_state, 0.0)
print("expanded root")
min_max_stats = MinMaxStats(agent.config.known_bounds)

for _ in range(agent.config.num_simulations):
    print("at root")
    node = root
    search_path = [node]
    to_play = env.agents.index(env.agent_selection)

    # GO UNTIL A LEAF NODE IS REACHED
    while node.expanded():
        print("selecting child")
        action, node = node.select_child(
            min_max_stats,
            agent.config.pb_c_base,
            agent.config.pb_c_init,
            agent.config.discount_factor,
            agent.config.game.num_players,
        )
        print("Selected action", action)
        # THIS NEEDS TO BE CHANGED FOR GAMES WHERE PLAYER COUNT DECREASES AS PLAYERS GET ELIMINATED, USE agent_selector.next() (clone of the current one)
        to_play = (to_play + 1) % agent.config.game.num_players
        search_path.append(node)
    parent = search_path[-2]
    reward, hidden_state, value, policy = agent.predict_single_recurrent_inference(
        parent.hidden_state,
        action,  # model=model
    )
    reward = reward.item()
    value = value.item()
    print("leaf value", value)
    print("leaf reward", reward)

    node.expand(
        list(range(agent.num_actions)),
        to_play,
        policy,
        hidden_state,
        (
            reward  # if self.config.game.has_intermediate_rewards else 0.0
        ),  # for board games and games with no intermediate rewards
    )

    for node in reversed(search_path):
        node.value_sum += value if node.to_play == to_play else -value
        node.visits += 1
        min_max_stats.update(
            node.reward
            + agent.config.discount_factor
            * (node.value() if agent.config.game.num_players == 1 else -node.value())
        )
        value = (
            -node.reward
            if node.to_play == to_play and agent.config.game.num_players > 1
            else node.reward
        ) + agent.config.discount_factor * value

    visit_counts = [(child.visits, action) for action, child in root.children.items()]

print(visit_counts)

In [27]:
import sys

sys.path.append("../")
from muzero.muzero_mcts import Node
from muzero.muzero_minmax_stats import MinMaxStats
import torch

num_players = 2
min_max_stats = MinMaxStats([-1, 1])


def make_search_path():
    root = Node(0.0)
    policy = torch.tensor([0.0, 1.0])
    hidden_state = torch.tensor([1])
    legal_moves = [0, 1]
    root.expand(legal_moves, 0, policy, hidden_state, 0.0)

    search_path = [root]
    node = root.children[0]
    search_path.append(node)
    node.expand(legal_moves, 1, policy, hidden_state, 0.0)
    node = node.children[0]
    search_path.append(node)
    node.expand(legal_moves, 1, policy, hidden_state, 1.0)
    node = node.children[0]
    search_path.append(node)

    value = 0.0
    reward = 0.0
    to_play = 0
    node.expand(legal_moves, to_play, policy, hidden_state, reward)
    print("leaf value", value)
    print("leaf reward", reward)
    print("leaf to_play", to_play)

    return search_path, to_play, value


search_path_1, to_play_1, value_1 = make_search_path()
search_path_2, to_play_2, value_2 = make_search_path()

for _ in range(1):
    for node_1, node_2 in zip(reversed(search_path_1), reversed(search_path_2)):
        print(node_1)
        print("node 1 to_play", node_1.to_play)
        print("init value sum 1", node_1.value_sum)
        node_1.value_sum += value_1 if node_1.to_play == to_play_1 else -value_1
        print("new value sum 1", node_1.value_sum)
        node_1.visits += 1
        min_max_update_1 = node_1.reward + 1.0 * (
            node_1.value() if node_1.to_play == to_play_1 else -node_1.value()
        )
        print("min max update 1", min_max_update_1)
        min_max_stats.update(min_max_update_1)

        print(node_2)
        print("node to_play", node_2.to_play)
        print("init value sum", node_2.value_sum)
        node_2.value_sum += value_2 if node_2.to_play == to_play_2 else -value_2
        print("new value sum", node_2.value_sum)
        node_2.visits += 1
        if num_players == 1:
            min_max_update_2 = node_2.reward + 1.0 * node_2.value()
        elif node_2.to_play == to_play_2:
            min_max_update_2 = node_2.reward + 1.0 * node_2.value()
        else:
            min_max_update_2 = node_2.reward + 1.0 * (-node_2.value())
        print("min max update 2", min_max_update_2)
        min_max_stats.update(min_max_update_2)

        print(node_1.value_sum == node_2.value_sum)
        print(node_1.visits == node_2.visits)
        print(min_max_update_1 == min_max_update_2)

        # ✅ METHOD 1 (baseline)
        if node_1.to_play == to_play_1 and num_players > 1:
            value_1 = -node_1.reward + 1.0 * value_1
        else:
            value_1 = node_1.reward + 1.0 * value_1

        print("next value to be added", value_1)
        print("method 1 value", node_1.value())

        # ✅ METHOD 2 (generalized version)
        if num_players == 1:
            value_2 = node_2.reward + 1.0 * value_2
        elif node_2.to_play == to_play_2:
            value_2 = -node_2.reward + 1.0 * value_2
        else:
            value_2 = node_2.reward + 1.0 * value_2
        print("next value to be added", value_2)
        print("method 2 value", node_2.value())

        print(value_1 == value_2)

leaf value 1.0
leaf reward 0.0
leaf to_play 0
leaf value 1.0
leaf reward 0.0
leaf to_play 0
<muzero.muzero_mcts.Node object at 0x12ff21ae0>
node 1 to_play 0
init value sum 1 0
new value sum 1 1.0
min max update 1 1.0
<muzero.muzero_mcts.Node object at 0x12ff204f0>
node to_play 0
init value sum 0
new value sum 1.0
min max update 2 1.0
True
True
True
next value to be added 1.0
method 1 value 1.0
next value to be added 1.0
method 2 value 1.0
True
<muzero.muzero_mcts.Node object at 0x12ff23e20>
node 1 to_play 1
init value sum 1 0
new value sum 1 -1.0
min max update 1 2.0
<muzero.muzero_mcts.Node object at 0x12ff207f0>
node to_play 1
init value sum 0
new value sum -1.0
min max update 2 2.0
True
True
True
next value to be added 2.0
method 1 value -1.0
next value to be added 2.0
method 2 value -1.0
True
<muzero.muzero_mcts.Node object at 0x12ff21060>
node 1 to_play 1
init value sum 1 0
new value sum 1 -2.0
min max update 1 2.0
<muzero.muzero_mcts.Node object at 0x12ff23340>
node to_play 1
ini

In [7]:
import torch
import sys
import math

sys.path.append("../")
from muzero.muzero_mcts import Node
from muzero.muzero_minmax_stats import MinMaxStats


def make_search_path(path_config):
    """
    path_config is a list of tuples: (to_play, reward)
    Last element has the leaf value and to_play
    """
    root = Node(0.0)
    policy = torch.tensor([0.0, 1.0])
    hidden_state = torch.tensor([1])
    legal_moves = [0, 1]

    # Initialize root (player 0)
    root.expand(legal_moves, 0, policy, hidden_state, 0.0)

    search_path = [root]
    node = root.children[0]

    # Build path according to config
    for i, (to_play, reward) in enumerate(path_config[:-1]):
        search_path.append(node)
        node.expand(legal_moves, to_play, policy, hidden_state, reward)
        node = node.children[0]

    # Last node
    search_path.append(node)
    last_config = path_config[-1]
    leaf_to_play = last_config[0]
    leaf_reward = last_config[1]
    leaf_value = last_config[2] if len(last_config) > 2 else 0.0
    node.expand(legal_moves, leaf_to_play, policy, hidden_state, leaf_reward)

    return search_path, leaf_to_play, leaf_value


def backpropagate_method1(search_path, to_play, value, num_players):
    """Original Method 1"""
    for node in reversed(search_path):
        node.value_sum += value if node.to_play == to_play else -value
        node.visits += 1

        if node.to_play == to_play and num_players > 1:
            value = -node.reward + 1.0 * value
        else:
            value = node.reward + 1.0 * value

    return [n.value() for n in search_path]


def backpropagate_method2(search_path, to_play, value, num_players):
    """
    Robust method 2: compute, for each node in the search_path, the total return
    from that node according to that node's player's perspective, then add it
    to node.value_sum and increment visits.

    - `to_play` is the leaf player (owner of `value`).
    - `value` is the leaf value (scalar).
    - reward is stored on nodes so that search_path[j].reward is the reward
      for the parent at j-1 (i.e. it's the reward received when moving into j).
    """
    leaf_to_play = to_play
    leaf_value = value

    # For each node in the path compute the exact return from that node's perspective.
    # Complexity O(n^2) in path length — fine for typical MCTS path lengths.
    for i, node in enumerate(search_path):
        total = 0.0
        # sum future rewards (reward at search_path[j] belongs to acting_player = search_path[j-1].to_play)
        for j in range(i + 1, len(search_path)):
            acting_player = search_path[j - 1].to_play
            r = search_path[j].reward
            if acting_player == node.to_play:
                total += r
            else:
                total -= r

        # add the leaf value (belongs to leaf_to_play)
        if leaf_to_play == node.to_play:
            total += leaf_value
        else:
            total -= leaf_value

        # update node stats (value_sum / visits semantics preserved)
        node.value_sum += total
        node.visits += 1

    return [n.value() for n in search_path]


# Test cases: (path_config, expected_root_value, num_players, description)
test_cases = [
    (
        [(1, 0.0), (1, 1.0), (0, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 0",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 1",
    ),
    (
        [(1, 0.0), (1, 1.0), (0, 1.0, 0.0)],
        [-2.0, 2.0, 1.0, 0.0],
        2,
        "2-player: two player 1s both actions getting a reward (they should dont cancel), ending on a root for player 0",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 1.0, 0.0)],
        [-2.0, 2.0, 1.0, 0.0],
        2,
        "2-player: two player 1s both actions getting a reward (they should dont cancel), ending on a root for player 1",
    ),
    (
        [(1, 1.0), (1, 1.0), (0, 0.0, 0.0)],
        [0.0, 1.0, 0.0, 0.0],
        2,
        "2-player: Two player 1 turns (but player 0 got a reward), ending on player 0",
    ),
    (
        [(1, 1.0), (1, 1.0), (1, 0.0, 0.0)],
        [0.0, 1.0, 0.0, 0.0],
        2,
        "2-player: Two player 1 turns (but player 0 got a reward), ending on player 1",
    ),
    (
        [(1, 0.0), (0, 1.0), (1, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins on there first move",
    ),
    (
        [(1, 0.0), (0, 1.0), (1, 0.0), (0, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins on there first move",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 1.0, 0.0)],
        [1.0, -1.0, 1.0, 0.0],
        2,
        "2-player: alternating game, player 0 wins",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0), (0, 1.0, 0.0)],
        [-1.0, 1.0, -1.0, 1.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0, 1.0)],
        [-1.0, 1.0, -1.0, 1.0],
        2,
        "2-player: alternating game with a leaf value",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0), (0, 0.0, 1.0)],
        [1.0, -1.0, 1.0, -1.0, 1.0],
        2,
        "2-player: alternating game with a leaf value",
    ),
    (
        [(0, 1.0), (0, 1.0), (0, 1.0, 0.0)],
        [3.0, 2.0, 1.0, 0.0],
        2,
        "2-player: All player 0 turns",
    ),
    (
        [(0, 0.0), (0, 0.0), (0, 0.0, 4.0)],
        [4.0, 4.0, 4.0, 4.0],
        2,
        "2-player: All player 0 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (0, 0.0, 4.0)],
        [3.0, -3.0, -4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0, 4.0)],
        [-5.0, 5.0, 4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0), (0, 0.0, 4.0)],
        [3.0, -3.0, -4.0, -4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    # Single player test cases
    (
        [(0, 1.0), (0, 2.0), (0, 3.0, 0.0)],
        [6.0, 5.0, 3.0, 0.0],
        1,
        "1-player: All rewards sum up",
    ),
    (
        [(0, 1.0), (0, 0.0), (0, 0.0, 5.0)],
        [6.0, 5.0, 5.0, 5.0],
        1,
        "1-player: Rewards + leaf value",
    ),
]

print("Testing MuZero Value Backpropagation\n")
print("=" * 80)


def are_lists_roughly_equal(list1, list2):
    """
    Checks if two lists of floats are roughly equal in Python 2.

    Args:
        list1: The first list of floats.
        list2: The second list of floats.
        tolerance: The maximum allowed absolute difference between corresponding
                   elements for them to be considered roughly equal.

    Returns:
        True if the lists are roughly equal, False otherwise.
    """
    if len(list1) != len(list2):
        return False

    for i in range(len(list1)):
        # Compare elements using absolute tolerance
        if not math.isclose(list1[i], list2[i]):
            return False
    return True


all_passed = True
method1_correct = 0
method2_correct = 0
total_tests = 0
for i, (path_config, expected, num_players, description) in enumerate(test_cases, 1):
    print(f"\nTest Case {i}: {description}")
    print(f"Path config: {path_config}")
    print(f"Num players: {num_players}")
    print(f"Expected node values: {expected}")

    # Test Method 1
    search_path_1, to_play_1, value_1 = make_search_path(path_config)
    result_1 = backpropagate_method1(search_path_1, to_play_1, value_1, num_players)

    # Test Method 2
    search_path_2, to_play_2, value_2 = make_search_path(path_config)
    result_2 = backpropagate_method2(search_path_2, to_play_2, value_2, num_players)

    # Check results
    method1_pass = are_lists_roughly_equal(result_1, expected)
    method2_pass = are_lists_roughly_equal(result_2, expected)
    match = are_lists_roughly_equal(result_1, result_2)

    print(f"Method 1 result: {result_1} {'✓' if method1_pass else '✗'}")
    print(f"Method 2 result: {result_2} {'✓' if method2_pass else '✗'}")
    print(f"Methods match: {'✓' if match else '✗'}")
    if method1_pass:
        method1_correct += 1
    if method2_pass:
        method2_correct += 1
    if not method1_pass:
        print("❌ METHOD 1 FAILED")
    else:
        print("✅ METHOD 1 PASSED")
    if not method2_pass:
        print("❌ METHOD 2 FAILED")
    else:
        print("✅ METHOD 2 PASSED")

    # if not (method1_pass and method2_pass and match):
    #     all_passed = False
    #     print("❌ FAILED")
    # else:
    #     print("✅ PASSED")

print("\n" + "=" * 80)
if all_passed:
    print("✅ All tests passed!")
else:
    print("❌ Some tests failed")

print("Method 1 got", method1_correct)
print("Method 2 got", method2_correct)

Testing MuZero Value Backpropagation


Test Case 1: 2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 0
Path config: [(1, 0.0), (1, 1.0), (0, 0.0, 0.0)]
Num players: 2
Expected node values: [-1.0, 1.0, 0.0, 0.0]
Method 1 result: [1.0, -1.0, 0.0, 0.0] ✗
Method 2 result: [-1.0, 1.0, 0.0, 0.0] ✓
Methods match: ✗
❌ METHOD 1 FAILED
✅ METHOD 2 PASSED

Test Case 2: 2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 1
Path config: [(1, 0.0), (1, 1.0), (1, 0.0, 0.0)]
Num players: 2
Expected node values: [-1.0, 1.0, 0.0, 0.0]
Method 1 result: [1.0, -1.0, 0.0, 0.0] ✗
Method 2 result: [-1.0, 1.0, 0.0, 0.0] ✓
Methods match: ✗
❌ METHOD 1 FAILED
✅ METHOD 2 PASSED

Test Case 3: 2-player: two player 1s both actions getting a reward (they should dont cancel), ending on a root for player 0
Path config: [(1, 0.0), (1, 1.0), (0, 1.0, 0.0)]
Num players: 2
Expected node values: [-2.0, 2.0, 1.0, 0.0]
Method 1 result: 

In [7]:
import torch
import sys
import math

sys.path.append("../")
from muzero.muzero_mcts import Node
from muzero.muzero_minmax_stats import MinMaxStats


def make_search_path(path_config):
    """
    path_config is a list of tuples: (to_play, reward)
    Last element has the leaf value and to_play
    """
    root = Node(0.0)
    policy = torch.tensor([0.0, 1.0])
    hidden_state = torch.tensor([1])
    legal_moves = [0, 1]

    # Initialize root (player 0)
    root.expand(legal_moves, 0, policy, hidden_state, 0.0)

    search_path = [root]
    node = root.children[0]

    # Build path according to config
    for i, (to_play, reward) in enumerate(path_config[:-1]):
        search_path.append(node)
        node.expand(legal_moves, to_play, policy, hidden_state, reward)
        node = node.children[0]

    # Last node
    search_path.append(node)
    last_config = path_config[-1]
    leaf_to_play = last_config[0]
    leaf_reward = last_config[1]
    leaf_value = last_config[2] if len(last_config) > 2 else 0.0
    node.expand(legal_moves, leaf_to_play, policy, hidden_state, leaf_reward)

    return search_path, leaf_to_play, leaf_value


def backpropagate_method1(search_path, to_play, value, num_players, min_max_stats):
    """Original Method 1"""
    for node in reversed(search_path):
        node.value_sum += value if node.to_play == to_play else -value
        node.visits += 1
        min_max_stats.update(
            node.reward + 1.0 * (node.value() if num_players == 1 else -node.value())
        )

        if node.to_play == to_play and num_players > 1:
            value = -node.reward + 1.0 * value
        else:
            value = node.reward + 1.0 * value

    return [n.value() for n in search_path], min_max_stats


def backpropagate_method2(
    search_path, leaf_to_play, leaf_value, num_players, min_max_stats
):

    n = len(search_path)

    # 1) Compute exact total return for each node from that node's player perspective.
    #    totals[i] is the scalar to add to search_path[i].value_sum.
    totals = [0.0] * n
    for i, node in enumerate(search_path):
        total = 0.0
        # Sum future rewards: reward at search_path[j] belongs to acting_player = search_path[j-1].to_play
        for j in range(i + 1, n):
            acting_player = search_path[j - 1].to_play
            r = search_path[j].reward
            total += r if acting_player == node.to_play else -r

        # Add leaf value (owned by leaf_to_play)
        total += leaf_value if leaf_to_play == node.to_play else -leaf_value
        totals[i] = total

    # 2) Apply updates in reverse order and update MinMaxStats using parent-perspective value.
    #    For node at index i, its parent is search_path[i-1] (if i>0).
    for i in range(n - 1, -1, -1):
        node = search_path[i]

        # Update node stats (so node.value() reflects totals after update)
        node.value_sum += totals[i]
        node.visits += 1

        # Compute the scalar used to update MinMaxStats: it must be the value of this child
        # from its parent's perspective:
        # parent_value_contrib = child.reward + discount * (sign * child.value())
        # sign = +1 if child.to_play == parent.to_play (same player acts again), else -1.
        if i > 0:
            parent_to_play = search_path[i - 1].to_play
            # For single-player games, child.value() is always added (no sign flip).
            if num_players == 1:
                sign = 1.0
            else:
                sign = 1.0 if node.to_play == parent_to_play else -1.0
        else:
            # For root (no parent) we must still pass something to MinMaxStats.
            # Use sign = +1 (treat root as its own parent's perspective) for a consistent convention.
            # This is harmless because root's parent doesn't exist — MinMaxStats is just tracking
            # global min/max of these scalars for normalization.
            sign = 1.0 if num_players == 1 else 1.0

        parent_value_contrib = node.reward + 1.0 * (sign * node.value())
        min_max_stats.update(parent_value_contrib)

    return [n.value() for n in search_path], min_max_stats


def backpropagate_method3(
    search_path,  # list of nodes from root .. leaf
    leaf_to_play,  # player id that owns `leaf_value`
    leaf_value,  # scalar leaf value
    num_players,
    min_max_stats,
):
    """
    O(n) discounted backpropagation with correct sign handling for repeated same-player turns.
    - search_path: list of Nodes, index 0 is root, last is leaf.
      Each Node must expose: .reward (reward stored on the node),
      .to_play (player id who acts at that node),
      .value_sum, .visits, and .value() == value_sum/visits.
    - leaf_to_play, leaf_value: ownership and scalar of the leaf bootstrap.
    - num_players: number of players (1 for single-player).
    - discount: gamma
    - min_max_stats: MinMaxStats instance with .update(x) and .normalize(x)
    Returns: list of node.value() after updates.
    """

    n = len(search_path)
    if n == 0:
        return []

    # --- 1) Build per-player accumulator array acc[p] = Acc_p(i) for current i (starting from i = n-1) ---
    # Acc_p(i) definition: discounted return from node i for a node whose player is p:
    # Acc_p(i) = sum_{j=i+1..n-1} discount^{j-i-1} * sign(p, j) * reward_j
    #            + discount^{n-1-i} * sign(p, leaf) * leaf_value
    # Where sign(p, j) = +1 if acting_player_at_j (which is search_path[j-1].to_play) == p else -1.
    #
    # We compute Acc_p(n-1) = sign(p, leaf) * leaf_value as base, then iterate backward:
    # Acc_p(i-1) = s(p, i) * reward_i + discount * Acc_p(i)

    # Initialize acc for i = n-1 (base: discounted exponent 0 for leaf value)
    # acc is a Python list of floats length num_players
    acc = [0.0] * num_players
    for p in range(num_players):
        acc[p] = leaf_value if leaf_to_play == p else -leaf_value

    # totals[i] will hold Acc_{node_player}(i)
    totals = [0.0] * n

    # Iterate from i = n-1 down to 0
    for i in range(n - 1, -1, -1):
        node = search_path[i]
        node_player = node.to_play
        # totals for this node = acc[node_player] (current Acc_p(i))
        totals[i] = acc[node_player]

        # Prepare acc for i-1 (if any)
        if i > 0:
            # reward at index i belongs to acting_player = search_path[i-1].to_play
            r_i = search_path[i].reward
            acting_player = search_path[i - 1].to_play

            # Update per-player accumulators in O(num_players)
            # Acc_p(i-1) = sign(p, i) * r_i + discount * Acc_p(i)
            # sign(p, i) = +1 if acting_player == p else -1
            # We overwrite acc[p] in-place to be Acc_p(i-1)
            for p in range(num_players):
                sign = 1.0 if acting_player == p else -1.0
                acc[p] = sign * r_i + 1.0 * acc[p]

    # --- 2) Apply totals to nodes in reverse order and update MinMaxStats (parent-perspective scalar) ---
    # We must update nodes (value_sum, visits) from the leaf back to the root so that when
    # computing parent-perspective scalars we can use child.value() (which should reflect the
    # just-updated child totals).
    for i in range(n - 1, -1, -1):
        node = search_path[i]

        # apply computed discounted total for this node's player
        node.value_sum += totals[i]
        node.visits += 1

        # compute scalar that MinMaxStats expects for this child from its parent's perspective:
        # parent_value_contrib = child.reward + discount * (sign * child.value())
        # sign = +1 if single-player OR child.to_play == parent.to_play else -1
        if i > 0:
            parent = search_path[i - 1]
            if num_players == 1:
                sign = 1.0
            else:
                sign = 1.0 if node.to_play == parent.to_play else -1.0
        else:
            # root: choose sign = +1 convention (root has no parent)
            sign = 1.0

        parent_value_contrib = node.reward + 1.0 * (sign * node.value())
        min_max_stats.update(parent_value_contrib)

    # Return updated node values
    return [node.value() for node in search_path], min_max_stats


# Test cases: (path_config, expected_root_value, num_players, description)
test_cases = [
    (
        [(1, 0.0), (1, 1.0), (0, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 0",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 1",
    ),
    (
        [(1, 0.0), (1, 1.0), (0, 1.0, 0.0)],
        [-2.0, 2.0, 1.0, 0.0],
        2,
        "2-player: two player 1s both actions getting a reward (they should dont cancel), ending on a root for player 0",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 1.0, 0.0)],
        [-2.0, 2.0, 1.0, 0.0],
        2,
        "2-player: two player 1s both actions getting a reward (they should dont cancel), ending on a root for player 1",
    ),
    (
        [(1, 1.0), (1, 1.0), (0, 0.0, 0.0)],
        [0.0, 1.0, 0.0, 0.0],
        2,
        "2-player: Two player 1 turns (but player 0 got a reward), ending on player 0",
    ),
    (
        [(1, 1.0), (1, 1.0), (1, 0.0, 0.0)],
        [0.0, 1.0, 0.0, 0.0],
        2,
        "2-player: Two player 1 turns (but player 0 got a reward), ending on player 1",
    ),
    (
        [(1, 0.0), (0, 1.0), (1, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins on there first move",
    ),
    (
        [(1, 0.0), (0, 1.0), (1, 0.0), (0, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins on there first move",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 1.0, 0.0)],
        [1.0, -1.0, 1.0, 0.0],
        2,
        "2-player: alternating game, player 0 wins",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0), (0, 1.0, 0.0)],
        [-1.0, 1.0, -1.0, 1.0, 0.0],
        2,
        "2-player: alternating game, player 1 wins",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0, 1.0)],
        [-1.0, 1.0, -1.0, 1.0],
        2,
        "2-player: alternating game with a leaf value",
    ),
    (
        [(1, 0.0), (0, 0.0), (1, 0.0), (0, 0.0, 1.0)],
        [1.0, -1.0, 1.0, -1.0, 1.0],
        2,
        "2-player: alternating game with a leaf value",
    ),
    (
        [(0, 1.0), (0, 1.0), (0, 1.0, 0.0)],
        [3.0, 2.0, 1.0, 0.0],
        2,
        "2-player: All player 0 turns",
    ),
    (
        [(0, 0.0), (0, 0.0), (0, 0.0, 4.0)],
        [4.0, 4.0, 4.0, 4.0],
        2,
        "2-player: All player 0 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (0, 0.0, 4.0)],
        [3.0, -3.0, -4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0, 4.0)],
        [-5.0, 5.0, 4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    (
        [(1, 0.0), (1, 1.0), (1, 0.0), (0, 0.0, 4.0)],
        [3.0, -3.0, -4.0, -4.0, 4.0],
        2,
        "2-player: Two player 1 turns with leaf value",
    ),
    # Single player test cases
    (
        [(0, 1.0), (0, 2.0), (0, 3.0, 0.0)],
        [6.0, 5.0, 3.0, 0.0],
        1,
        "1-player: All rewards sum up",
    ),
    (
        [(0, 1.0), (0, 0.0), (0, 0.0, 5.0)],
        [6.0, 5.0, 5.0, 5.0],
        1,
        "1-player: Rewards + leaf value",
    ),
    (
        [(1, 0.0), (2, 0.0), (0, 1.0), (0, 0.0, 0.0)],
        [-1.0, -1.0, 1.0, 0.0, 0.0],
        3,
        "3-player: Player 2 wins",
    ),
    (
        [(1, 1.0), (2, 0.0), (0, 0.0), (0, 0.0, 0.0)],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        3,
        "3-player: Player 0 wins",
    ),
    (
        [(1, 0.0), (2, 1.0), (0, 0.0), (0, 0.0, 0.0)],
        [-1.0, 1.0, 0.0, 0.0, 0.0],
        3,
        "3-player: Player 1 wins",
    ),
    (
        [(1, 0.0), (2, 0.0), (0, 0.0), (0, 1.0, 0.0)],
        [1.0, -1.0, -1.0, 1.0, 0.0],
        3,
        "3-player: Player 0 wins",
    ),
    (
        [(1, 0.0), (2, 0.0), (0, 0.0), (1, 0.0, 1.0)],
        [-1.0, 1.0, -1.0, -1.0, 1.0],
        3,
        "3-player: player 1 ends with a value prediction",
    ),
]

print("Testing MuZero Value Backpropagation\n")
print("=" * 80)


def are_lists_roughly_equal(list1, list2):
    """
    Checks if two lists of floats are roughly equal in Python 2.

    Args:
        list1: The first list of floats.
        list2: The second list of floats.
        tolerance: The maximum allowed absolute difference between corresponding
                   elements for them to be considered roughly equal.

    Returns:
        True if the lists are roughly equal, False otherwise.
    """
    if len(list1) != len(list2):
        return False

    for i in range(len(list1)):
        # Compare elements using absolute tolerance
        if not math.isclose(list1[i], list2[i]):
            return False
    return True


all_passed = True
method1_correct = 0
method2_correct = 0
method3_correct = 0
for i, (path_config, expected, num_players, description) in enumerate(test_cases, 1):
    print(f"\nTest Case {i}: {description}")
    print(f"Path config: {path_config}")
    print(f"Num players: {num_players}")
    print(f"Expected node values: {expected}")

    # Test Method 1
    search_path_1, to_play_1, value_1 = make_search_path(path_config)
    result_1, min_max_stats_1 = backpropagate_method1(
        search_path_1,
        to_play_1,
        value_1,
        num_players,
        min_max_stats=MinMaxStats(known_bounds=[-1, 1]),
    )

    # Test Method 2
    search_path_2, to_play_2, value_2 = make_search_path(path_config)
    result_2, min_max_stats_2 = backpropagate_method2(
        search_path_2,
        to_play_2,
        value_2,
        num_players,
        min_max_stats=MinMaxStats(known_bounds=[-1, 1]),
    )

    # Test Method 3
    search_path_3, to_play_3, value_3 = make_search_path(path_config)
    result_3, min_max_stats_3 = backpropagate_method3(
        search_path_3,
        to_play_3,
        value_3,
        num_players,
        min_max_stats=MinMaxStats(known_bounds=[-1, 1]),
    )

    # Check results
    method1_pass = are_lists_roughly_equal(result_1, expected)
    method2_pass = are_lists_roughly_equal(result_2, expected)
    method3_pass = are_lists_roughly_equal(result_3, expected)
    match1 = are_lists_roughly_equal(result_1, result_2)
    match2 = are_lists_roughly_equal(result_2, result_3)
    match3 = are_lists_roughly_equal(result_1, result_3)

    print(f"Method 1 result: {result_1} {'✓' if method1_pass else '✗'}")
    print(f"Method 2 result: {result_2} {'✓' if method2_pass else '✗'}")
    print(f"Method 3 result: {result_3} {'✓' if method3_pass else '✗'}")
    print(f"Methods 1 and 2 match: {'✓' if match1 else '✗'}")
    print(f"Methods 2 and 3 match: {'✓' if match2 else '✗'}")
    print(f"Methods 1 and 3 match: {'✓' if match3 else '✗'}")
    print(
        f"MinMaxStats maxes match:  {'✓' if min_max_stats_1.max == min_max_stats_2.max else '✗'} {min_max_stats_1.max} = {min_max_stats_2.max}"
    )
    print(
        f"MinMaxStats mins match:  {'✓' if min_max_stats_1.min == min_max_stats_2.min else '✗'} {min_max_stats_1.min} = {min_max_stats_2.min}"
    )
    print(
        f"MinMaxStats maxes match:  {'✓' if min_max_stats_2.max == min_max_stats_3.max else '✗'} {min_max_stats_2.max} = {min_max_stats_3.max}"
    )
    print(
        f"MinMaxStats mins match:  {'✓' if min_max_stats_2.min == min_max_stats_3.min else '✗'} {min_max_stats_2.min} = {min_max_stats_3.min}"
    )
    print(
        f"MinMaxStats maxes match:  {'✓' if min_max_stats_1.max == min_max_stats_3.max else '✗'} {min_max_stats_1.max} = {min_max_stats_3.max}"
    )
    print(
        f"MinMaxStats mins match:  {'✓' if min_max_stats_1.min == min_max_stats_3.min else '✗'} {min_max_stats_1.min} = {min_max_stats_3.min}"
    )
    if method1_pass:
        method1_correct += 1
    if method2_pass:
        method2_correct += 1
    if method3_pass:
        method3_correct += 1
    if not method1_pass:
        print("❌ METHOD 1 FAILED")
    else:
        print("✅ METHOD 1 PASSED")
    if not method2_pass:
        print("❌ METHOD 2 FAILED")
    else:
        print("✅ METHOD 2 PASSED")
    if not method3_pass:
        print("❌ METHOD 3 FAILED")
    else:
        print("✅ METHOD 3 PASSED")

    # if not (method1_pass and method2_pass and match):
    #     all_passed = False
    #     print("❌ FAILED")
    # else:
    #     print("✅ PASSED")

print("\n" + "=" * 80)
if all_passed:
    print("✅ All tests passed!")
else:
    print("❌ Some tests failed")

print(f"Method 1 got {method1_correct}/{len(test_cases)}")
print(f"Method 2 got {method2_correct}/{len(test_cases)}")
print(f"Method 3 got {method3_correct}/{len(test_cases)}")

Testing MuZero Value Backpropagation


Test Case 1: 2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 0
Path config: [(1, 0.0), (1, 1.0), (0, 0.0, 0.0)]
Num players: 2
Expected node values: [-1.0, 1.0, 0.0, 0.0]
Method 1 result: [1.0, -1.0, 0.0, 0.0] ✗
Method 2 result: [-1.0, 1.0, 0.0, 0.0] ✓
Method 3 result: [-1.0, 1.0, 0.0, 0.0] ✓
Methods 1 and 2 match: ✗
Methods 2 and 3 match: ✓
Methods 1 and 3 match: ✗
MinMaxStats maxes match:  ✓ 1 = 1
MinMaxStats mins match:  ✓ -1 = -1
MinMaxStats maxes match:  ✓ 1 = 1
MinMaxStats mins match:  ✓ -1 = -1
MinMaxStats maxes match:  ✓ 1 = 1
MinMaxStats mins match:  ✓ -1 = -1
❌ METHOD 1 FAILED
✅ METHOD 2 PASSED
✅ METHOD 3 PASSED

Test Case 2: 2-player: two player 1s with a reward for player 1 on a normally player 0 turn, ending on player 1
Path config: [(1, 0.0), (1, 1.0), (1, 0.0, 0.0)]
Num players: 2
Expected node values: [-1.0, 1.0, 0.0, 0.0]
Method 1 result: [1.0, -1.0, 0.0, 0.0] ✗
Method 2 result: [-1

In [13]:
def old_fn(
    index: int,
    values: list,
    policies: list,
    rewards: list,
    actions: list,
    infos: list,
    num_unroll_steps: int,
    n_step: int,
):
    n_step_values = torch.zeros(num_unroll_steps + 1, dtype=torch.float32)
    n_step_rewards = torch.zeros(num_unroll_steps + 1, dtype=torch.float32)
    n_step_policies = torch.zeros(
        (num_unroll_steps + 1, num_actions), dtype=torch.float32
    )
    n_step_actions = torch.zeros(num_unroll_steps, dtype=torch.int16)
    for current_index in range(index, index + num_unroll_steps + 1):
        unroll_step = current_index - index
        bootstrap_index = current_index + n_step
        # print("bootstrapping")
        # value of current position is the value at the position n_steps away + rewards to get to the n_step position
        if bootstrap_index < len(values):
            if (
                "player" not in infos[current_index]
                or infos[current_index]["player"] == infos[bootstrap_index]["player"]
            ):
                value = values[bootstrap_index] * gamma**n_step
            else:
                value = -values[bootstrap_index] * gamma**n_step
        else:
            value = 0

        # the rewards at this index to the bootstrap index should be added to the value
        for i, reward in enumerate(rewards[current_index:bootstrap_index]):
            # WHAT IS current_index + i + 1 when current index is the last frame?? IS THIS AN ERROR?
            if (
                "player" not in infos[current_index]
                or infos[current_index]["player"]
                == infos[current_index + i][
                    "player"
                ]  # + 1 if doing my og thing and i want to go back
            ):
                value += reward * gamma**i
            else:
                value -= reward * gamma**i

        # target reward is the reward before the ones added to the value
        if current_index > 0 and current_index <= len(rewards):
            last_reward = rewards[current_index - 1]
            # if self.has_intermediate_rewards:
            #     last_reward = rewards[current_index - 1]
            # else:
            #     value += (
            #         rewards[current_index - 1]
            #         if infos[current_index]["player"]
            #         == infos[current_index - 1]["player"]
            #         else -rewards[current_index - 1]
            #     )
            #     last_reward = rewards[current_index - 1]  # reward not used
        else:
            last_reward = 0  # self absorbing state 0 reward

        if current_index < len(values):
            n_step_values[unroll_step] = value
            n_step_rewards[unroll_step] = last_reward
            n_step_policies[unroll_step] = policies[current_index]
            if unroll_step < num_unroll_steps:
                # no action for last unroll step (since you dont act on that state)
                n_step_actions[unroll_step] = actions[current_index]
        else:
            n_step_values[unroll_step] = (
                value  # should be value or 0, maybe broken for single player
            )
            n_step_rewards[unroll_step] = last_reward
            n_step_policies[unroll_step] = (
                torch.ones(num_actions) / num_actions
            )  # self absorbing state
            if unroll_step < num_unroll_steps:
                # no action for last unroll step (since you dont act on that state)
                n_step_actions[unroll_step] = -1  # self absorbing state

    return (
        n_step_values,  # [initial value, recurrent values]
        n_step_policies,  # [initial policy, recurrent policies]
        n_step_rewards,  # [initial reward (0), recurrent rewards] initial reward is useless like the first last action, but we ignore it in the learn function
        n_step_actions,  # [recurrent actions, extra action]
    )  # remove the last actions, as there should be one less action than other stuff


def new_fn(
    index: int,
    values: list,
    policies: list,
    rewards: list,
    actions: list,
    infos: list,
    num_unroll_steps: int,
    n_step: int,
):
    """
    Returns:
        n_step_values: tensor shape (num_unroll_steps+1,)
        n_step_policies: tensor shape (num_unroll_steps+1, num_actions)
        n_step_rewards: tensor shape (num_unroll_steps+1,)
        n_step_actions: tensor shape (num_unroll_steps,)
    Conventions:
        - rewards[t] is the reward from taking action at state t (transition t → t+1)
        - infos[t]["player"] is the player who acted at state t
        - n_step_rewards[0] = 0 (no reward leading into root)
    """
    n_step_values = torch.zeros(num_unroll_steps + 1, dtype=torch.float32)
    n_step_rewards = torch.zeros(num_unroll_steps + 1, dtype=torch.float32)
    n_step_policies = torch.zeros(
        (num_unroll_steps + 1, num_actions), dtype=torch.float32
    )
    n_step_actions = torch.zeros(num_unroll_steps, dtype=torch.int16)

    max_index = len(values)

    for u in range(0, num_unroll_steps + 1):
        current_index = index + u

        # 1. discounted n-step value from current_index
        value = 0.0
        for k in range(n_step):
            r_idx = current_index + k
            if r_idx < len(rewards):
                r = rewards[r_idx]
                node_player = (
                    infos[current_index].get("player", None)
                    if current_index < len(infos)
                    else None
                )
                acting_player = (
                    infos[r_idx].get("player", None) if r_idx < len(infos) else None
                )
                sign = (
                    1.0
                    if (
                        node_player is None
                        or acting_player is None
                        or node_player == acting_player
                    )
                    else -1.0
                )
                value += (gamma**k) * (sign * r)
            else:
                break

        boot_idx = current_index + n_step
        if boot_idx < len(values):
            v_boot = values[boot_idx]
            node_player = (
                infos[current_index].get("player", None)
                if current_index < len(infos)
                else None
            )
            boot_player = (
                infos[boot_idx].get("player", None) if boot_idx < len(infos) else None
            )
            sign_leaf = (
                1.0
                if (
                    node_player is None
                    or boot_player is None
                    or node_player == boot_player
                )
                else -1.0
            )
            value += (gamma**n_step) * (sign_leaf * v_boot)

        n_step_values[u] = value

        # 2. reward target
        if u == 0:
            n_step_rewards[u] = 0.0  # root has no preceding reward
        else:
            reward_idx = current_index - 1
            n_step_rewards[u] = (
                rewards[reward_idx] if reward_idx < len(rewards) else 0.0
            )

        # 3. policy
        if current_index < len(policies):
            n_step_policies[u] = policies[current_index]
        else:
            n_step_policies[u] = torch.ones(num_actions) / num_actions

        # 4. action
        if u < num_unroll_steps:
            n_step_actions[u] = (
                actions[current_index] if current_index < len(actions) else -1
            )

    return n_step_values, n_step_policies, n_step_rewards, n_step_actions

In [9]:
import torch
import math


def compare_get_n_step_info(
    old_fn,
    new_fn,
    *,
    index,
    values,
    policies,
    rewards,
    actions,
    infos,
    num_unroll_steps,
    n_step,
    num_actions,
    gamma=0.997,
    verbose=True,
):
    """
    Compares old and new _get_n_step_info outputs and checks correctness.
    Arguments:
        old_fn: callable implementing the old logic
        new_fn: callable implementing the new logic
        (both should take arguments in same order as MuZeroReplayBuffer._get_n_step_info)
    """

    # Run both implementations
    old_vals, old_pols, old_rews, old_acts = old_fn(
        index, values, policies, rewards, actions, infos, num_unroll_steps, n_step
    )

    new_vals, new_pols, new_rews, new_acts = new_fn(
        index, values, policies, rewards, actions, infos, num_unroll_steps, n_step
    )

    # --- compute expected mathematically correct discounted values ---
    def expected_nstep_value(t):
        """Return correct discounted n-step bootstrap value from index t"""
        v = 0.0
        for k in range(n_step):
            idx = t + k
            if idx >= len(rewards):
                break
            # who acted for reward idx
            r_player = infos[idx]["player"]
            node_player = infos[t]["player"]
            sign = 1 if r_player == node_player else -1
            v += (gamma**k) * (sign * rewards[idx])
        boot_idx = t + n_step
        if boot_idx < len(values):
            node_player = infos[t]["player"]
            leaf_player = infos[boot_idx]["player"]
            sign_leaf = 1 if leaf_player == node_player else -1
            v += (gamma**n_step) * (sign_leaf * values[boot_idx])
        return v

    expected_vals = torch.tensor(
        [expected_nstep_value(index + u) for u in range(num_unroll_steps + 1)],
        dtype=torch.float32,
    )

    # --- compare ---
    def close(a, b, tol=1e-5):
        return torch.allclose(a, b, atol=tol, rtol=tol)

    ok_vals = close(new_vals, expected_vals)
    ok_rews = close(
        new_rews,
        torch.tensor(
            [
                rewards[index + u] if index + u < len(rewards) else 0
                for u in range(num_unroll_steps + 1)
            ]
        ),
    )
    ok_acts = close(
        new_acts,
        torch.tensor(
            [
                actions[index + u] if index + u < len(actions) else -1
                for u in range(num_unroll_steps)
            ],
            dtype=torch.int16,
        ),
    )

    # --- print diagnostics ---
    if verbose:
        print("====== N-STEP COMPARISON ======")
        for name, old, new in zip(
            ["values", "rewards", "actions"],
            [old_vals, old_rews, old_acts],
            [new_vals, new_rews, new_acts],
        ):
            print(f"\n{name.upper()}:")
            print(f" old: {old.tolist()}")
            print(f" new: {new.tolist()}")

        print("\nEXPECTED CORRECT VALUES:", expected_vals.tolist())
        print("\nChecks:")
        print(f"  New matches expected values: {ok_vals}")
        print(f"  Immediate rewards correct:   {ok_rews}")
        print(f"  Actions aligned:             {ok_acts}")

        # Highlight mismatches
        if not ok_vals:
            diffs = (new_vals - expected_vals).abs()
            print("  Value diffs:", diffs.tolist())
        if not close(new_vals, old_vals):
            print(
                "⚠️  Old and new disagree on n-step values (expected if old lacked proper discounting or sign logic)."
            )

    return {
        "ok_values": ok_vals,
        "ok_rewards": ok_rews,
        "ok_actions": ok_acts,
        "old_vals": old_vals,
        "new_vals": new_vals,
        "expected_vals": expected_vals,
    }

In [14]:
# Example episode
values = [0.2, -0.1, 0.4, 0.5]
rewards = [1.0, -2.0, 3.0]
actions = [0, 1, 2]
policies = [
    torch.tensor([0.7, 0.3]),
    torch.tensor([0.2, 0.8]),
    torch.tensor([0.5, 0.5]),
    torch.tensor([0.5, 0.5]),
]
infos = [
    {"player": 0},
    {"player": 1},
    {"player": 0},
    {"player": 1},
]
index = 0
num_unroll_steps = 2
n_step = 2
num_actions = 2
gamma = 0.99

results = compare_get_n_step_info(
    old_fn=old_fn,
    new_fn=new_fn,
    index=index,
    values=values,
    policies=policies,
    rewards=rewards,
    actions=actions,
    infos=infos,
    num_unroll_steps=num_unroll_steps,
    n_step=n_step,
    num_actions=num_actions,
    gamma=0.997,
)


VALUES:
 old: [3.372040033340454, -4.479949951171875, 3.0]
 new: [3.372040033340454, -4.479949951171875, 3.0]

REWARDS:
 old: [0.0, 1.0, -2.0]
 new: [0.0, 1.0, -2.0]

ACTIONS:
 old: [0, 1]
 new: [0, 1]

EXPECTED CORRECT VALUES: [3.391603708267212, -4.493995666503906, 3.0]

Checks:
  New matches expected values: False
  Immediate rewards correct:   False
  Actions aligned:             True
  Value diffs: [0.019563674926757812, 0.01404571533203125, 0.0]
