In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import random
import math
from tqdm.notebook import trange

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count
        
    def __repr__(self):
        return "TicTacToe"
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player

In [3]:
class ReplayBuffer:
    def __init__(self, args, game):
        self.memory = []
        self.trajectories = []
        self.args = args
        self.game = game

    def __len__(self):
        return len(self.memory)
    
    def empty(self):
        self.memory = []
        self.trajectories = []

    def build_trajectories(self):
        for i in range(len(self.memory)):
            observation, action, policy, value, game_idx, is_terminal = self.memory[i]
            if is_terminal:
                action = np.random.choice(self.game.action_size)

            policy_list, action_list, value_list = [policy], [action], [value]

            for k in range(1, self.args['K'] + 1):
                if i + k < len(self.memory) and self.memory[i + k][4] == game_idx:
                    _, action, policy, value, _, is_terminal = self.memory[i + k]
                    if is_terminal:
                        action = np.random.choice(self.game.action_size)
                    policy_list.append(policy)
                    action_list.append(action)
                    value_list.append(value)

                else:
                    # might be wrong! Check paper
                    policy_list.append(policy_list[-1])
                    action_list.append(np.random.choice(self.game.action_size))
                    value_list.append(0)

            policy_list = np.stack(policy_list)
            self.trajectories.append((observation, policy_list, action_list, value_list))

In [4]:
class Node:
    def __init__(self, muZero, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.muZero = muZero
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    @torch.no_grad()
    def expand(self, policy):
        child_state = self.state.copy()
        child_state = np.expand_dims(child_state, axis=0).repeat(len(policy), axis=0)

        child_state = self.muZero.dynamics(
            torch.tensor(child_state, dtype=torch.float32, device=self.muZero.device), list(range(len(policy))))
        child_state = child_state.cpu().numpy()
        child_state = self.game.change_perspective(child_state, player=self.game.get_opponent(1))
        
        for action, prob in enumerate(policy):
            child = Node(
                self.muZero,
                self.game,
                self.args,
                child_state[action],
                self,
                action,
                prob
            )
            self.children.append(child)

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, muZero, game, args):
        self.muZero = muZero
        self.game = game
        self.args = args

    @torch.no_grad()
    def search(self, state):
        hidden_state = self.muZero.represent(
            torch.tensor(state, dtype=torch.float32, device=self.muZero.device).unsqueeze(0)
        )
        policy, _ = self.muZero.predict(hidden_state)
        hidden_state = hidden_state.cpu().numpy().squeeze(0)
        
        root = Node(self.muZero, self.game, self.args, hidden_state, visit_count=1)

        policy = torch.softmax(policy, dim=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        policy /= np.sum(policy)

        root.expand(policy)

        for search in range(self.args['num_mcts_searches']):
            node = root

            while node.is_fully_expanded():
                node = node.select()

            policy, value = self.muZero.predict(
                torch.tensor(node.state, dtype=torch.float32, device=self.muZero.device).unsqueeze(0)
            )
            policy = torch.softmax(policy, dim=1).squeeze().cpu().numpy()
            value = value.item()

            node.expand(policy)
            node.backpropagate(value)

        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs

In [5]:
class MuZero(nn.Module):
    def __init__(self, game, device):
        super().__init__()
        self.game = game
        self.device = device

        self.predictionFunction = PredictionFunction(game)
        self.dynamicsFunction = DynamicsFunction(game)
        self.representationFunction = RepresentationFunction(game)

        self.to(device)

    def predict(self, hidden_state):
        return self.predictionFunction(hidden_state)

    def represent(self, observation):
        return self.representationFunction(observation)

    def dynamics(self, hidden_state, actions):
        actionPlane = torch.zeros((hidden_state.shape[0], self.game.action_size), device=self.device, dtype=torch.float32)
        for i, a in enumerate(actions):
            actionPlane[i, a] = 1
        x = torch.cat((hidden_state, actionPlane), dim=1)
        return self.dynamicsFunction(x)

class DynamicsFunction(nn.Module):
    def __init__(self, game):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(32 + game.action_size, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layers(x)
        return x
    
class PredictionFunction(nn.Module):
    def __init__(self, game):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
        )

        self.policy_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, game.action_size)
        )

        self.value_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layers(x)
        p = self.policy_head(x)
        v = self.value_head(x)
        return p, v

# Creates initial hidden state based on observation | several observations
class RepresentationFunction(nn.Module):
    def __init__(self, game):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(game.row_count * game.column_count, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.Tanh(),
        )

    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = self.layers(x)
        return x

In [6]:
class Trainer:
    def __init__(self, muZero, optimizer, game, args):
        self.muZero = muZero
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(self.muZero, self.game, self.args)
        self.replayBuffer = ReplayBuffer(self.args, self.game)

    def self_play(self, game_idx):
        memory = []
        player = 1
        observation = self.game.get_initial_state()

        while True:
            neutral_observation = self.game.change_perspective(observation, player)
            action_probs = self.mcts.search(neutral_observation)

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            temperature_action_probs /= np.sum(temperature_action_probs)
            action = np.random.choice(self.game.action_size, p=temperature_action_probs)

            memory.append((neutral_observation, action, action_probs, player))

            observation = self.game.get_next_state(observation, action, player)

            value, is_terminal = self.game.get_value_and_terminated(observation, action)

            if is_terminal:
                returnMemory = []
                for hist_neutral_observation, hist_action, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        hist_neutral_observation,
                        hist_action, 
                        hist_action_probs,
                        hist_outcome,
                        game_idx,
                        False # is_terminal
                    ))
                if self.args['K'] > 0:
                    returnMemory.append((
                        self.game.change_perspective(observation, self.game.get_opponent(player)), # check this
                        None,
                        np.array([1 / self.game.action_size] * self.game.action_size),
                        0, # check this
                        game_idx,
                        True # is_terminal
                    ))
                return returnMemory
            
            player = self.game.get_opponent(player)

    def train(self):
        random.shuffle(self.replayBuffer.trajectories)
        for batchIdx in range(0, len(self.replayBuffer), self.args['batch_size']): 
            sample = self.replayBuffer. trajectories[batchIdx:min(len(self.replayBuffer) - 1, batchIdx + self.args['batch_size'])]
            observation, policy_targets, action, value_targets = list(zip(*sample))

            observation = torch.tensor(np.array(observation), dtype=torch.float32, device=self.muZero.device)
            action = np.array(action)
            policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32, device=self.muZero.device)
            value_targets = torch.tensor(np.array(value_targets), dtype=torch.float32, device=self.muZero.device).unsqueeze(-1)

            hidden_state = self.muZero.represent(observation)
            out_policy, out_value = self.muZero.predict(hidden_state)

            policy_loss = F.cross_entropy(out_policy, policy_targets[:, 0]) 
            value_loss = F.mse_loss(out_value, value_targets[:, 0])

            if self.args['K'] > 0:
                for k in range(1, self.args['K'] + 1):
                    hidden_state, self.muZero.dynamics(hidden_state, action[:, k - 1])
                    hidden_state.register_hook(lambda grad: grad * 0.5)

                    hidden_state = self.game.change_perspective(hidden_state, -1)

                    out_policy, out_value = self.muZero.predict(hidden_state)

                    policy_loss += F.cross_entropy(out_policy, policy_targets[:, k])
                    value_loss += F.mse_loss(out_value, value_targets[:, k])

            loss = (value_loss * self.args['value_loss_weight'] + policy_loss).mean()
            loss.register_hook(lambda grad: grad * 1 / self.args['K'])

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.muZero.parameters(), self.args['max_grad_norm'])
            self.optimizer.step()

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            self.replayBuffer.empty()

            self.muZero.eval()
            for train_game_idx in (self_play_bar := trange(self.args['num_train_games'], desc="train_game")):
                self.replayBuffer.memory += self.self_play(train_game_idx + iteration * self.args['num_train_games'])
                self_play_bar.set_description(f"Avg. steps per Game: {len(self.replayBuffer) / (train_game_idx + 1):.2f}")
            self.replayBuffer.build_trajectories()

            self.muZero.train()
            for epoch in trange(self.args['num_epochs'], desc="epochs"):
                self.train()

            torch.save(self.muZero.state_dict(), f"../../Environments/{self.game}/Models/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"../../Environments/{self.game}/Models/optimizer_{iteration}.pt")

In [7]:
args = {
    'num_iterations': 20,
    'num_train_games': 100,
    'num_mcts_searches': 50,
    'num_epochs': 4,
    'batch_size': 64,
    'temperature': 1,
    'K': 3,
    'C': 2,
    'dirichlet_alpha': 0.3,
    'dirichlet_epsilon': 0.25,
    'value_loss_weight': 0.25,
    'max_grad_norm': 5,
}

LOAD = False

game = TicTacToe()
muZero = MuZero(game, device)
optimizer = torch.optim.AdamW(muZero.parameters(), lr=0.001)

trainer = Trainer(muZero, optimizer, game, args)
trainer.run()

iteration: 0


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 1


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 2


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 3


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 4


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 5


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 6


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 7


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 8


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 9


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 10


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 11


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 12


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 13


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 14


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 15


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 16


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 17


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 18


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]

iteration: 19


train_game:   0%|          | 0/100 [00:00<?, ?it/s]

epochs:   0%|          | 0/4 [00:00<?, ?it/s]