# Projeto 2 - Desenvolvimento de um Alpha Zero capaz de jogar jogos

## Introdução

O segundo projeto proposto para a unidade curricular de Laboratórios de Inteligência Artificial e Ciência de Dados consiste na criação de um algoritmo Alpha Zero capaz de jogar 2 jogos, Attax e Go, sendo testados em diferentes cenários. No caso do Attax, tabuleiro 4x4, 6x6, e tamanho variável, enquanto no Go tamanho de 7x7 e 9x9.

## Jogos implementados

#### Attax:
O Attax é um jogo de estratégia de tabuleiro abstrato de dois jogadores, que consiste em tentar fazer com que no final do jogo a maioria das peças do tabuleiro sejam do jogador, convertendo a maioria das peças do oponente possiveis.
Geralmente, o jogo começa com 2 peças para cada jogador, cada uma nos cantos do tabuleiro.


#### Go:
O Go também é um jogo de estratégia de tabuleiro, em qu eo objetivo é capturar, isto é, rodear mais território que o adversário, sendo considerado o jogo de tabuleiro jogado de forma contínua há mais tempo.
O jogo usa peças pretas e brancas colocadas em posições livres no tabuleiro, sendo que o jogador que começa em segundo recebe pontos extra devido a essa desvantagem.



## Bibliotecas utilizadas
Na implementação dos jogos foram usadas diversas bibliotecas, tais como:


### Bibliotecas Go:
Na implementação do jogo:
- import numpy as np

Na interface:
- import pygame
    from pygame.locals import QUIT
    import sys
    import time
    import torch
    from torch.optim import Adam
    import numpy as np
    from go_pygame.go_1 import Go
    from alphazero import ResNet
    from alphazero import MCTS
    from button import Button
    import os

### Bibliotecas attax:
Na implementação do jogo:
- import numpy as np

Na interface:
- import pygame
    from pygame.locals import QUIT
    import sys
    import numpy as np
    from attaxx.attaxx import Attaxx
    pygame.init()
    from alphazero import MCTS
    import torch
    from alphazero import ResNet
    import time
    from args_manager import load_args_from_json

## Implementação dos jogos

## Go
A implementação do Go que nós optámos por utilizar baseou se em definir uma classe Go, que inicializa os atributos da classe necessários para representar o estado do jogo, determinar movimentos legais/ilegais, mas também verificar a pontuação e por um estado terminal.


Explorando um pouco mais sobre o papel de cada uma das funções que moldaram o nosso jogo final:

get_initial_state(self): Retorna o ponto de partida do tabuleiro, representado por uma matriz vazia de zeros.

count(self, x, y, state, player, liberties, block): Uma função recursiva que não apenas contabiliza as liberdades, mas também identifica grupos de pedras no tabuleiro.

clear_block(self, block, state): Elimina as pedras capturadas, retirando-as do tabuleiro.

get_current_state(self): Oferece uma visão instantânea do estado atual do tabuleiro.

restore_board(self, state): Devolve o tabuleiro ao seu estado original após a contagem de pedras e liberdades.

print_board(self, state): Traduz visualmente o cenário atual do tabuleiro.

captures(self, state, player, a, b): Examina e remove pedras capturadas, seguindo as regras do jogo.

set_stone(self, a, b, state, player): Adiciona uma pedra na posição (a, b) do tabuleiro.

get_next_state_mcts(self, state, action, player): Calcula o próximo estado do tabuleiro após uma ação, levando em conta regras específicas para capturas.

get_next_state(self, state, action, player): Determina o próximo estado do tabuleiro após uma ação.

is_valid_move(self, state, action, player): Avalia se uma jogada é válida ou não.

get_valid_moves(self, state, player): Fornece uma representação das jogadas permitidas para um jogador específico.

get_value_and_terminated(self, state, player): Retorna o valor atual do estado e indica se o jogo chegou ao fim.

scoring(self, state): Calcula a pontuação do jogo com base no cenário atual.

count_influenced_territory_enhanced(self, board): Conta a influência territorial de maneira mais aprimorada, contribuindo para a pontuação.

get_opponent(self, player): Identifica o oponente do jogador atual.

get_opponent_value(self, value): Inverte o valor associado a um jogador.

get_encoded_state(self, state): Codifica o estado do tabuleiro em camadas binárias para análise.

change_perspective(self, state, player): Modifica a perspectiva do estado do tabuleiro para o jogador específico.

### Código da implementação:

In [4]:
import numpy as np


class Go():
    EMPTY = 0
    BLACK = 1
    WHITE = -1
    BLACKMARKER = 4
    WHITEMARKER = 5
    LIBERTY = 8

    def __init__(self, small_board=True):
        self.row_count = 7 if small_board else 9
        self.column_count = 7 if small_board else 9
        self.board_size = 7 if small_board else 9
        self.komi = 5.5
        self.action_size = self.row_count * self.column_count + 1
        self.liberties = []
        self.name="Go"
        self.block = []
        self.seki_count = 0
        self.seki_liberties = []
        self.state_history = [self.get_initial_state()]
        self.currrent_player = self.BLACK
        self.passed_player_1 = False
        self.passed_player_2 = False


    def get_initial_state(self):
        board = np.zeros((self.row_count, self.column_count))
        self.state_history = [np.copy(board)]
        return board

    def count(self, x, y, state: list, player: int, liberties: list, block: list):


        # initialize piece
        piece = state[y][x]
        # if there's a stone at square of the given player
        if piece == player:
            # save stone coords
            block.append((y, x))
            # mark the stone
            if player == self.BLACK:
                state[y][x] = self.BLACKMARKER
            else:
                state[y][x] = self.WHITEMARKER

            # look for neighbours recursively
            if y - 1 >= 0:
                liberties, block = self.count(x, y - 1, state, player, liberties, block)  # walk north
            if x + 1 < self.column_count:
                liberties, block = self.count(x + 1, y, state, player, liberties, block)  # walk east
            if y + 1 < self.row_count:
                liberties, block = self.count(x, y + 1, state, player, liberties, block)  # walk south
            if x - 1 >= 0:
                liberties, block = self.count(x - 1, y, state, player, liberties, block)  # walk west

        # if square is empty
        elif piece == self.EMPTY:
            # mark liberty
            state[y][x] = self.LIBERTY
            # save liberties
            liberties.append((y, x))

        # print("Liberties: " + str(len(self.liberties)) + " in: " + str(x) + "," + str(y))
        # print("Block: " + str(len(self.block)) + " in: " + str(x) + "," + str(y))
        return liberties, block

    # remove captured stones
    def clear_block(self, block: list, state: list) -> list:


        # clears the elements in the block of elements which is captured
        for i in range(len(block)):
            y, x = block[i]
            state[y][x] = self.EMPTY

        return state

    # restore board after counting stones and liberties

    def get_current_state(self):
        return self.state_history[-1]
    def restore_board(self, state: list) -> list:

        for y in range(len(state)):
            for x in range(len(state)):
                # restore piece
                val = state[y][x]
                if val == self.BLACKMARKER:
                    state[y][x] = self.BLACK
                elif val == self.WHITEMARKER:
                    state[y][x] = self.WHITE
                elif val == self.LIBERTY:
                    state[y][x] = self.EMPTY

        # print("After Restore Board")
        # print(state)
        return state

    def print_board(self, state) -> None:

        print("   ", end="")
        for j in range(self.column_count):
            print(f"{j:2}", end=" ")
        print("\n  +", end="")
        for _ in range(self.column_count):
            print("---", end="")
        print()

        for i in range(self.row_count):
            print(f"{i:2}|", end=" ")
            for j in range(self.column_count):
                print(f"{str(int(state[i][j])):2}", end=" ")
            print()

    def captures(self, state: list, player: int, a: int, b: int):

        check = False
        neighbours = []
        if (a > 0): neighbours.append((a - 1, b))
        if (a < self.column_count - 1): neighbours.append((a + 1, b))
        if (b > 0): neighbours.append((a, b - 1))
        if (b < self.row_count - 1): neighbours.append((a, b + 1))

        # loop over the board squares
        for pos in neighbours:
            # print(pos)
            x = pos[0]
            y = pos[1]
            # init piece
            piece = state[x][y]

            # if stone belongs to given colour
            if piece == player:
                # print("opponent piece")
                # count liberties
                liberties = []
                block = []
                liberties, block = self.count(y, x, state, player, liberties, block)
                # print("Liberties in count: " + str(len(liberties)))
                # if no liberties remove the stones
                if len(liberties) == 0:
                    # clear block

                    state = self.clear_block(block, state)

                    # if the move is a "ko" move but causes the capture of stones, then it is not allowed, unless it is the second move, in which case it is dealt afterwards
                    if self.seki_count == 0:
                        # print("Seki Found")
                        # returns False, which means that the move has caused a capture (the logic worked out that way in the initial development and i'm not sure what it would affect if it is changed)
                        check = True
                        self.seki_count = 1
                        continue
                # restore the board
                state = self.restore_board(state)
        # print("Seki Count: " + str(self.seki_count))
        return check, state




    def set_stone(self, a, b, state, player):
        state[a][b] = player
        return state

    def get_next_state_mcts(self, state, action, player):

        if action == self.row_count * self.column_count:
            if self.passed_player_1:
                self.passed_player_2 = True
            else:
                self.passed_player_1 = True
            return state # pass move

        a = action // self.row_count
        b = action % self.column_count

        state_copy = np.copy(state)
        state[a][b] = player
        state = self.captures(state, -player, a, b)[1]
        self.passed_player_1 = False
        self.passed_player_2 = False

        self.state_history.append(np.copy(state_copy))

        return state

    def get_next_state(self, state, action, player):

        if action == self.row_count * self.column_count:
            return state # pass move

        a = action // self.row_count
        b = action % self.column_count

        state_copy = np.copy(state)
        state[a][b] = player
        state = self.captures(state, -player, a, b)[1]

        self.state_history.append(np.copy(state_copy))

        return state

    # Dentro da classe Go

    def is_valid_move(self, state: list, action: tuple, player: int) -> bool:
        a, b = action[0], action[1]  # Mantenha as coordenadas originais
        state_copy = np.copy(state).astype(np.int8)

        if len(self.state_history) > 1:
            if np.array_equal(state, self.state_history[-2]):
                #print("Ko violation")
                return False

        # Restante do código permanece inalterado
        if a < 0 or a >= self.row_count or b < 0 or b >= self.column_count:
            #print("Invalid move: Out of bounds")
            return False

        if state[a][b] != self.EMPTY:
            #print("Space Occupied")
            return False

        state_copy = self.set_stone(a, b, state_copy, player)

        if self.captures(state_copy, -player, a, b)[0] == True:
            return True
        else:
            libs, block = self.count(b, a, state_copy, player, [], [])
            if len(libs) == 0:
                #print("Invalid move: Suicide")
                return False
            else:
                return True

    def get_valid_moves(self, state, player):

        newstate = np.zeros((self.row_count, self.column_count))
        for a in range(0, self.column_count):
            for b in range(0, self.row_count):
                if self.is_valid_move(state, (a, b), player):
                    newstate[a][b] = 1

        newstate = newstate.reshape(-1)
        newstate = np.concatenate([newstate, [1]])
        return (newstate).astype(np.uint8)

    def get_value_and_terminated(self, state, player):

        '''
        # Description:
        Returns the value of the state and if the game is over.
        '''

        scoring, endgame = self.scoring(state)

        if self.passed_player_1 and self.passed_player_2:
            endgame=True

        if endgame:
            if scoring > 0:
                return 1, True
            else:
                return -1, True
        else:
            if scoring > 0:
                return 1, False
            else:
                return -1, False


    def scoring(self, state):
        '''
        # Description:
        Checks the score of the game.
        '''
        black = 0
        white = 0
        empty = 0
        endgame = False
        # print("Scoring")
        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 5: # if more than 1/4 of the board is empty, it is not the endgame
                        endgame = False

        black, white = self.count_influenced_territory_enhanced(state)

        return black - (white + self.komi), endgame


    def count_influenced_territory_enhanced(self, board):
        black_territory = 0
        white_territory = 0
        visited = set()

        # Function to calculate influence score
        def influence_score(x, y):
            score = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if 0 <= nx < len(board) and 0 <= ny < len(board[0]):
                    score += board[nx][ny]
            return score

        # Function to explore territory
        def explore_territory(x, y):
            nonlocal black_territory, white_territory
            if (x, y) in visited or not (0 <= x < len(board) and 0 <= y < len(board[0])):
                return
            visited.add((x, y))

            if board[x][y] == 0:
                score = influence_score(x, y)
                if score > 0:
                    black_territory += 1
                elif score < 0:
                    white_territory += 1

        for i in range(len(board)):
            for j in range(len(board[0])):
                if board[i][j] == 0 and (i, j) not in visited:
                    explore_territory(i, j)

        return black_territory, white_territory


    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def get_encoded_state(self, state):
        layer_1 = np.where(np.array(state) == -1, 1, 0).astype(np.float32)
        layer_2 = np.where(np.array(state) == 0, 1, 0).astype(np.float32)
        layer_3 = np.where(np.array(state) == 1, 1, 0).astype(np.float32)

        result = np.stack([layer_1, layer_2, layer_3]).astype(np.float32)

        return result

    def change_perspective(self, state, player):
        return state * player




### Codigo da implementação da interface

In [5]:
import pygame
from pygame.locals import QUIT
import sys
import time
import torch
from torch.optim import Adam
import numpy as np
from go_pygame.go_1 import Go
from alphazero import ResNet
from alphazero import MCTS
from button import Button
import os


# Define as dimensões da janela

# Define o tamanho do grid e do tabuleiro
GRID_SIZE = 75  # Ajuste conforme necessário


# Define o modo de exibição
SCREEN = pygame.display.set_mode((1280, 720))
pygame.display.set_caption("GO")
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
GRAY = (169, 169, 169)

STONE_RADIUS = GRID_SIZE // 2 - 5
BLACK_STONE_COLOR = (0, 0, 0)
WHITE_STONE_COLOR = (255, 255, 255)
BG = pygame.image.load("images/blue_background.jpg")



def get_font(size): # Returns Press-Start-2P in the desired size
  return pygame.font.Font("images/font.ttf", size)


def draw_board(board_state):
    board_size = len(board_state)

    offset_x = (1280 - board_size * GRID_SIZE) // 2
    offset_y = (720 - board_size * GRID_SIZE) // 2
    # Desenha o retângulo cinza como fundo

    for i in range(board_size-1):
        for j in range(board_size-1):
            x = j * GRID_SIZE + offset_x +50
            y = i * GRID_SIZE + offset_y +50
            pygame.draw.rect(SCREEN, GRAY, (x, y, GRID_SIZE, GRID_SIZE), border_radius=15)
    for i in range(board_size):
        for j in range(board_size):
            x = j * GRID_SIZE + offset_x +50
            y = i * GRID_SIZE + offset_y +50
            stone = board_state[i][j]
            if stone == Go.BLACK:
                pygame.draw.circle(SCREEN, BLACK_STONE_COLOR, (x, y), STONE_RADIUS)
            elif stone == Go.WHITE:
                pygame.draw.circle(SCREEN, WHITE_STONE_COLOR, (x, y), STONE_RADIUS)


def prepair_model(game):
    args = {
            'game': 'Go',
            'num_iterations': 10,             # number of highest level iterations
            'num_selfPlay_iterations': 10,   # number of self-play games to play within each iteration
            'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
            'num_epochs': 25,                  # number of epochs for training on self-play data for each iteration
            'batch_size': 8,                # batch size for training
            'temperature': 1.25,              # temperature for the softmax selection of moves
            'C': 2,                           # the value of the constant policy
            'augment': False,                 # whether to augment the training data with flipped states
            'dirichlet_alpha': 0.3,           # the value of the dirichlet noise
            'dirichlet_epsilon': 0.25,        # the value of the dirichlet noise
            'alias': ('Goolaola')
    }
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet(game, 9, 3, device)
    model.load_state_dict(torch.load(f'AlphaZero/Models/Goolaola/model_-1.pt', map_location=device))
    #optimizer.load_state_dict(torch.load(f'AlphaZero/Models/Attax_TestModel/optimizer_4.pt', map_location=device))
    mcts = MCTS(model, game, args)
    return mcts




def play_go(board_size):
    go_game = Go(board_size)
    player = 1
    action=0
    b=0
    w=0
    state=go_game.get_initial_state()
    mcts=prepair_model(go_game)
    if board_size == 7: margin = 1
    else: margin = 2
    while True:
        SCREEN.blit(BG,(0,0))
        Ataxx_MENU_TEXT = get_font(50).render("GO", True, "#d7fcd4")
        Ataxx_MENU_RECT = Ataxx_MENU_TEXT.get_rect(center=(180,100))
        SCREEN.blit(Ataxx_MENU_TEXT, Ataxx_MENU_RECT)
        Go_Pontuacao_TEXT = get_font(30).render(f"HUMAN {b} - {w} ALPHAZERO", True, "#d7fcd4")
        Go_Pontuacao_Rect = Go_Pontuacao_TEXT.get_rect(center=(700,650))
        SCREEN.blit(Go_Pontuacao_TEXT,Go_Pontuacao_Rect)
        PASS = Button(image=None, pos=(1100,300), text_input="PASS", font=get_font(50), base_color="#b68f40", hovering_color="White")
        PASS_POS = pygame.mouse.get_pos()
        PASS.changeColor(PASS_POS)
        PASS.update(SCREEN)
        draw_board(state)
        if player==1:
            for event in pygame.event.get():
                if event.type == QUIT:
                    pygame.quit()
                    sys.exit()

                if event.type == pygame.MOUSEBUTTONDOWN:
                    if PASS.checkForInput(PASS_POS):
                        action=board_size*board_size
                        state=go_game.get_next_state_mcts(state,action, player)
                        player = -player
                    else:
                        mouse_pos = pygame.mouse.get_pos()
                        col = (mouse_pos[0] - ((1280 - 400) // 2)) // 75 +margin
                        row = (mouse_pos[1] - ((720 - 400) // 2)) // 75 +margin

                        print(col, row)
                        if  row >=0 and col >=0  and row < board_size and col < board_size:
                            action=col + row * board_size
                            if   go_game.is_valid_move(state, (row,col), player):
                                state=go_game.get_next_state_mcts(state,action, player)
                                player = -player  # Switch player after a move
                                draw_board(state)
                                print(state)
        else:
            time.sleep(1)
            neut = go_game.change_perspective(state, -player)
            action = mcts.search(neut, player)
            action = np.argmax(action)
            print(action)
            if action == board_size*board_size:
                state=go_game.get_next_state_mcts(state,action, player)
                print("pasou")
                pass_text = get_font(20).render("O modelo passou", True, "#ffffff")
                pass_rect = pass_text.get_rect(center=(1105, 400))
                SCREEN.blit(pass_text, pass_rect)
                pygame.display.update()
                pygame.time.wait(2000)
            else:
                state = go_game.get_next_state_mcts(state, action, player)
            player = -player
        winner, win = go_game.get_value_and_terminated(state,player)
        b,w=go_game.count_influenced_territory_enhanced(state)
        if win:
            print(b,w,winner)
            return b,w,winner
        pygame.display.update()

if __name__ == "__main__":
    play_go(7)  # Você pode ajustar o tamanho do tabuleiro conforme necessário


pygame 2.5.2 (SDL 2.28.3, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


FileNotFoundError: [Errno 2] No such file or directory: 'AlphaZero/Models/Goolaola/model_-1.pt'

## Attax
A estrutura do attax é muito parecida aquela implementada  