In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from tqdm.notebook import trange

from game import TicTacToe
from resNet import MuZeroResNet as MuZero
# from linearNet import MuZeroLinear as MuZero
from mctsParallel import MCTS
from replayBuffer import ReplayBuffer
from utils import KaggleAgent, evaluateKaggle

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# no masking for now

In [None]:
class Trainer:
    def __init__(self, muZero, optimizer, game, args):
        self.muZero = muZero
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(self.muZero, self.game, self.args)
        self.replayBuffer = ReplayBuffer(self.args, self.game)
        if self.args['evaluate']:
            self.evalPlayer = KaggleAgent(self.muZero, self.game, self.args['eval_args'])

    def self_play(self, game_idx_group, self_play_bar):
        spGames = [
            SelfPlayGame(
                self.game,
                game_idx_group * self.args['num_parallel_games'] + i
            ) for i in range(self.args['num_parallel_games'])
        ]
        player = 1

        while len(spGames) > 0:
            observations = np.stack([g.observation for g in spGames])
            neutral_observations = self.game.change_perspective(observations, player)
            valid_moves = self.game.get_valid_moves(observations)
            encoded_observations = self.game.get_encoded_observation(neutral_observations)
            
            self.mcts.search(encoded_observations, valid_moves, spGames)

            for i in range(len(spGames))[::-1]:
                g = spGames[i]

                action_probs = np.zeros(self.game.action_size, dtype=np.float32)
                for child in g.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                g.memory.append((encoded_observations[i], action, action_probs, player))

                g.observation = self.game.get_next_state(g.observation, action, player)
                value, is_terminal = self.game.get_value_and_terminated(g.observation, action)

                if is_terminal:
                    for hist_observation, hist_action, hist_action_probs, hist_player in g.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        self.replayBuffer.memory.append((
                            hist_observation,
                            hist_action, 
                            hist_action_probs,
                            hist_outcome,
                            g.game_idx,
                            False # is_terminal
                        ))
                    hist_outcome = value if self.game.get_opponent(player) == player else self.game.get_opponent_value(value)
                    self.replayBuffer.memory.append((
                        self.game.get_encoded_observation(self.game.change_perspective(g.observation, self.game.get_opponent(player))),
                        None,
                        np.zeros(self.game.action_size, dtype=np.float32),
                        hist_outcome,
                        g.game_idx,
                        True # is_terminal
                    ))
                    del spGames[i]
                    self_play_bar.set_description(
                        f"Games finished: {self.args['num_parallel_games'] - len(spGames) + self.args['num_parallel_games'] * game_idx_group} | Avg. steps: \
                        {len(self.replayBuffer.memory) / (self.args['num_parallel_games'] - len(spGames) + self.args['num_parallel_games'] * (game_idx_group % (self.args['num_train_games'] // self.args['num_parallel_games'])))}"
                    )
            
            player = self.game.get_opponent(player)

    def train(self):
        random.shuffle(self.replayBuffer.trajectories)
        for batchIdx in range(0, len(self.replayBuffer), self.args['batch_size']): 
            sample = self.replayBuffer.trajectories[batchIdx:batchIdx+self.args['batch_size']]
            observation, policy_targets, action, value_targets = list(zip(*sample))

            observation = torch.tensor(np.array(observation), dtype=torch.float32, device=self.muZero.device)
            action = np.array(action)
            policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32, device=self.muZero.device)
            value_targets = torch.tensor(np.array(value_targets), dtype=torch.float32, device=self.muZero.device).unsqueeze(-1)

            hidden_state = self.muZero.represent(observation)
            out_policy, out_value = self.muZero.predict(hidden_state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets[:, 0])
            value_loss = F.mse_loss(out_value, value_targets[:, 0])
            for k in range(1, self.args['K'] + 1):
                hidden_state, self.muZero.dynamics(hidden_state, action[:, k - 1])
                out_policy, out_value = self.muZero.predict(hidden_state)

                current_policy_loss = F.cross_entropy(out_policy, policy_targets[:, k])
                current_value_loss = F.mse_loss(out_value, value_targets[:, k])

                current_policy_loss.register_hook(lambda grad: grad / self.args['K'])
                current_value_loss.register_hook(lambda grad: grad / self.args['K'])

                policy_loss += current_policy_loss
                value_loss += current_value_loss

                hidden_state.register_hook(lambda grad: grad * 0.5)

            loss = value_loss * self.args['value_loss_weight'] + policy_loss

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.muZero.parameters(), self.args['max_grad_norm'])
            self.optimizer.step()

    def evaluate(self):
        evaluateKaggle("tictactoe", ["random", self.evalPlayer], num_iterations=20)
        evaluateKaggle("tictactoe", [self.evalPlayer, "random"], num_iterations=20)

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            self.replayBuffer.empty()

            self.muZero.eval()
            for train_game_idx in (self_play_bar := trange(self.args['num_train_games'] // self.args['num_parallel_games'], desc="train_game")):
                self.self_play(train_game_idx + iteration * \
                    (self.args['num_train_games'] // self.args['num_parallel_games']), self_play_bar)
            self.replayBuffer.build_trajectories()

            self.muZero.train()
            for epoch in trange(self.args['num_epochs'], desc="epochs"):
                self.train()

            torch.save(self.muZero.state_dict(), f"Models/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"Models/optimizer_{iteration}.pt")
        
class SelfPlayGame:
    def __init__(self, game, game_idx):
        self.game = game
        self.game_idx = game_idx
        self.memory = []
        self.observation = self.game.get_initial_state()
        self.root = None
        self.node = None

In [None]:
# Train
args = {
    'num_iterations': 20,
    'num_train_games': 500,
    'num_mcts_searches': 25,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1,
    'K': 3,
    'C': 2,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'value_loss_weight': 0.25,
    'max_grad_norm': 5,
    'evaluate': True,
    'eval_args': {
        'search': True,
        'num_mcts_searches': 25,
        'temperature': 0.1,
        'C': 2,
        'dirichlet_alpha': 0.3,
        'dirichlet_epsilon': 0.25,
        'num_eval_games': 100,
    }
}

LOAD = False

game = TicTacToe()
muZero = MuZero(game, device)
optimizer = torch.optim.Adam(muZero.parameters(), lr=0.001, weight_decay=1e-4)

if LOAD:
    muZero.load_state_dict(torch.load(f"Models/model.pt"))
    optimizer.load_state_dict(torch.load(f"Models/optimizer.pt"))

trainer = Trainer(muZero, optimizer, game, args)
trainer.run()

In [None]:
# Test
from utils import KaggleAgent, evaluateKaggle

args = {
    'num_mcts_searches': 25,
    'temperature': 1,
    'C': 1.25,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'search': True,
}

game = TicTacToe()
muZero = MuZero(game, device)

muZero.load_state_dict(torch.load("../../Environments/TicTacToe/Models/model_15.pt"))
muZero.eval()

player = KaggleAgent(muZero, game, args)

evaluateKaggle("tictactoe", ["random", player.run], num_iterations=1)
evaluateKaggle("tictactoe", [player.run, "random"], num_iterations=100)


In [None]:
# Count Parameters
game = TicTacToe()
model = MuZero(game, device)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

get_n_params(model)