#GLOBAL FUNCTIONS AND IMPORTS

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn import preprocessing

from tqdm.notebook import trange

##ALPHA-ZERO

##ALPHA-ZERO-PARALLEL

In [None]:
class MCTSParallel:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, states, spGames):
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])
        
        for i, spg in enumerate(spGames):
            spg_policy = policy[i]
            valid_moves = self.game.get_valid_moves(states[i])
            spg_policy *= valid_moves
            spg_policy /= np.sum(spg_policy)

            spg.root = Node(self.game, self.args, states[i], visit_count=1)
            spg.root.expand(spg_policy)
        
        for search in range(self.args['num_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root

                while node.is_fully_expanded():
                    node = node.select()

                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)
                
                if is_terminal:
                    node.backpropagate(value)
                    
                else:
                    spg.node = node
                    
            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]
                    
            if len(expandable_spGames) > 0:
                states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])
                
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
                )
                policy = torch.softmax(policy, axis=1).cpu().numpy()
                value = value.cpu().numpy()
                
            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                spg_policy, spg_value = policy[i], value[i]
                
                valid_moves = self.game.get_valid_moves(node.state)
                spg_policy *= valid_moves
                spg_policy /= np.sum(spg_policy)

                node.expand(spg_policy)
                node.backpropagate(spg_value)

In [None]:
class AlphaZeroParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTSParallel(game, args, model)
        
    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]
        
        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])
            neutral_states = self.game.change_perspective(states, player)
            
            self.mcts.search(neutral_states, spGames)
            
            for i in range(len(spGames))[::-1]:
                spg = spGames[i]
                
                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                spg.memory.append((spg.root.state, action_probs, player))

                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size, p=temperature_action_probs) # Divide temperature_action_probs with its sum in case of an error

                spg.state = self.game.get_next_state(spg.state, action, player)

                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    for hist_neutral_state, hist_action_probs, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome
                        ))
                    del spGames[i]
                    
            player = self.game.get_opponent(player)
            
        return return_memory
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                memory += self.selfPlay()
            
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")


In [None]:
class SPG:
    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None
        

##EVALUATE-AI

In [None]:
class BigBrain(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        
        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        value = self.valueHead(x)
        return value
        
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [None]:
class evaluateAI:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        
    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()
        
        while True:
            neutral_state = self.game.change_perspective(state, player)

            action_probs = dict()
            for action in self.game.get_valid_moves(state):
                state_after_action = self.game.get_next_state(np.copy(state), action, player)
                new_neutral_state = self.game.change_perspective(state_after_action, self.game.get_opponent(player))

                value = self.model(
                    torch.tensor(self.game.get_encoded_state(new_neutral_state), device=self.model.device).unsqueeze(0)
                )
                action_probs[action] = self.game.get_opponent_value(value) + 1
            
            memory.append((neutral_state, player))
            action = random.choices(list(action_probs.keys()), weights = action_probs.values(), k = 1)[0]
            state = self.game.get_next_state(state, action, player)
            value, is_terminal = self.game.get_value_and_terminated(state, action)
            
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_outcome
                    ))
                return returnMemory
            player = self.game.get_opponent(player)
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
            state, value_targets = zip(*sample)
            
            state, value_targets = np.array(state), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_value = self.model(state)
            
            value_loss = F.mse_loss(out_value, value_targets)
            loss = value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()
                
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

##EVALUATE-AI-PARALLEL

In [None]:
class evaluateAIParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        
    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]
        
        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])
            neutral_states = self.game.change_perspective(states, player)
            
            for i in range(len(spGames))[::-1]:
                spg = spGames[i]
                
                action_probs = dict()
                for action in self.game.get_valid_moves(states[i]):
                    state_after_action = self.game.get_next_state(np.copy(states[i]), action, player)
                    new_neutral_state = self.game.change_perspective(state_after_action, self.game.get_opponent(player))

                    value = self.model(
                        torch.tensor(self.game.get_encoded_state(new_neutral_state), device=self.model.device).unsqueeze(0)
                    )
                    action_probs[action] = self.game.get_opponent_value(value.item()) / 2 + 0.5
                
                action_probs.update((key, value**self.args['temperature']) for key, value in action_probs.items())
                spg.memory.append((neutral_states[i], player))
                action = random.choices(list(action_probs.keys()), weights = action_probs.values(), k = 1)[0]
                spg.state = self.game.get_next_state(spg.state, action, player)
                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)
                if is_terminal:
                    for hist_neutral_state, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_outcome
                        ))
                    del spGames[i]
                    
            player = self.game.get_opponent(player)
            
        return return_memory
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])]
            state, value_targets = zip(*sample)
            
            state, value_targets = np.array(state), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_value = self.model(state)
            
            value_loss = F.mse_loss(out_value, value_targets)
            loss = value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                new_memories = self.selfPlay()
                memory += new_memories
            
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")


##DUMMY

In [None]:
class dummy():
    def choose_action(self, game, game_state):
        return random.choice(game.actions(game_state))

##ALPHA-BETA

In [None]:
class alphabeta():
    def choose_action(self, game, game_state):
        player = game_state.player
        value, move = self.max_value(game, game_state, player, -np.Inf, np.Inf)
        return move

    def max_value(self, game, game_state, player, alfa, beta):
        if game.is_terminal(game_state):
            return game.utility(player, game_state), None
        value = -np.Inf

        for action in game.actions(game_state):
            value2, action2 = self.min_value(game, game.perform_action(action, game_state), player, alfa, beta)
            if value2 > value:
                value, move = value2, action
                alfa = max(alfa, value)
            if value >= beta:
                return value, move
        return value, move
    
    def min_value(self, game, game_state, player, alfa, beta):
        if game.is_terminal(game_state):
            return game.utility(player, game_state), None
        value = np.Inf

        for action in game.actions(game_state):
            value2, action2 = self.max_value(game, game.perform_action(action, game_state), player, alfa, beta)
            if value2 < value:
                value, move = value2, action
                beta = min(beta, value)
            if value <= alfa:
                return value, move
        return value, move

##MCTS

In [None]:
class MCTSNode:
    def __init__(self, game, args, state, parent = None, action_taken = None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)

        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['c'] * np.sqrt(np.log(self.visit_count) / child.visit_count)
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        return best_child
    
    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, player = -1)

        child = MCTSNode(self.game, self.args, child_state, self, action)
        self.children.append(child)
        return child

    def simulate(self):
        value, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)

        if is_terminal:
            return value
        
        rollout_state = self.state.copy()
        rollout_player = 1
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)

            if is_terminal:
                if rollout_player == -1:
                    value = self.game.get_opponent_value(value)
                return value
            rollout_player = self.game.get_opponent(rollout_player)
    
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
    
    def search(self, state):
        root = MCTSNode(self.game, self.args, state)

        for serach in range(self.args['num_searches']):
            node = root

            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                node = node.expand()
                value = node.simulate()
            node.backpropagate(value)
        
        action_probabilities = np.zeros(self.game.action_size)
        for child in root.children:
            action_probabilities[child.action_taken] = child.visit_count
        action_probabilities /= np.sum(action_probabilities)
        return action_probabilities

#TRAINING

In [None]:
class ELO_Tournament_Trainer:
    def __init__(self, game, device, args, model, optimizer, number_of_players = 4, players = []):
        self.game = game
        self.number_of_players = number_of_players
        self.players = players
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.args = args
        self.model = model
        self.optimizer = optimizer
    
    def train_models(self, no_improvement = 10):
        while len(self.players) < self.number_of_players:
            player = evaluateAIParallel(copy.deepcopy(self.model), copy.deepcopy(self.optimizer), game, self.args)
            player.learn()
            self.players.append(copy.deepcopy(player))
        
        results = self.tournament()
        
        best_performer = max(results, key=results.get)
        worst_performer = min(results, key=results.get)
        print(f'best_performence: {max(results.values())}')
        print(f'worst_performence: {min(results.values())}')
        """
        self.players.remove(worst_performer)
        best_performer = copy.deepcopy(best_performer)
        best_performer.learn()
        self.players.append(best_performer)
        no_improvement -= 1
        """
        return self.players
    
    def tournament(self):
        players_performances = {player: 0 for player in self.players}
        for player1 in self.players:
            for player2 in self.players:
                for _ in range(10):
                    result = self.play_game(player1, player2)
                    if result == 1:
                        players_performances[player2] -= 1
                    if result == -1:
                        players_performances[player1] -= 1
        return players_performances

    def play_game(self, player1, player2, debug = False):
        state = self.game.get_initial_state()
        player = 1
        while True:
            if debug:
                print(state)
            action = self.choose_action(state, player, player1.model) if player == 1 else self.choose_action(state, player, player2.model)
            state = self.game.get_next_state(state, action, player)  
            value, is_terminal = self.game.get_value_and_terminated(state, action)

            if is_terminal:
                if value == 1:
                    return player
                else:
                    return 0
            player = self.game.get_opponent(player)
    
    def choose_action(self, state, player, model):
        neutral_state = self.game.change_perspective(state, player)

        action_probs = dict()
        for action in self.game.get_valid_moves(state):
            state_after_action = self.game.get_next_state(np.copy(state), action, player)
            new_neutral_state = self.game.change_perspective(state_after_action, self.game.get_opponent(player))

            value = model(
                torch.tensor(self.game.get_encoded_state(new_neutral_state), device=model.device).unsqueeze(0)
            )
            value = value.item()
            action_probs[action] = self.game.get_opponent_value(value) / 2 + 0.5
        
        if self.args['temperature'] == 'inf':
            return max(action_probs, key = action_probs.get)
        
        action_probs.update((key, value**self.args['temperature']) for key, value in action_probs.items())
        return random.choices(list(action_probs.keys()), weights = action_probs.values(), k = 1)[0]

##TIC-TAC-TOE-TRAINING

In [None]:
game = TicTacToe()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BigBrain(game, 8, 8, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
args = {
    'num_iterations': 1,
    'num_selfPlay_iterations': 512,
    'num_parallel_games': 128,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
trainer = ELO_Tournament_Trainer(game, device, args, model, optimizer)
players = trainer.train_models()

  0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
players

##CONNECT-FOUR-TRAINING

In [None]:
game = ConnectFour()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BigBrain(game, 8, 8, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
args = {
    'num_iterations': 1,
    'num_selfPlay_iterations': 512,
    'num_parallel_games': 128,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
trainer = ELO_Tournament_Trainer(game, device, args, model, optimizer)
players = trainer.train_models()

#TESTS

##TIC-TAC-TOE-TESTS

In [None]:
game = TicTacToe()
win1 = np.array([[0, 0, 1], [-1, 1, 0], [1, -1, 0]])
win2 = np.array([[1, 0, -1], [1, -1, 0], [1, 0, 0]])
win3 = np.array([[0, 0, 0], [0, -1, -1], [1, 1, 1]])
win4 = np.array([[1, 0, -1], [1, -1, 0], [0, 0, 0]])
win5 = np.array([[0, 0, 1], [-1, 0, 0], [1, -1, 0]])
wins = [win1, win2, win3, win4, win5]

draw1 = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
draw2 = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])
draw3 = np.array([[-1, 1, -1], [-1, 1, -1], [1, -1, 1]])
draws = [draw1, draw2, draw3]

lose1 = np.array([[0, 0, -1], [1, -1, 0], [-1, 1, 0]])
lose2 = np.array([[-1, 0, -1], [1, 0, 0], [-1, 1, 0]])
lose3 = np.array([[-1, 0, 1], [-1, 0, 0], [-1, 1, 0]])
loses = [lose1, lose2, lose3]

for i, player in enumerate(players):
    print(f'model{i}')
    model = players[i].model
    model.eval()
    print('winning_possitions:')
    for winning_position in wins:
        encoded_state = game.get_encoded_state(winning_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)

    print('\ndrawing_possitions:')
    for drawing_position in draws:
        encoded_state = game.get_encoded_state(drawing_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)

    print('\nlosing_possitions:')
    for losing_position in loses:
        encoded_state = game.get_encoded_state(losing_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)
    print()

#TOURNAMENTS

##TIC-TAC-TOE

#ADDITIONAL RESOURCES

In [None]:
class CRAZY_FAST_LEARNING:
    def __init__(self, game, device, args, number_of_players = 4, players = []):
        self.game = game
        self.number_of_players = number_of_players
        self.players = players
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.args = args
    
    def train_models(self, no_improvement = 10):
        while len(self.players) < self.number_of_players:
            model = BigBrain(game, 8, 8, device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
            player = evaluateAIParallel(model, optimizer, game, self.args)
            player.learn()
            self.players.append(copy.deepcopy(player))
        
        while no_improvement > 9:
            results = self.tournament()
            
            best_performer = max(results, key=results.get)
            worst_performer = min(results, key=results.get)
            print(f'best_performence: {max(results.values())}')
            print(f'worst_performence: {min(results.values())}')
            self.players.remove(worst_performer)
            best_performer = copy.deepcopy(best_performer)
            best_performer.learn()
            self.players.append(best_performer)
            no_improvement -= 1
        return self.players
    
    def tournament(self):
        players_performances = {player: 0 for player in self.players}
        for player1 in self.players:
            for player2 in self.players:
                for _ in range(10):
                    result = self.play_game(player1, player2)
                    if result == 1:
                        players_performances[player2] -= 1
                    if result == -1:
                        players_performances[player1] -= 1
        return players_performances

    def play_game(self, player1, player2):
        state = self.game.get_initial_state()
        player = 1
        while True:
            action = self.choose_action(state, player, player1.model) if player == 1 else self.choose_action(state, player, player2.model)
            state = self.game.get_next_state(state, action, player)  
            value, is_terminal = self.game.get_value_and_terminated(state, action)

            if is_terminal:
                if value == 1:
                    return player
                else:
                    return 0
            player = self.game.get_opponent(player)
    
    def choose_action(self, state, player, model):
        neutral_state = self.game.change_perspective(state, player)

        action_probs = dict()
        for action in self.game.get_valid_moves(state):
            state_after_action = self.game.get_next_state(np.copy(state), action, player)
            new_neutral_state = self.game.change_perspective(state_after_action, self.game.get_opponent(player))
            
            value = model(
                torch.tensor(self.game.get_encoded_state(new_neutral_state), device=model.device).unsqueeze(0)
            )
            value = value.item()
            action_probs[action] = self.game.get_opponent_value(value) / 2 + 0.5
        
        if self.args['temperature'] == 'inf':
            return max(action_probs, key = action_probs.get)
        
        action_probs.update((key, value**self.args['temperature']) for key, value in action_probs.items())
        return random.choices(list(action_probs.keys()), weights = action_probs.values(), k = 1)[0]

In [None]:
game = TicTacToe()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = {
    'num_iterations': 1,
    'num_selfPlay_iterations': 1024,
    'num_parallel_games': 128,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
trainer = CRAZY_FAST_LEARNING(game, device, args)
players = trainer.train_models()

In [None]:
players

In [None]:
game = TicTacToe()
win1 = np.array([[0, 0, 1], [-1, 1, 0], [1, -1, 0]])
win2 = np.array([[1, 0, -1], [1, -1, 0], [1, 0, 0]])
win3 = np.array([[0, 0, 0], [0, -1, -1], [1, 1, 1]])
win4 = np.array([[1, 0, -1], [1, -1, 0], [0, 0, 0]])
win5 = np.array([[0, 0, 1], [-1, 0, 0], [1, -1, 0]])
wins = [win1, win2, win3, win4, win5]

draw1 = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
draw2 = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])
draw3 = np.array([[-1, 1, -1], [-1, 1, -1], [1, -1, 1]])
draws = [draw1, draw2, draw3]

lose1 = np.array([[0, 0, -1], [1, -1, 0], [-1, 1, 0]])
lose2 = np.array([[-1, 0, -1], [1, 0, 0], [-1, 1, 0]])
lose3 = np.array([[-1, 0, 1], [-1, 0, 0], [-1, 1, 0]])
loses = [lose1, lose2, lose3]

for i, player in enumerate(players):
    print(f'model{i}')
    model = players[1].model
    model.eval()
    print('winning_possitions:')
    for winning_position in wins:
        encoded_state = game.get_encoded_state(winning_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)

    print('\ndrawing_possitions:')
    for drawing_position in draws:
        encoded_state = game.get_encoded_state(drawing_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)

    print('\nlosing_possitions:')
    for losing_position in loses:
        encoded_state = game.get_encoded_state(losing_position)
        tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

        value = model(tensor_state)
        value = value.item()
        print(value)