In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from tqdm.notebook import trange
import wandb

from game import TicTacToe
from resNet import MuZeroResNet as MuZero
# from linearNet import MuZeroLinear as MuZero
from mcts import MCTS
from replayBuffer import ReplayBuffer
from utils import KaggleAgent, evaluateKaggle

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

wandb.login()
wandb.init(project="TicTacToe-Desperate")

# no masking for now

In [None]:
class Trainer:
    def __init__(self, muZero, optimizer, game, args):
        self.muZero = muZero
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(self.muZero, self.game, self.args)
        self.replayBuffer = ReplayBuffer(self.args, self.game)
        if self.args['evaluate']:
            self.evalPlayer = KaggleAgent(self.muZero, self.game, self.args['eval_args'])

    def self_play(self, game_idx):
        memory = []
        player = 1
        observation = self.game.get_initial_state()

        while True:
            valid_moves = self.game.get_valid_moves(observation)
            neutral_observation = self.game.change_perspective(observation, player)
            encoded_observation = self.game.get_encoded_observation(neutral_observation)
            action_probs = self.mcts.search(encoded_observation, valid_moves)

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            temperature_action_probs /= np.sum(temperature_action_probs)
            action = np.random.choice(self.game.action_size, p=temperature_action_probs)

            memory.append((encoded_observation, action, action_probs, player))

            observation = self.game.get_next_state(observation, action, player)
            value, is_terminal = self.game.get_value_and_terminated(observation, action)

            if is_terminal:
                return_memory = []
                for hist_observation, hist_action, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    return_memory.append((
                        hist_observation,
                        hist_action, 
                        hist_action_probs,
                        hist_outcome,
                        game_idx,
                        False # is_terminal
                    ))
                hist_outcome = value if self.game.get_opponent(player) == player else self.game.get_opponent_value(value)
                return_memory.append((
                    self.game.get_encoded_observation(self.game.change_perspective(observation, self.game.get_opponent(player))),
                    None,
                    np.ones(self.game.action_size) / self.game.action_size,
                    hist_outcome,
                    game_idx,
                    True # is_terminal
                ))
                return return_memory
            
            player = self.game.get_opponent(player)

    def train(self):
        random.shuffle(self.replayBuffer.trajectories)
        for batchIdx in range(0, len(self.replayBuffer), self.args['batch_size']): 
            sample = self.replayBuffer.trajectories[batchIdx:batchIdx+self.args['batch_size']]
            observation, policy_targets, action, value_targets = list(zip(*sample))

            observation = torch.tensor(np.array(observation), dtype=torch.float32, device=self.muZero.device)
            action = np.array(action)
            policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32, device=self.muZero.device)
            value_targets = torch.tensor(np.array(value_targets), dtype=torch.float32, device=self.muZero.device).unsqueeze(-1)

            hidden_state = self.muZero.represent(observation)
            out_policy, out_value = self.muZero.predict(hidden_state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets[:, 0])
            value_loss = F.mse_loss(out_value, value_targets[:, 0])
            for k in range(1, self.args['K'] + 1):
                hidden_state, self.muZero.dynamics(hidden_state, action[:, k - 1])
                out_policy, out_value = self.muZero.predict(hidden_state)

                current_policy_loss = F.cross_entropy(out_policy, policy_targets[:, k])
                current_value_loss = F.mse_loss(out_value, value_targets[:, k])

                current_policy_loss.register_hook(lambda grad: grad / self.args['K'])
                current_value_loss.register_hook(lambda grad: grad / self.args['K'])

                policy_loss += current_policy_loss
                value_loss += current_value_loss

                # hidden_state.register_hook(lambda grad: grad * 0.5)

            loss = value_loss * self.args['value_loss_weight'] + policy_loss

            wandb.log({
                "value_loss": value_loss.item(),
                "policy_loss": policy_loss.item(),
                "loss": loss.item()
            })

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.muZero.parameters(), self.args['max_grad_norm'])
            self.optimizer.step()

    def evaluate(self):
        results_1 = evaluateKaggle("tictactoe", ["random", self.evalPlayer], num_iterations=20)
        results_2 = evaluateKaggle("tictactoe", [self.evalPlayer, "random"], num_iterations=20)
        
        wandb.log({
            "win_rate (against random)": (np.sum(results_1==-1) + np.sum(results_2==1)) / 40,
            "tie_rate (against random)": (np.sum(results_1==0) + np.sum(results_2==0)) / 40,
        })

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            self.replayBuffer.empty()

            self.muZero.eval()
            for train_game_idx in (self_play_bar := trange(self.args['num_train_games'], desc="train_game")):
                self.replayBuffer.memory += self.self_play(train_game_idx + iteration * self.args['num_train_games'])
                self_play_bar.set_description(f"Avg. steps per Game: {len(self.replayBuffer.memory) / (train_game_idx + 1):.2f}")
                wandb.log({"steps_per_game": len(self.replayBuffer.memory) / (train_game_idx + 1)})
            self.replayBuffer.build_trajectories()

            self.muZero.train()
            for epoch in trange(self.args['num_epochs'], desc="epochs"):
                self.train()

            if self.args['evaluate']:
                self.muZero.eval()
                self.evaluate()

            torch.save(self.muZero.state_dict(), f"Models/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"Models/optimizer_{iteration}.pt")

In [None]:
# Train
args = {
    'num_iterations': 20,
    'num_train_games': 500,
    'num_mcts_searches': 25,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1,
    'K': 3,
    'C': 2,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'value_loss_weight': 0.25,
    'max_grad_norm': 5,
    'evaluate': True,
    'eval_args': {
        'search': True,
        'num_mcts_searches': 25,
        'temperature': 0.1,
        'C': 2,
        'dirichlet_alpha': 0.3,
        'dirichlet_epsilon': 0.25,
        'num_eval_games': 100,
    }
}

LOAD = False

game = TicTacToe()
muZero = MuZero(game, device)
optimizer = torch.optim.Adam(muZero.parameters(), lr=0.001, weight_decay=1e-4)

if LOAD:
    muZero.load_state_dict(torch.load(f"Models/model.pt"))
    optimizer.load_state_dict(torch.load(f"Models/optimizer.pt"))

trainer = Trainer(muZero, optimizer, game, args)
trainer.run()

In [None]:
# Test
from utils import KaggleAgent, evaluateKaggle

args = {
    'num_mcts_searches': 25,
    'temperature': 1,
    'C': 1.25,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'search': True,
}

game = TicTacToe()
muZero = MuZero(game, device)

# muZero.load_state_dict(torch.load("../../Environments/TicTacToe/Models/model_15.pt"))
muZero.eval()

player = KaggleAgent(muZero, game, args)

evaluateKaggle("tictactoe", ["random", player.run], num_iterations=1)
evaluateKaggle("tictactoe", [player.run, "random"], num_iterations=100)


In [None]:
# Count Parameters
game = TicTacToe()
model = MuZero(game, device)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(model)

In [4]:
# Test MuZeroGeneral

from resNetGeneral import MuZeroResidualNetwork
from tictactoeGeneral import TicTacToe as TicTacToeGeneral



model = MuZeroResidualNetwork(
    observation_shape=(3, 3, 3),
    stacked_observations=0,
    action_space_size=list(range(9)),
    num_blocks=1,
    num_channels=16,
    reduced_channels_reward=16,
    reduced_channels_value=16,
    reduced_channels_policy=16,
    fc_reward_layers=[8],
    fc_value_layers=[8],
    fc_policy_layers=[8],
    support_size=10,
    downsample=False,
)

a = torch.load('/home/robert/Documents/GitHub/MuZero/MuZeroGeneral/2023-05-25--14-22-49/model.checkpoint')
a

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
