In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import random
import math
from tqdm.notebook import trange
from kaggle_environments import make, evaluate

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# hidden_state_scale (paper), weights prio, parallel, (max grad norm), division of 1/k is only for unrolling steps, check reward loss
# include parallel valid moves to other scripts (might be slower though)

cuda


In [2]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count
        
    def __repr__(self):
        return "TicTacToe"
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        if len(state.shape) == 3:
            return (state.reshape(-1, 9) == 0).astype(np.uint8)
        return (state.reshape(9) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player
    
    def get_encoded_observation(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)
        
        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)
        
        return encoded_state

In [3]:
class ReplayBuffer:
    def __init__(self, args, game):
        self.memory = []
        self.trajectories = []
        self.args = args
        self.game = game

    def __len__(self):
        return len(self.trajectories)
    
    def empty(self):
        self.memory = []
        self.trajectories = []

    def build_trajectories(self):
        for i in range(len(self.memory)):
            observation, action, policy, value, game_idx, is_terminal = self.memory[i]
            if not is_terminal:
                policy_list, action_list, value_list = [policy], [action], [value]

                for k in range(1, self.args['K'] + 1):
                    if i + k < len(self.memory) and self.memory[i + k][4] == game_idx:
                        _, action, policy, value, _, is_terminal = self.memory[i + k]
                        if is_terminal:
                            action = np.random.choice(self.game.action_size)
                        policy_list.append(policy)
                        action_list.append(action)
                        value_list.append(value)

                    else:
                        policy_list.append(np.zeros(self.game.action_size, dtype=np.float32))
                        action_list.append(np.random.choice(self.game.action_size))
                        value_list.append(self.game.get_opponent_value(value_list[-1]))

                policy_list = np.stack(policy_list)
                self.trajectories.append((observation, policy_list, action_list, value_list))

In [None]:
class MuZero(nn.Module):
    def __init__(self, game, device):
        super().__init__()
        self.game = game
        self.device = device

        self.predictionFunction = PredictionFunction(game)
        self.dynamicsFunction = DynamicsFunction()
        self.representationFunction = RepresentationFunction()

        self.to(device)

    def predict(self, hidden_state):
        return self.predictionFunction(hidden_state)

    def represent(self, observation):
        return self.representationFunction(observation)

    def dynamics(self, hidden_state, actions):
        actionPlane = torch.zeros((hidden_state.shape[0], 1, self.game.row_count, self.game.column_count), device=self.device, dtype=torch.float32)
        for i, a in enumerate(actions):
            row = a // self.game.column_count
            col = a % self.game.column_count
            actionPlane[i, 0, row, col] = 1
        x = torch.cat((hidden_state, actionPlane), dim=1)
        return self.dynamicsFunction(x)

class DynamicsFunction(nn.Module):
    def __init__(self, num_resBlocks=2, num_hidden=16):
        super().__init__()
        
        self.startBlock = nn.Sequential(
            nn.Conv2d(4, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        self.backBone = nn.ModuleList([ResBlock(num_hidden) for _ in range(num_resBlocks)])
        self.endBlock = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.startBlock(x)
        for resblock in self.backBone:
            x = resblock(x)
        x = self.endBlock(x)
        return x

class PredictionFunction(nn.Module):
    def __init__(self, game, num_resBlocks=2, num_hidden=16):
        super().__init__()

        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden // 2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(num_hidden // 2 * game.row_count * game.column_count, game.action_size)
        )
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
 
class RepresentationFunction(nn.Module):
    def __init__(self, num_hidden=16):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(3, num_hidden // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden // 2),
            nn.ReLU(),
            nn.Conv2d(num_hidden // 2, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU(),
            # ResBlock(num_hidden),
            nn.Conv2d(num_hidden, num_hidden // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden // 2),
            nn.ReLU(),
            nn.Conv2d(num_hidden // 2, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layers(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x


In [None]:
class Node:
    def __init__(self, muZero, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.muZero = muZero
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    @torch.no_grad()
    def expand(self, policy):
        actions = [a for a in range(self.game.action_size) if policy[a] > 0]
        child_state = self.state.copy()
        child_state = np.expand_dims(child_state, axis=0).repeat(len(actions), axis=0)

        child_state = self.muZero.dynamics(
            torch.tensor(child_state, dtype=torch.float32, device=self.muZero.device), actions)
        child_state = child_state.cpu().numpy()
        
        for i, action in enumerate(actions):
            child = Node(
                self.muZero,
                self.game,
                self.args,
                state=child_state[i],
                parent=self,
                action_taken=action,
                prior=policy[action],
            )
            self.children.append(child)

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, muZero, game, args):
        self.muZero = muZero
        self.game = game
        self.args = args

    @torch.no_grad()
    def search(self, observations, valid_moves, spGames):
        hidden_states = self.muZero.represent(
            torch.tensor(observations, dtype=torch.float32, device=self.muZero.device)
        )
        policy, _ = self.muZero.predict(hidden_states)
        
        policy = torch.softmax(policy, dim=1).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])
        policy *= valid_moves
        policy /= np.sum(policy, axis=1, keepdims=True)

        hidden_states = hidden_states.cpu().numpy()

        for i, g in enumerate(spGames):
            g.root = Node(
                self.muZero, self.game, self.args, 
                hidden_states[i], visit_count=1)
            g.root.expand(policy[i])

        for search in range(self.args['num_mcts_searches']):
            for g in spGames:
                node = g.root

                while node.is_fully_expanded():
                    node = node.select()

                g.node = node

            hidden_states = np.stack([g.node.state for g in spGames])
            policy, value = self.muZero.predict(
                torch.tensor(hidden_states, dtype=torch.float32, device=self.muZero.device)
            )
            policy = torch.softmax(policy, dim=1).cpu().numpy()
            value = value.cpu().numpy().reshape(-1)

            for i, g in enumerate(spGames):
                g.node.expand(policy[i])
                g.node.backpropagate(value[i])

In [None]:
class Trainer:
    def __init__(self, muZero, optimizer, game, args):
        self.muZero = muZero
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(self.muZero, self.game, self.args)
        self.replayBuffer = ReplayBuffer(self.args, self.game)

    def self_play(self, game_idx_group, self_play_bar):
        spGames = [
            SelfPlayGame(
                self.game,
                game_idx_group * self.args['num_parallel_games'] + i
            ) for i in range(self.args['num_parallel_games'])
        ]
        player = 1

        while len(spGames) > 0:
            observations = np.stack([g.observation for g in spGames])
            valid_moves = self.game.get_valid_moves(observations)
            neutral_observations = self.game.change_perspective(observations, player)
            encoded_observations = self.game.get_encoded_observation(neutral_observations)
            
            self.mcts.search(encoded_observations, valid_moves, spGames)

            for i in range(len(spGames))[::-1]:
                g = spGames[i]

                action_probs = np.zeros(self.game.action_size, dtype=np.float32)
                for child in g.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                g.memory.append((encoded_observations[i], action, action_probs, player))

                g.observation = self.game.get_next_state(g.observation, action, player)
                value, is_terminal = self.game.get_value_and_terminated(g.observation, action)

                if is_terminal:
                    for hist_observation, hist_action, hist_action_probs, hist_player in g.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        self.replayBuffer.memory.append((
                            hist_observation,
                            hist_action, 
                            hist_action_probs,
                            hist_outcome,
                            g.game_idx,
                            False # is_terminal
                        ))
                    hist_outcome = value if self.game.get_opponent(player) == player else self.game.get_opponent_value(value)
                    self.replayBuffer.memory.append((
                        self.game.get_encoded_observation(self.game.change_perspective(g.observation, self.game.get_opponent(player))),
                        None,
                        np.zeros(self.game.action_size, dtype=np.float32),
                        hist_outcome,
                        g.game_idx,
                        True # is_terminal
                    ))
                    del spGames[i]
                    self_play_bar.set_description(
                        f"Games finished: {self.args['num_parallel_games'] - len(spGames) + self.args['num_parallel_games'] * game_idx_group} | Avg. steps: \
                        {len(self.replayBuffer.memory) / (self.args['num_parallel_games'] - len(spGames) + self.args['num_parallel_games'] * (game_idx_group % (self.args['num_train_games'] // self.args['num_parallel_games'])))}"
                    )
            
            player = self.game.get_opponent(player)

    def train(self):
        random.shuffle(self.replayBuffer.trajectories)
        for batchIdx in range(0, len(self.replayBuffer), self.args['batch_size']): 
            sample = self.replayBuffer.trajectories[batchIdx:batchIdx+self.args['batch_size']]
            observation, policy_targets, action, value_targets = list(zip(*sample))

            observation = torch.tensor(np.array(observation), dtype=torch.float32, device=self.muZero.device)
            action = np.array(action)
            policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32, device=self.muZero.device)
            value_targets = torch.tensor(np.array(value_targets), dtype=torch.float32, device=self.muZero.device).unsqueeze(-1)

            hidden_state = self.muZero.represent(observation)
            out_policy, out_value = self.muZero.predict(hidden_state)
            
            predictions = [(out_policy, out_value)]
            for k in range(1, self.args['K'] + 1):
                hidden_state, self.muZero.dynamics(hidden_state, action[:, k - 1])
                out_policy, out_value = self.muZero.predict(hidden_state)
                predictions.append((out_policy, out_value))

                hidden_state.register_hook(lambda grad: grad * 0.5)

            policy_loss = F.cross_entropy(predictions[0][0], policy_targets[:, 0])
            value_loss = F.mse_loss(predictions[0][1], value_targets[:, 0])
            for k in range(1, self.args['K'] + 1):
                current_policy_loss = F.cross_entropy(predictions[k][0], policy_targets[:, k], reduction='sum') \
                    / (policy_targets[:, k].sum(axis=1)!=0).sum()
                current_value_loss = F.mse_loss(predictions[k][1], value_targets[:, k])
                current_policy_loss.register_hook(lambda grad: grad / self.args['K'])
                current_value_loss.register_hook(lambda grad: grad / self.args['K'])

                policy_loss += current_policy_loss
                value_loss += current_value_loss

            loss = value_loss * self.args['value_loss_weight'] + policy_loss

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.muZero.parameters(), self.args['max_grad_norm'])
            self.optimizer.step()

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            self.replayBuffer.empty()

            self.muZero.eval()
            for train_game_idx in (self_play_bar := trange(self.args['num_train_games'] // self.args['num_parallel_games'], desc="train_game")):
                self.self_play(train_game_idx + iteration * \
                    (self.args['num_train_games'] // self.args['num_parallel_games']), self_play_bar)
            self.replayBuffer.build_trajectories()

            self.muZero.train()
            for epoch in trange(self.args['num_epochs'], desc="epochs"):
                self.train()

            torch.save(self.muZero.state_dict(), f"../../Environments/{self.game}/Models/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"../../Environments/{self.game}/Models/optimizer_{iteration}.pt")
        
class SelfPlayGame:
    def __init__(self, game, game_idx):
        self.game = game
        self.game_idx = game_idx
        self.memory = []
        self.observation = self.game.get_initial_state()
        self.root = None
        self.node = None

In [None]:
args = {
    'num_iterations': 100,
    'num_train_games': 100,
    'num_parallel_games': 100,
    'num_mcts_searches': 60,
    'num_epochs': 3,
    'batch_size': 64,
    'temperature': 1,
    'K': 3,
    'C': 1.25,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'value_loss_weight': 0.5,
    'max_grad_norm': 5,
}

LOAD = False

game = TicTacToe()
muZero = MuZero(game, device)
optimizer = torch.optim.AdamW(muZero.parameters(), lr=0.001)

if LOAD:
    muZero.load_state_dict(torch.load(f"../../Environments/{game}/Models/model.pt"))
    optimizer.load_state_dict(torch.load(f"../../Environments/{game}/Models/optimizer.pt"))

trainer = Trainer(muZero, optimizer, game, args)
# trainer.run()

In [None]:
class KaggleAgent:
    def __init__(self, muZero, game, args):
        self.muZero = muZero
        self.game = game
        self.args = args
        if self.args['search']:
            self.mcts = MCTS(self.muZero, self.game, self.args)

    def run(self, obs, conf):
        player = obs['mark'] if obs['mark'] == 1 else -1
        observation = np.array(obs['board']).reshape(self.game.row_count, self.game.column_count)
        observation[observation==2] = -1
        valid_moves = self.game.get_valid_moves(observation)
        
        neutral_observation = self.game.change_perspective(observation, player).copy()
        encoded_observation = self.game.get_encoded_observation(neutral_observation)

        with torch.no_grad():
            if self.args['search']:
                policy = self.mcts.search(encoded_observation, valid_moves)

            else:
                hidden_state = torch.tensor(encoded_observation, dtype=torch.float32, device=self.muZero.device).unsqueeze(0)
                hidden_state = self.muZero.represent(hidden_state)
                print(hidden_state)

                policy, _ = self.muZero.predict(hidden_state)
                policy = torch.softmax(policy, dim=1).squeeze(0).cpu().numpy()

        policy *= valid_moves
        policy /= np.sum(policy)

        if self.args['temperature'] == 0:
            action = int(np.argmax(policy))
        elif self.args['temperature'] == float('inf'):
            action = np.random.choice([r for r in range(self.game.action_size) if policy[r] > 0])
        else:
            policy = policy ** (1 / self.args['temperature'])
            policy /= np.sum(policy)
            action = np.random.choice(self.game.action_size, p=policy)

        return action
    
def evaluateKaggle(gameName, players, num_iterations=1):
    if num_iterations == 1:
        env = make(gameName, debug=True)
        env.run(players)
        return env.render(mode="ipython")

    results = np.array(evaluate(gameName, players, num_episodes=num_iterations))[:, 0]
    print(f"""
Player 1 | Wins: {np.sum(results == 1)} | Draws: {np.sum(results == 0)} | Losses: {np.sum(results == -1)}
Player 2 | Wins: {np.sum(results == -1)} | Draws: {np.sum(results == 0)} | Losses: {np.sum(results == 1)}
    """)


In [None]:
args = {
    'num_mcts_searches': 60,
    'temperature': 1,
    'C': 1.25,
    'dirichlet_alpha': 0.1,
    'dirichlet_epsilon': 0.25,
    'search': True,
}

game = TicTacToe()
muZero = MuZero(game, device)

muZero.load_state_dict(torch.load("../../Environments/TicTacToe/Models/model_13.pt"))
muZero.eval()

player = KaggleAgent(muZero, game, args)

evaluateKaggle("tictactoe", ["random", player.run], num_iterations=1)
evaluateKaggle("tictactoe", [player.run, "random"], num_iterations=100)


tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]]], device='cuda:0')
Traceback (most recent call last):
  File "/home/robert/anaconda3/envs/myenv/lib/python3.8/site-packages/kaggle_environments/agent.py", line 159, in act
    action = self.agent(*args)
  File "/tmp/ipykernel_19649/2796621217.py", line 40, in run
    action = np.random.choice(self.game.action_size, p=policy)
  File "mtrand.pyx", line 954, in numpy.random.mtrand.RandomState.choice
ValueError: probabilities contain NaN
Error: ['Traceback (most recent call last):\n', '  File "/home/robert/anaconda3/envs/myenv/lib/python3.8/site-packages/kaggle_environments/agent.py", line 159, in act\n    action = self.agent(*args)\n', '  File "/tmp/ipykernel_19649/2796621217.py", line 40, in run\n    action = np.random.choice(self.game.action_s


Player 1 | Wins: 0 | Draws: 0 | Losses: 0
Player 2 | Wins: 0 | Draws: 0 | Losses: 0
    


In [None]:
game = TicTacToe()
model = MuZero(game, device)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))