# Titulo

texto de introducao

### Go

texto introducao do Go

### Implementação do Go Graficamente

In [1]:
import pygame
import numpy as np
import itertools
import sys
import networkx as nx
import collections
from pygame import gfxdraw

# Define board size
print("Select board size (7 or 9): ")
size = int(input())

# Game constants
BOARD_BROWN = (141, 104, 75)  # Change color as desired
BOARD_WIDTH = 800  # New size of the board
BOARD_BORDER = 75
STONE_RADIUS = int(abs(BOARD_WIDTH / size * 20 * 0.02))  # Adjust stone size to grid size
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
TURN_POS = (BOARD_BORDER, 20)
SCORE_POS = (BOARD_BORDER, BOARD_WIDTH - BOARD_BORDER + 30)
DOT_RADIUS = 2


def make_grid(size):
    # Return list of (start_point, end_point pairs) defining gridlines
    start_points, end_points = [], []

    # Vertical start points (constant y)
    xs = np.linspace(BOARD_BORDER, BOARD_WIDTH - BOARD_BORDER, size)
    ys = np.full((size), BOARD_BORDER)
    start_points += list(zip(xs, ys))

    # Horizontal start points (constant x)
    xs = np.full((size), BOARD_BORDER)
    ys = np.linspace(BOARD_BORDER, BOARD_WIDTH - BOARD_BORDER, size)
    start_points += list(zip(xs, ys))

    # Vertical end points (constant y)
    xs = np.linspace(BOARD_BORDER, BOARD_WIDTH - BOARD_BORDER, size)
    ys = np.full((size), BOARD_WIDTH - BOARD_BORDER)
    end_points += list(zip(xs, ys))

    # Horizontal end points (constant x)
    xs = np.full((size), BOARD_WIDTH - BOARD_BORDER)
    ys = np.linspace(BOARD_BORDER, BOARD_WIDTH - BOARD_BORDER, size)
    end_points += list(zip(xs, ys))

    return start_points, end_points


class HumanPlayer:
    def make_move(self, board):
        pass


class AI:
    def __init__(self, color):
        self.color = color

    def make_move(self, board):
        # Implement AI move generation logic here
        # Construct alphaZero on the side and import here

        # random move generator (valid moves only)
        valid_moves = [(col, row) for col in range(board.shape[0]) for row in range(board.shape[1]) if board[col, row] == 0]
        if not valid_moves:
            return None  # No valid moves available

        return valid_moves[np.random.choice(len(valid_moves))]


def xy_to_colrow(x, y, size):
    inc = (BOARD_WIDTH - 2 * BOARD_BORDER) / (size - 1)
    x_dist = x - BOARD_BORDER
    y_dist = y - BOARD_BORDER
    col = int(round(x_dist / inc))
    row = int(round(y_dist / inc))
    return col, row


def colrow_to_xy(col, row, size):
    inc = (BOARD_WIDTH - 2 * BOARD_BORDER) / (size - 1)
    x = int(BOARD_BORDER + col * inc)
    y = int(BOARD_BORDER + row * inc)
    return x, y


def has_no_liberties(board, group):
    for x, y in group:
        if x > 0 and board[x - 1, y] == 0:
            return False
        if y > 0 and board[x, y - 1] == 0:
            return False
        if x < board.shape[0] - 1 and board[x + 1, y] == 0:
            return False
        if y < board.shape[0] - 1 and board[x, y + 1] == 0:
            return False
    return True


def get_stone_groups(board, color):
    size = board.shape[0]
    color_code = 1 if color == "black" else 2
    xs, ys = np.where(board == color_code)
    graph = nx.grid_graph(dim=[size, size])
    stones = set(zip(xs, ys))
    all_spaces = set(itertools.product(range(size), range(size)))
    stones_to_remove = all_spaces - stones
    graph.remove_nodes_from(stones_to_remove)
    return nx.connected_components(graph)


def is_valid_move(col, row, board):
    if col < 0 or col >= board.shape[0]:
        return False
    if row < 0 or row >= board.shape[0]:
        return False
    return board[col, row] == 0


class Game:
    def __init__(self, size, player1, player2):
        self.board = np.zeros((size, size))
        self.size = size
        self.black_turn = True
        self.prisoners = collections.defaultdict(int)
        self.start_points, self.end_points = make_grid(self.size)
        self.player1 = player1
        self.player2 = player2
        self.player1_passed = False
        self.player2_passed = False
        self.current_player = self.player1 if self.black_turn else self.player2

    def init_pygame(self):
        pygame.init()
        screen = pygame.display.set_mode((BOARD_WIDTH, BOARD_WIDTH))
        self.screen = screen
        self.ZOINK = pygame.mixer.Sound("wav/zoink.wav")
        self.CLICK = pygame.mixer.Sound("wav/click.wav")
        self.font = pygame.font.SysFont("arial", 30)

    def clear_screen(self):
        # Fill board and add gridlines
        self.screen.fill(BOARD_BROWN)
        for start_point, end_point in zip(self.start_points, self.end_points):
            pygame.draw.line(self.screen, BLACK, start_point, end_point)

        # Add guide dots
        guide_dots = [3, self.size // 2, self.size - 4]
        for col, row in itertools.product(guide_dots, guide_dots):
            x, y = colrow_to_xy(col, row, self.size)
            gfxdraw.aacircle(self.screen, x, y, DOT_RADIUS, BLACK)
            gfxdraw.filled_circle(self.screen, x, y, DOT_RADIUS, BLACK)

        pygame.display.flip()

    def check_territory(self, col, row, color):
        visited = set()

        def dfs(x, y):
            if x < 0 or x >= self.size or y < 0 or y >= self.size:
                return False
            if self.board[x, y] == color or (x, y) in visited:
                return True
            visited.add((x, y))
            return all(dfs(nx, ny) for nx, ny in ((x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)))

        return dfs(col, row)

    def calculate_score(self):
        """Calculate the score of the game."""
        territory_black, territory_white, stones_black, stones_white = 0, 0, 0, 0

        for col in range(self.size):
            for row in range(self.size):
                color = self.board[col, row]
                if color == 1:
                    stones_black += 1
                elif color == 2:
                    stones_white += 1
                elif color == 0:
                    neighbors = [(col + i, row + j) for i, j in [(-1, 0), (1, 0), (0, -1), (0, 1)] if
                                0 <= col + i < self.size and 0 <= row + j < self.size]
                    if all(0 <= n_col < self.size and 0 <= n_row < self.size and self.board[n_col, n_row] != color
                        for n_col, n_row in neighbors):
                        # Empty intersection surrounded by opponent's stones
                        territory_black += 1
                    elif all(0 <= n_col < self.size and 0 <= n_row < self.size and self.board[n_col, n_row] != color
                            for n_col, n_row in neighbors):
                        # Empty intersection surrounded by opponent's stones
                        territory_white += 1

        return territory_black, territory_white, stones_black, stones_white

    def draw(self):
        # Draw stones - filled circle and antialiased ring
        self.clear_screen()
        for col, row in zip(*np.where(self.board == 1)):
            x, y = colrow_to_xy(col, row, self.size)
            gfxdraw.aacircle(self.screen, x, y, STONE_RADIUS, BLACK)
            gfxdraw.filled_circle(self.screen, x, y, STONE_RADIUS, BLACK)
        for col, row in zip(*np.where(self.board == 2)):
            x, y = colrow_to_xy(col, row, self.size)
            gfxdraw.aacircle(self.screen, x, y, STONE_RADIUS, WHITE)
            gfxdraw.filled_circle(self.screen, x, y, STONE_RADIUS, WHITE)

        # Text for score and turn info
        territory_black, territory_white, stones_black, stones_white = self.calculate_score()
        score_msg = (
            f"Black - Territory: {territory_black}, Stones: {stones_black} | White - Territory: {territory_white}, Stones: {stones_white}"
        )
        txt = self.font.render(score_msg, True, BLACK)
        self.screen.blit(txt, SCORE_POS)
        turn_msg = (
            f"{'Black' if self.black_turn else 'White'} to move. "
            + "Click to place a stone, press P to pass."
        )
        txt = self.font.render(turn_msg, True, BLACK)
        self.screen.blit(txt, TURN_POS)

        pygame.display.flip()

    def update(self):
        events = pygame.event.get()
        for event in events:
            if event.type == pygame.QUIT:
                self.print_final_scores()
                pygame.quit()
                sys.exit()

        if not self.current_player_is_human():
            self.make_ai_move()
    def make_ai_move(self):
        ai_col, ai_row = self.current_player.make_move(self.board)
        while not is_valid_move(ai_col, ai_row, self.board):
            ai_col, ai_row = self.current_player.make_move(self.board)
        self.board[ai_col, ai_row] = 2 if self.current_player == self.player2 else 1
        self.black_turn = not self.black_turn
        self.current_player = self.player1 if self.black_turn else self.player2  # Update current player
        self.draw()

        if self.check_end_game():
            self.print_final_scores()
            pygame.quit()
            sys.exit()
    def handle_click(self):
        # get board position
        x, y = pygame.mouse.get_pos()
        col, row = xy_to_colrow(x, y, self.size)
        if not is_valid_move(col, row, self.board):
            self.ZOINK.play()
            return

        # update board array
        self.board[col, row] = 1 if self.current_player == self.player1 else 2

        # get stone groups for black and white
        self_color = "black" if self.black_turn else "white"
        other_color = "white" if self.black_turn else "black"

        # handle captures
        prisoners_captured = 0
        for group in list(get_stone_groups(self.board, other_color)):
            if has_no_liberties(self.board, group):
                prisoners_captured += len(group)
                for i, j in group:
                    self.board[i, j] = 0

        # update prisoners count
        if self_color == "black":
            self.prisoners['white'] += prisoners_captured
        else:
            self.prisoners['black'] += prisoners_captured

        # change turns and draw screen
        self.CLICK.play()
        self.black_turn = not self.black_turn
        self.current_player = self.player1 if self.black_turn else self.player2  # Update current player
        self.draw()

        # If it's the AI's turn, let AI make a move
        if not self.current_player_is_human():
            ai_col, ai_row = self.current_player.make_move(self.board)
            while not is_valid_move(ai_col, ai_row, self.board):
                ai_col, ai_row = self.current_player.make_move(self.board)
            self.board[ai_col, ai_row] = 2 if self.current_player == self.player2 else 1
            self.black_turn = not self.black_turn
            self.current_player = self.player1 if self.black_turn else self.player2  # Update current player
            self.draw()

    def pass_move(self):
        if self.black_turn:
            self.player1_passed = True
        else:
            self.player2_passed = True

        self.black_turn = not self.black_turn
        self.draw()

        if self.check_end_game():
            self.print_final_scores()
            pygame.quit()
            sys.exit()

    def check_end_game(self):
        if self.passed_twice() or not any(
            is_valid_move(col, row, self.board) for col in range(self.size) for row in range(self.size)
        ):
            return True
        return False

    def passed_twice(self):
        return self.player1_passed and self.player2_passed

    def print_final_scores(self):
        territory_black, territory_white, stones_black, stones_white = self.calculate_score()
        print("Game Over!")
        print(f"Black - Territory: {territory_black}, Stones: {stones_black}; Total Score: {territory_black+stones_black}")
        print(f"White - Territory: {territory_white}, Stones: {stones_white}; Total Score: {territory_white+stones_white}")

        if territory_black + stones_black > territory_white + stones_white:
            print("Black wins!")
        elif territory_black + stones_black < territory_white + stones_white:
            print("White wins!")
        else:
            print("It's a tie!")

    def update(self):
        events = pygame.event.get()
        for event in events:
            if event.type == pygame.MOUSEBUTTONUP:
                self.handle_click()
            if event.type == pygame.QUIT:
                self.print_final_scores()
                pygame.quit()
                sys.exit()
            if event.type == pygame.KEYUP:
                if event.key == pygame.K_p:
                    self.pass_move()

    def current_player_is_human(self):
        return self.current_player is None or isinstance(self.current_player, HumanPlayer)

if __name__ == "__main__":
    print("Select Player 1 (H for human, A for AI): ")
    player1_type = input().upper()
    player1 = HumanPlayer() if player1_type == "H" else AI(color="black")

    print("Select Player 2 (H for human, A for AI): ")
    player2_type = input().upper()
    player2 = HumanPlayer() if player2_type == "H" else AI(color="white")

    g = Game(size, player1, player2)

    g.init_pygame()
    g.clear_screen()
    g.draw()

    while True:
        g.update()

        if not g.current_player_is_human():
            g.make_ai_move()

pygame 2.3.0 (SDL 2.24.2, Python 3.9.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Select board size (7 or 9): 
7
Select Player 1 (H for human, A for AI): 
1
Select Player 2 (H for human, A for AI): 
1


FileNotFoundError: No file 'wav/zoink.wav' found in working directory 'C:\Users\beatr\Desktop\UNI\3ºANO\Lab_IA_CD\Projeto2 - Attax&Go\Trabalho'.

### Classe Go() - Lógica

In [None]:
#Logica para usar no AlphaZero

class Go():

    EMPTY = 0
    BLACK = 1
    WHITE = -1
    BLACKMARKER = 4
    WHITEMARKER = 5
    LIBERTY = 8

    def __init__(self, size, komi):
        self.row_count = size
        self.column_count = size
        self.komi = 5.5
        self.action_size = self.row_count * self.column_count + 1
        self.liberties = []
        self.block = []
        self.seki_liberties = []
        
        def get_initial_state(self):
        '''
        # Description:
        Returns a board of the argument size filled of zeros.

        # Retuns:
        Empty board full of zeros
        '''
        board = np.zeros((self.row_count, self.column_count))
        return board
    
    
    def count(self, x, y, state: list, player:int , liberties: list, block: list) -> tuple[list, list]:
        '''
        # Description:
        Counts the number of liberties of a stone and the number of stones in a block.
        Follows a recursive approach to count the liberties of a stone and the number of stones in a block.

        # Returns:
        A tuple containing the number of liberties and the number of stones in a block.
        '''
        
        #initialize piece
        piece = state[y][x]
        #if there's a stone at square of the given player
        if piece == player:
            #save stone coords
            block.append((y,x))
            #mark the stone
            if player == self.BLACK:
                state[y][x] = self.BLACKMARKER
            else:
                state[y][x] = self.WHITEMARKER
            
            #look for neighbours recursively
            if y-1 >= 0:
                liberties, block = self.count(x,y-1,state,player,liberties, block) #walk north
            if x+1 < self.column_count:
                liberties, block = self.count(x+1,y,state,player,liberties, block) #walk east
            if y+1 < self.row_count:
                liberties, block = self.count(x,y+1,state,player,liberties, block) #walk south
            if x-1 >= 0:
                liberties, block = self.count(x-1,y,state,player,liberties, block) #walk west
                #if square is empty
        elif piece == self.EMPTY:
            #mark liberty
            state[y][x] = self.LIBERTY
            #save liberties
            liberties.append((y,x))

        return liberties, block
    
    #remove captured stones
    def clear_block(self, block: list, state: list) -> list:
        '''
        # Description:
        Clears the block of stones captured by the opponent on the board.

        # Returns:
        The board with the captured stones removed.
        '''

        #clears the elements in the block of elements which is captured
        for i in range(len(block)): 
            y, x = block[i]
            state[y][x] = self.EMPTY
        
        return state
    
     def print_board(self, state: list) -> None:
            '''
            # Description:
            Draws the board in the console.

            # Returns:
            None
            '''

        # Print column coordinates
            print("   ", end="")
            for j in range(len(state[0])):
                print(f"{j:2}", end=" ")
            print("\n  +", end="")
            for _ in range(len(state[0])):
                print("---", end="")
            print()

            # Print rows with row coordinates
            for i in range(len(state)):
                print(f"{i:2}|", end=" ")
                for j in range(len(state[0])):
                    print(f"{str(int(state[i][j])):2}", end=" ")
                print()
                
    def captures(self, state: list,player: int, a:int, b:int) -> tuple[bool, list]:
        '''
        # Description:
        Checks if a move causes a capture of stones of the player passed as an argument.
        If a move causes a capture, the stones are removed from the board.

        # Returns:
        A tuple containing a boolean indicating if a capture has been made and the board with the captured stones removed.
        '''
        check = False
        neighbours = []
        if(a > 0): neighbours.append((a-1, b))
        if(a < self.column_count - 1): neighbours.append((a+1, b))
        if(b > 0): neighbours.append((a, b - 1))
        if(b < self.row_count - 1): neighbours.append((a, b+1))
        #loop over the board squares
        for pos in neighbours:
            # print(pos)
            x = pos[0]
            y = pos[1]    
            # init piece
            piece = state[x][y]

                #if stone belongs to given colour
            if piece == player:
                # print("opponent piece")
                # count liberties
                liberties = []
                block = []
                liberties, block = self.count(y, x, state, player, liberties, block)
                # print("Liberties in count: " + str(len(liberties)))
                # if no liberties remove the stones
                if len(liberties) == 0: 
                    #clear block
                    state = self.clear_block(block, state)
                    check = True

                #restore the board
                state = self.restore_board(state)

        #print("Captures: " + str(check))
        return check, state
    
    def set_stone(self, a, b, state, player):
        '''
        # Description:
        Places the piece on the board. THIS DOES NOT account for the rules of the game, use get_next_state().

        # Retuns:
        Board with the piece placed.
        '''
        state[a][b] = player
        return state
    
    def get_next_state(self, state, action, player):
        '''
        # Description
        Plays the move, verifies and undergoes captures and saves the state to the history.
        
        # Returns:
        New state with everything updated.
        '''
        if action == self.row_count * self.column_count:
            return state # pass move

        a = action // self.row_count
        b = action % self.column_count

        # checking if the move is part of is the secondary move to a ko fight
        state = self.set_stone(a, b, state, player)
        # print(state)
        state = self.captures(state, -player, a, b)[1]
        return state
    
    def is_valid_move(self, state: list, action: tuple, player: int) -> bool:
        '''
        # Description:
        Checks if a move is valid.
        If a move repeats a previous state or commits suicide (gets captured without capturing back), it is not valid.
        
        A print will follow explaining the invalid move in case it exists.

        # Returns:
        A boolean confirming the validity of the move.
        '''

        a = action[0]
        b = action[1]

        #print(f"{a} , {b}")

        statecopy = np.copy(state).astype(np.int8)

        if state[a][b] != self.EMPTY:
            # print("Space Occupied")
            return False
        
        statecopy = self.set_stone(a,b,statecopy,player)

        if self.captures(statecopy, -player, a, b)[0] == True:
            return True
        else:
            #print("no captures")
            libs, block = self.count(b,a,statecopy,player,[],[])
            #print(libs)
            if len(libs) == 0:
                #print("Invalid, Suicide")
                return False
            else:
                return True
            
    def get_valid_moves(self, state, player):
        '''
        # Description:
        Returns a matrix with the valid moves for the current player.
        '''
        newstate = np.zeros((self.row_count, self.column_count))
        for a in range(0, self.column_count):
            for b in range(0, self.row_count):
                if self.is_valid_move(state, (a,b), player):
                    newstate[a][b] = 1
        
        newstate = newstate.reshape(-1)

        empty = 0
        endgame = True
        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 3: # if 2/3ds are already filled, skipping becomes available
                        endgame = False
                        break
        if endgame:
            newstate = np.concatenate([newstate, [1]])
        else:
            newstate = np.concatenate([newstate, [0]])
        return (newstate).astype(np.int8)
    
    def get_value_and_terminated(self, state, action, player):
        '''
        # Description:
        Returns the value of the state and if the game is over.
        '''

        scoring, endgame = self.scoring(state)

        if endgame:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, True
                else:
                    return -1, True
            else:
                if scoring < 0:
                    return 1, True
                else:
                    return -1, True
        else:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, False
                else:
                    return -1, False
            else:
                if scoring < 0:
                    return 1, False
                else:
                    return -1, False
                
                
    def scoring(self, state: list) -> int:
        '''
        # Description:
        Checks the score of the game. Score is calculated using:

        black - (white + komi)

        # Returns:
        Integer with score.
        '''
        black = 0
        white = 0
        empty = 0
        endgame = True

        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 4:
                        endgame = False
                        break

        black, white = self.count_influenced_territory_enhanced(state)
        black_eyes, black_strong_groups = self.count_eyes_and_strong_groups(state, self.BLACK)
        white_eyes, white_strong_groups = self.count_eyes_and_strong_groups(state, self.WHITE)
        
        black += black_eyes + black_strong_groups
        white += white_eyes + white_strong_groups
        
        return black - (white + self.komi), endgame
    
    
    def count_influenced_territory_enhanced(self, board: list) -> tuple[int, int]:
        '''
        # Description 
        Calculates the territory influenced by black and white players on the Go board.

        This function iterates through the board, analyzing each empty point to determine 
        if it's influenced by the surrounding black or white stones. The influence is calculated
        based on the adjacent stones, with positive scores indicating black influence and negative
        scores indicating white influence.

        # Returns:
        Tuple (black_territory, white_territory)
        '''
        black_territory = 0
        white_territory = 0
        visited = set()

        # Function to calculate influence score
        def influence_score(x, y):
            score = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if 0 <= nx < len(board) and 0 <= ny < len(board[0]):
                    score += board[nx][ny]
            return score
        
        # Function to explore territory
        def explore_territory(x, y):
            nonlocal black_territory, white_territory
            if (x, y) in visited or not (0 <= x < len(board) and 0 <= y < len(board[0])):
                return
            visited.add((x, y))

            if board[x][y] == 0:
                score = influence_score(x, y)
                if score > 0:
                    black_territory += 1
                elif score < 0:
                    white_territory += 1

        for i in range(len(board)):
            for j in range(len(board[0])):
                if board[i][j] == 0 and (i, j) not in visited:
                    explore_territory(i, j)

        return black_territory, white_territory
    
    
    def is_eye(self, board, x, y, player):

        # An eye is an empty point with all adjacent points of the player's color
        # and at least one diagonal point of the player's color.
        
        if board[x][y] != self.EMPTY:
            return False
        
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            nx, ny = x + dx, y + dy
            if not (0 <= nx < len(board) and 0 <= ny < len(board[0])):
                continue
            if board[nx][ny] != player:
                return False
            
        true_eye = False
        count = 0
        for dx, dy in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
            nx, ny = x + dx, y + dy

            if 0 <= nx < len(board) and 0 <= ny < len(board[0]) and board[nx][ny] == player:
                count += 1
                if count >= 3:
                    true_eye = True


        return true_eye
    
    
    def count_eyes_and_strong_groups(self, board, player):
        eyes = 0
        strong_groups = 0
        visited = set()

        def dfs(x, y):
            if (x, y) in visited or board[x][y] != player:
                return 0

            visited.add((x, y))
            liberties = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if not (0 <= nx < len(board) and 0 <= ny < len(board[0])):
                    continue
                if board[nx][ny] == self.EMPTY:
                    liberties += 1
                elif board[nx][ny] == player:
                    liberties += dfs(nx, ny)

            return liberties
        
        for x in range(len(board)):
            for y in range(len(board[0])):
                if board[x][y] == player and (x, y) not in visited:
                    liberties = dfs(x, y)
                    if liberties >= 2:  # Arbitrary threshold for a strong group
                        strong_groups += 1
                if board[x][y] != player and (x, y) not in visited and self.is_eye(board, x, y, player):
                    eyes += 1

        return eyes, strong_groups
    
    
    
    def get_opponent(self, player):
        '''
        # Description:
        Changes Opponent
        '''
        return -player
    
    def get_opponent_value(self, value):
        '''
        # Description
        Returns the negative value of the value
        '''
        return -value
    
    
    def get_encoded_state(self, state):
        '''
        # Description: 
        Converts the current state of the Go board into a 3-layer encoded format suitable for neural network input.
        Each layer in the encoded format represents the presence of a specific type of stone or an empty space on the board:
        - Layer 1 encodes the positions of white stones (represented by -1 in the input state) as 1s, and all other positions as 0s.
        - Layer 2 encodes the positions of empty spaces (represented by 0 in the input state) as 1s, and all other positions as 0s.
        - Layer 3 encodes the positions of black stones (represented by 1 in the input state) as 1s, and all other positions as 0s.
        This encoding helps in clearly distinguishing between different elements on the board for machine learning applications.

        # Returns: 
        A NumPy array of shape (3, height, width) containing the 3-layer encoded representation of the board state. Each layer is a 2D array where the board's height and width correspond to the dimensions of the original state.
        '''
        layer_1 = np.where(np.array(state) == -1, 1, 0).astype(np.float32)
        layer_2 = np.where(np.array(state) == 0, 1, 0).astype(np.float32)
        layer_3 = np.where(np.array(state) == 1, 1, 0).astype(np.float32)

        result = np.stack([layer_1, layer_2, layer_3]).astype(np.float32)

        return result
    
    
     def change_perspective(self, state, player):
        '''
        # Description: 
        Adjusts the perspective of the Go board state based on the current player.

        # Returns: 
        A two-dimensional array representing the Go board state adjusted for the current player's perspective.
        '''
        return state * player

        
        
        

### Attaxx

introducao attaxx

### Implementação do Attax

In [None]:

from tkinter import *
import numpy as np
import copy
import random as r



b_w="BLUE WINS!!"
r_w="RED WINS!!"

print("Escolha número de linhas/colunas:")
NB = int(input())  # Board number of rows/columns
size_of_board = 600
size_of_square = size_of_board/NB
symbol_size = (size_of_square*0.75-10)/2
symbol_thickness = 20
blue_color = '#496BAB'
red_color = '#F33E30'

possible_moves_global=[]
position_global=[]
bool=False
origin_pos=[]
moves_blue_global=[]
moves_red_global=[]
blue_pieces=[]
red_pieces=[]
board2=[]

class ataxx():
    def __init__(self):
        self.window = Tk()
        self.window.title('Ataxx')
        self.canvas = Canvas(self.window, width=size_of_board, height=size_of_board, background="white")
        self.canvas.pack()
        self.window.bind('<Button-1>', self.click)
        self.board = np.zeros(shape=(NB, NB))
        self.board[0][0]=2
        self.board[0][NB-1]=1
        self.board[NB-1][NB-1]=1
        self.board[NB-1][0]=2
        self.player_blue_turn = True
        self.game_ended = False
        self.mode1=0
        self.mode2=0
        self.init_draw_board()


    def mainloop(self):
        self.window.mainloop()
        if self.mode1==1 and self.mode2==1:
            self.ai_vs_ai()
            
    #----------------DESENHO DO TABULEIRO---------------------------------------------------------------------------------------------------------

    def init_draw_board(self):
        self.canvas.delete("all")
        for i in range(NB-1):
            self.canvas.create_line((i+1)*size_of_square, 0, (i+1)*size_of_square, size_of_board)
        for i in range(NB-1):
            self.canvas.create_line(0,(i+1)*size_of_square, size_of_board, (i+1)*size_of_square)
        self.canvas.create_oval(size_of_square/2 - symbol_size, size_of_square/2 - symbol_size,
                                size_of_square/2 + symbol_size, size_of_square/2 + symbol_size,
                                width=symbol_thickness, outline=red_color,
                                fill=red_color)
        self.canvas.create_oval(size_of_board - size_of_square/2 - symbol_size,size_of_board - size_of_square/2 - symbol_size,
                                size_of_board - size_of_square/2 + symbol_size, size_of_board - size_of_square/2 + symbol_size,
                                width=symbol_thickness, outline=blue_color,
                                fill=blue_color)
        self.canvas.create_oval(size_of_square/2 - symbol_size,size_of_board - size_of_square/2 - symbol_size,
                                size_of_square/2 + symbol_size, size_of_board - size_of_square/2 + symbol_size,
                                width=symbol_thickness, outline=blue_color,
                                fill=blue_color)
        self.canvas.create_oval(size_of_board - size_of_square/2 - symbol_size, size_of_square/2- symbol_size,
                                size_of_board - size_of_square/2 + symbol_size, size_of_square/2 + symbol_size,
                                width=symbol_thickness, outline=red_color,
                                fill=red_color)


    def update_board(self, x, y, origin):
        for i in range(max(0, x-1), min(NB, x+2)):
            for j in range(max(0, y-1), min(NB, y+2)):
                if not self.is_square_clear([i,j]):
                    if self.player_blue_turn:
                        self.draw_blue([i,j])
                    else:
                        self.draw_red([i,j])
                    self.board[i][j]=self.board[x][y]
        if x-origin[0]== 2 or y-origin[1]== 2 or x-origin[0]== -2 or y-origin[1]== -2:
            self.board[origin[0]][origin[1]]=0
            pos=self.convert_logical_to_grid_position(origin)
            self.draw_whitespace(pos)
        self.score()
        self.all_moves()
        self.player_blue_turn = not self.player_blue_turn

    def update_board2(self, board, x, y, origin):
        for i in range(max(0, x-1), min(NB, x+2)):
            for j in range(max(0, y-1), min(NB, y+2)):
                if not board[pos[0]][pos[1]] == 0:
                    board[i][j]=board[x][y]
        if x-origin[0]== 2 or y-origin[1]== 2 or x-origin[0]== -2 or y-origin[1]== -2:
            board[origin[0]][origin[1]]=0

    def all_moves(self):
        global moves_blue_global
        global moves_red_global
        global blue_pieces
        global red_pieces
        for i in range(NB):
            for j in range(NB):
                if self.board[i][j]==1:
                    moves_blue_global.append(self.possible_moves([i,j]))
                    blue_pieces.append([i,j])
                elif self.board[i][j]==2:
                    moves_red_global.append(self.possible_moves([i,j]))
                    red_pieces.append([i,j])
        if len(moves_blue_global)==0:
            self.no_moves(1)
        elif len(moves_red_global)==0:
            self.no_moves(2)
        moves_blue_global=[]
        moves_red_global=[]

    def no_moves(self, player):
        print("3")
        if player==1:
            for i in range(NB):
                for j in range(NB):
                    if self.board[i][j]==0:
                        self.board[i][j]=2
                        self.draw_red([i,j])
        elif player==2:
            for i in range(NB):
                for j in range(NB):
                    if self.board[i][j]==0:
                        self.board[i][j]=1
                        self.draw_blue([i,j])
        self.score()




    def execute_move(self, move, origin, player):

        
        self.board[move[0]][move[1]] = player
        self.update_board(move[0], move[1], origin)

    def is_square_clear(self, pos):
        if not np.array_equal(pos, []):
            return self.board[pos[0]][pos[1]] == 0

    def valid_move(self, logical_pos):
        return self.is_square_clear(logical_pos)

    def possible_moves(self, move):

    #dado a peça selecionada, devolve uma lista
    #com todos os movimentos possiveis da mesma 

        possible_moves=[]
        for i in range(max(0,move[0]-2), min(NB, move[0]+3)):
            for j in range(max(0,move[1]-2), min(NB, move[1]+3)):
                
                if self.is_square_clear([i,j]):
                    possible_moves.append([i,j])
        
        return possible_moves

    def score(self):
        cont_blue=0
        cont_red=0
        cheio=True
        for i in range(NB):
            for j in range(NB):
                if self.board[i][j]==1:
                    cont_blue+=1
                elif self.board[i][j]==2:
                    cont_red+=1
                if self.board[i][j]==0:
                    cheio=False
        print("Blue score= ",  cont_blue)
        print("Red score= ",  cont_red)
        print("------------------------")
        self.window.title("Ataxx - Red : %d vs %d : Blue" % (cont_red, cont_blue))
        if cont_blue==0:
            self.game_is_over(cont_red, cont_blue)
        elif cont_red==0:
            self.game_is_over(cont_red, cont_blue)
        elif cheio:
            self.game_is_over(cont_red, cont_blue)




    def game_is_over(self, red, blue):
        print("Blue score= ",  blue)
        print("Red score= ",  red)
        print("\n")
        if blue>red:
            print(b_w)
            a=1
        else:
            print(r_w)
            a=2
        print("\n")
        self.clear_possible_moves()
        self.window.destroy()


#----------------------TRANSFORMAR EM MATRIZ PARA APLICAR REGRAS---------------

    def convert_logical_to_grid_position(self, logical_pos):
        logical_pos = np.array(logical_pos, dtype=int)
        return np.array((size_of_square)*logical_pos + size_of_square/2)

    def convert_grid_to_logical_position(self, grid_pos):
        grid_pos = np.array(grid_pos)
        return np.array(grid_pos//size_of_square, dtype=int)

#-----------------------DESENHAR PECAS----------------------------------------
    def draw_whitespace(self, grid_pos):

        

        self.canvas.create_rectangle(grid_pos[0] - symbol_size, grid_pos[1] - symbol_size,
                            grid_pos[0] + symbol_size, grid_pos[1] + symbol_size,
                            width=symbol_thickness, outline="white",
                            fill="white")


    def draw_blue(self, logical_pos):
        logical_pos = np.array(logical_pos)
        grid_pos = self.convert_logical_to_grid_position(logical_pos)
        self.canvas.create_oval(grid_pos[0] - symbol_size, grid_pos[1] - symbol_size,
                            grid_pos[0] + symbol_size, grid_pos[1] + symbol_size,
                            width=symbol_thickness, outline=blue_color,
                            fill=blue_color)

    def draw_red(self, logical_pos):
        logical_pos = np.array(logical_pos)
        grid_pos = self.convert_logical_to_grid_position(logical_pos)
        self.canvas.create_oval(grid_pos[0] - symbol_size, grid_pos[1] - symbol_size,
                            grid_pos[0] + symbol_size, grid_pos[1] + symbol_size,
                            width=symbol_thickness, outline=red_color,
                            fill=red_color)

    def draw_possible_moves(self, possible_moves):

        # desenha no tabuleiro as jogadas possiveis da bola selecionada

        moves=[0]*len(possible_moves)
        for i in range(len(possible_moves)):
            moves[i]=self.convert_logical_to_grid_position(possible_moves[i])
            self.canvas.create_oval(moves[i][0]-symbol_size, moves[i][1] - symbol_size,
                                    moves[i][0]+symbol_size, moves[i][1]+ symbol_size,
                                    width=symbol_thickness, outline="gray", fill="gray", tags="possible")


    def clear_possible_moves(self):

        

        self.canvas.delete("possible")
        
#----------------------------------- Verificaçao movimentos e jogadas ------------------------------------

    def total_moves(board, player,ROWS):
        moves = []
        moves_aval = []
        for peca in totalpecas(board, ROWS, player):
            moves_aval = plays_eval(peca,ROWS,board)
            for move in moves_aval:
                temp_board = deepcopy(board)
                temp_peca = (peca[0],peca[1])
                new_board = simula_move(temp_peca, move, temp_board, player, ROWS)
                moves.append(new_board)
        return moves

#----------------------- MOUSE -----------------------------------------------------------

    def click(self, event):        
        if self.game_ended: return
        grid_pos = [event.x, event.y]
        logical_pos = self.convert_grid_to_logical_position(grid_pos)
        global origin_pos
        global possible_moves_global
        origin_pos = logical_pos
        
        if self.board[logical_pos[0]][logical_pos[1]] == 1 and self.player_blue_turn:
            possible_moves_global = self.possible_moves(logical_pos)
            
            if not np.array_equal(possible_moves_global, []):
                self.window.bind("<Button-1>", self.second_click)
            self.draw_possible_moves(possible_moves_global)
            
        elif self.board[logical_pos[0]][logical_pos[1]] == 2 and not self.player_blue_turn:
            possible_moves_global = self.possible_moves(logical_pos)
            
            if not np.array_equal(possible_moves_global, []):
                self.window.bind("<Button-1>", self.second_click)
            self.draw_possible_moves(possible_moves_global)


    def second_click(self, event):
        global bool
        grid_pos = [event.x, event.y]
        logical_pos = self.convert_grid_to_logical_position(grid_pos)
        global possible_moves_global
        possible_moves_global=np.array(possible_moves_global, dtype=int)
        
        for element in possible_moves_global:
            
            if np.array_equal(logical_pos, element):
                global position_global
                position_global = logical_pos
                bool=True
        bool = True
        self.click2()
        possible_moves_global=[]
        position_global=[]


    def click2(self):
        global bool
        
        if self.player_blue_turn:
            player=1
        else:
            player=2
            
        if self.valid_move(position_global):
            
            if self.second_click_pressed(bool):
                if self.player_blue_turn and self.board[origin_pos[0]][origin_pos[1]] == 1:
                    self.draw_blue(position_global)
                    self.execute_move(position_global, origin_pos, player)
                    
                elif not self.player_blue_turn and self.board[origin_pos[0]][origin_pos[1]] == 2:
                    self.draw_red(position_global)
                    self.execute_move(position_global, origin_pos, player)
        self.clear_possible_moves()
        self.window.bind("<Button-1>", self.click)
        

    def second_click_pressed(self, bool):
        if bool:
            return True
        return False



def PvsP():
    game = ataxx()
    game.mainloop()
   

PvsP()

### MCTS

explicação

### Implementação

In [None]:
class Node:
    '''
    # Alpha Zero Node
    ## Description:
        A node for the AlphaZero MCTS. It contains the state, the action taken to get to the state, the prior probability of the action, the visit count, the value sum, and the children of the node.
    ## Metohds:
        - `is_expanded()`: Returns whether the node has been expanded.
        - `select()`: Selects the best child node based on the UCB.
        - `get_ucb()`: Returns the UCB of a child node.
        - `expand()`: Expands the node by adding children.
        - `backpropagate()`: Backpropagates the value of the node to the parent node.
        '''
    def __init__(self, game, args, state, player, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.player = player
        self.prior = prior
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_expanded(self):
        '''
        # is_expanded
        ## Description:
            Returns whether the node has been expanded.
        ## Returns:
            - `bool`: Whether the node has been expanded.'''
        return len(self.children) > 0
    
    def select(self):
        '''
        # Description: 
        Selects the best child node from the current node's children in a Monte Carlo Tree Search using the Upper Confidence Bound (UCB) algorithm. 

        # Returns: 
        The best child node, chosen based on the highest UCB value or randomly if there's a tie.
        '''
        best_child = []
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = [child]
                best_ucb = ucb
            elif ucb == best_ucb:
                best_child.append(child)
                
        return best_child[0] if len(best_child) == 1 else random.choice(best_child)
    
    def get_ucb(self, child):
        '''
        # Description: 
        Calculates the Upper Confidence Bound (UCB) value for a given child node in a Monte Carlo Tree Search.

        # Returns: 
        The calculated UCB value for the given child node.
        '''
        if child.visit_count == 0:
            q_value = child.prior * self.args['C'] * (math.sqrt(self.visit_count)) / (child.visit_count + 1)
        else:
            q_value = -(child.value_sum / child.visit_count) + child.prior * self.args['C'] * (math.sqrt(self.visit_count)) / (child.visit_count + 1)
        return q_value
    
    def serialize(self):
        # Serialize only essential data
        node_data = {
            'game': self.game,
            'args': self.args,
            'parent': self.parent,
            'state': self.state,
            'action_taken': self.action_taken,
            'player': self.player,
            'prior': self.prior,
            'visit_count': self.visit_count,
            'value_sum': self.value_sum,
            'children': [child for child in self.children]  # Assuming each child has a unique ID
        }
        return json.dumps(node_data)
    
    def deserialize(node_json):
        # Convert JSON back into a Node object
        node_data = json.loads(node_json)
        node = Node(  # assuming constructor can handle this data
            game = node_data['game'],
            args = node_data['args'],
            parent = node_data['parent'],
            player = node_data['player'],
            state=node_data['state'],
            action_taken=node_data['action_taken'],
            prior=node_data['prior'],
            visit_count=node_data['visit_count'],
        )
        node.value_sum = node_data['value_sum']

        for child in node_data['children']:
            child.parent = node
            node.children.append(child)

        # You'll need to handle children reconstruction separately
        return node
    def expand(self, policy):
        '''
        # Description: 
        Expands the current node by adding new child nodes based on the given policy probabilities. For each possible action, it calculates the next state, adjusts the perspective for the opponent, and creates a new child node if the probability for that action is greater than zero.

        # Returns: 
        None
        '''
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)
                child = Node(self.game, self.args, child_state, self.game.get_opponent(self.player), self, action, prob)
                self.children.append(child)
            
    def backpropagate(self, value):
        '''
        # Description: 
        Performs the backpropagation step in Monte Carlo Tree Search. It updates the current node's value sum and visit count based on the received value.

        # Returns: 
        None
        '''
        self.value_sum += value
        self.visit_count += 1
        
        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)
            
            
class MCTS:
    def __init__(self, model, game, args):
        self.model = model
        self.game = game
        self.args = args
        self.tree_dict = {}
        
    @torch.no_grad()
    def search(self, states, player):
        """
        # Description:
        Performs Monte Carlo Tree Search (MCTS) in batch to find the best action.

        # Returns:
        An array of arrays of action probabilities for each possible action.
        """
        roots = []
        for state in states:
            root = Node(self.game, self.args, state, player, visit_count=1)

            searches = self.args['num_mcts_searches']

            # if str(state)+str(player) not in self.tree_dict.keys():
            #     self.tree_dict.update({str(state)+str(player): root.serialize()})

            # else: # the state is already in the dictionary
            #    root = Node.deserialize(self.tree_dict.get(str(state)+str(player)))
            #    searches = (searches // 4)

            roots.append(root)
            
            policy, _ = self.model(
                torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
            )
            policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
            policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
                * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
                
            valid_moves = self.game.get_valid_moves(state, player)

            if self.args["game"] == "Attaxx":
                if np.sum(valid_moves) == 0:
                    valid_moves[-1] = 1
                else:
                    valid_moves[-1] = 0
                    policy *= valid_moves
            policy /= np.sum(policy)
            root.expand(policy)
                
            for search in range(searches):
                node = root
                while node.is_expanded():
                    node = node.select()
                    
                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken, node.player)
                value = self.game.get_opponent_value(value)
                    
                if node.parent is not None:
                    if node.action_taken == self.game.action_size - 1 and node.parent.action_taken == self.game.action_size - 1 and self.args['game'] == 'Go':
                        is_terminal = True # if the action is pass when the previous action was also pass, end the game

                if not is_terminal:
                    policy, value = self.model(torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0))
                    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                    valid_moves = self.game.get_valid_moves(node.state, player)

                    if self.args["game"] == "Attaxx":
                        if np.sum(valid_moves) == 0:
                            valid_moves[-1] = 1
                        else:
                            valid_moves[-1] = 0

                    policy *= valid_moves
                    policy /= np.sum(policy)
                        
                    value = value.item()
                    node.expand(policy)

                node.backpropagate(value)
                
        action_prob_list = []

        for root in roots:
            action_probs = np.zeros(self.game.action_size)
            for child in root.children:
                action_probs[child.action_taken] = child.visit_count
            action_probs /= np.sum(action_probs)
            action_prob_list.append(action_probs)

        return action_prob_list

### Rede Residual

Introdução de uma Rede Residual

In [None]:
class ResNet(nn.Module):
    '''
    # ResNet
    ## Description:
        A ResNet model for AlphaZero.
        The model takes in a state and outputs a policy and value.
         - The policy is a probability distribution over all possible actions.
         - The value is a number between -1 and 1, where -1 means the current player loses and 1 means the current player wins following a tanh activation.
        '''
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        self.device = device

        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding="same"),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding="same"),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding="same"),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
        
        
    def forward(self, x):
        '''
        # Description:
        The forward pass of the model. This overrides the forward method of nn.Module so that it can be called directly on the model.

        # Returns:
        - `policy`: The policy output of the model.
        - `value`: The value output of the model.
        '''
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
    
class ResBlock(nn.Module):
    '''
    # Description:
    A residual block for the ResNet model.
    '''
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding="same")
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        """
        # Description:
        Forward pass through the residual block.

        # Returns:
        Output tensor after passing through the block.
        """
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

### AlphaZero

Introdução

In [None]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(model, game, args)

    def augment_state(self, state, probs):

        augmented_states = []

        skip_prob = probs[-1]
        action_probs_matrix = np.array(probs[:-1]).reshape(self.game.column_count, self.game.row_count)
        augmented_action_probs = []

        def augment_and_append(transformed_state, transformed_probs_matrix):

            # Append state
            augmented_states.append(transformed_state)

            # Flatten probs matrix, append the last value, and then append to augmented_action_probs
            augmented_action_probs.append(list(transformed_probs_matrix.flatten()) + [skip_prob])

        # Original state and probs
        augment_and_append(state, action_probs_matrix)
        # Rotate 90 degrees clockwise
        augment_and_append(np.rot90(state, k=1), np.rot90(action_probs_matrix, k=1))

        # Rotate 180 degrees clockwise
        augment_and_append(np.rot90(state, k=2), np.rot90(action_probs_matrix, k=2))

        # Rotate 270 degrees clockwise
        augment_and_append(np.rot90(state, k=3), np.rot90(action_probs_matrix, k=3))

        # Flip horizontally
        augment_and_append(np.fliplr(state), np.fliplr(action_probs_matrix))

        # Flip vertically
        augment_and_append(np.flipud(state), np.flipud(action_probs_matrix))

        # Rotate 90 degrees clockwise and flip horizontally
        augment_and_append(np.rot90(np.fliplr(state), k=1), np.rot90(np.fliplr(action_probs_matrix), k=1))

        # Rotate 90 degrees clockwise and flip vertically
        augment_and_append(np.rot90(np.flipud(state), k=1), np.rot90(np.flipud(action_probs_matrix), k=1))

        return augmented_states, augmented_action_probs
    
    def selfPlay(self):
        player = 1

        memory = []
        states = []

        for _ in range(0, self.args['parallel_games']):
            state = self.game.get_initial_state()
            states.append(state)
            memory.append([])

        iter = 0
        prev_skip = False
        temperature = self.args['temperature']
        debugging = False

        returnData = []
        while True:
            if self.args["game"] == "Attaxx" and debugging:
                print("\nSEARCHING...")

            neutral_states_list = []

            for state in states:
                neutral_states_list.append(self.game.change_perspective(state, player))

            action_probs_list = self.mcts.search(states, player)

            for i, (neutral_state, action_probs) in enumerate(zip(neutral_states_list, action_probs_list)):
                memory[i].append((neutral_state, action_probs, player))

            for idx, (state, action_probs) in enumerate(zip(states, action_probs_list)):
                temperature_action_probs = action_probs ** (1 / temperature)
                temperature_action_probs /= np.sum(temperature_action_probs)
                
                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                state = self.game.get_next_state(state, action, player)

                if self.args["game"] == "Attaxx" and debugging:
                    print(f"Player: {player} with move {self.game.int_to_move(action)}\nBoard:")
                    self.game.print_board(state)    

                value, is_terminal = self.game.get_value_and_terminated(state, action, player)
                    

                if action == self.game.action_size - 1 and self.args['game'] == 'Go':
                    if prev_skip:
                        is_terminal = True
                    else:
                        prev_skip = True
                else:
                    prev_skip = False

                if is_terminal or iter >= self.args['max_moves']:
                    returnMemory = []
                    if self.args["game"] == "Attaxx" and debugging:
                        print("GAME OVER\n\n")
                    for hist_neutral_state, hist_action_probs, hist_player in memory[idx]:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)

                        if self.args['augment']:
                            augmented_states, augmented_action_probs = self.augment_state(hist_neutral_state, hist_action_probs)

                            for augmented_state, augmented_probs in zip(augmented_states, augmented_action_probs):
                                returnMemory.append((self.game.get_encoded_state(augmented_state), augmented_probs, hist_outcome))
                        else:
                            returnMemory.append((self.game.get_encoded_state(hist_neutral_state), hist_action_probs, hist_outcome))

                        returnData = returnData + returnMemory
                        
                        del memory[idx]
                    del states[idx]

                if len(memory) <= 0:
                    return returnData

            player = self.game.get_opponent(player)

            if temperature >= 0.1:
                temperature = temperature * self.args['cooling_constant']
            else:
                temperature = 0.1

            iter += 1
            
            
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:batchIdx+self.args['batch_size']]
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
    def learn(self, memory = None, LAST_ITERATION=0):
        primary_memory = []

        if memory != None:
            primary_memory = memory

        for iteration in range(LAST_ITERATION+1, self.args['num_iterations']):
            print(f"Iteration {iteration + 1}")

            secondary_memory = []

            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                states = self.selfPlay()
                secondary_memory += states

            training_memory = []
            if self.args['experience_replay']:
                sample_size = int(len(primary_memory) * 0.3)

                training_memory += random.sample(primary_memory, min(sample_size, len(primary_memory)))
                training_memory += secondary_memory
                
                primary_memory += secondary_memory
            else:
                training_memory += secondary_memory

            print(f"Memory size: {len(training_memory)}")

            self.model.train()

            for epoch in trange(self.args['num_epochs']):
                self.train(training_memory)

            print("\n")
            
            torch.save(self.model.state_dict(), f"DevelopmentModels/{self.args['alias']}/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"DevelopmentModels/{self.args['alias']}/optimizer_{iteration}.pt")
            with open(f'DevelopmentModels/{self.args["alias"]}/memory_{iteration}.pkl', 'wb') as f:
                pickle.dump(primary_memory, f)
            print("Data Saved!")


### Treino

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


SAVE_NAME = None

if __name__ == '__main__':

    # Go / Attaxx
    GAME = "Attaxx"

    # Board size (7/9 for Go, 4/5/6 for Attaxx)
    SIZE = 6
    # True to load previous model
    # False to start from scratch
    LOAD = True
    LAST_ITERATION = 1

    # Save Name
    SAVE_NAME = "4x4Parallel_5"

    # False for training
    # True for playing
    TEST = True

    # False if locally 
    # True if playing in the server
    ONLINE = False

    # Train from scratch
    if not LOAD and not TEST:
        LAST_ITERATION=-1
        
        
    if GAME == 'Go':
        if SIZE == 7:
            args = {
                'game': 'Go',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                'batch_size': 16,                 # batch size for training
                'temperature': 3,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.90,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 2,                           # the value of the constant policy
                'experience_replay': True,        # recycle a certain % of old random selfplay data in the current training iteration
                'augment': False,                 # whether to augment the training data with flipped and rotated states
                'parallel_games': 10,            # number of games run in parallel
                'dirichlet_alpha': 0.03,          # the value of the dirichlet noise (alpha)
                'dirichlet_epsilon': 0.25,        # the value of the dirichlet noise (epsilon)
                'alias': ('Go' + SAVE_NAME)
            }

            game = Go(size = SIZE, komi = 5.5)
            model = ResNet(game, 10, 10, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
        elif SIZE == 9:
            args = {
                'game': 'Go',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 20,    # number of self-play games to play within each iteration
                'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 60,                 # number of epochs for training on self-play data for each iteration
                'batch_size': 32,                 # batch size for training
                'temperature': 3,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 2,                           # the value of the constant policy
                'experience_replay': True,        # recycle a certain % of old random selfplay data in the current training iteration
                'augment': False,                 # whether to augment the training data with flipped and rotated states
                'parallel_games': 5,            # number of games run in parallel
                'dirichlet_alpha': 0.032,          # the value of the dirichlet noise (alpha)
                'dirichlet_epsilon': 0.25,        # the value of the dirichlet noise (epsilon)
                'alias': ('Go' + SAVE_NAME)
            }

            game = Go(size = SIZE, komi = 5.5)
            model = ResNet(game, 15, 15, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
    elif GAME == 'Attaxx':
        game_size = [SIZE,SIZE]
        if SIZE == 4:
            args = {
                'game': 'Attaxx',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 10,                # number of epochs for training on self-play data for each iteration
                'batch_size': 16,                # batch size for training
                'temperature': 2,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.9,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 4,                           # the value of the constant policy
                'dirichlet_alpha': 0.3,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.2,       # the 001value of the dirichlet noise
                'parallel_games': 10,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 4, 8, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
        elif SIZE == 5:
            args = {
                'game': 'Attaxx',
                'num_iterations': 10000,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 10,                # number of epochs for training on self-play data for each iteration
                'batch_size': 64,                # batch size for training
                'temperature': 1,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 4,                           # the value of the constant policy
                'dirichlet_alpha': 0.3,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.2,       # the value of the dirichlet noise
                'parallel_games': 15,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 8, 16, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
            
            
        elif SIZE == 6:
            args = {
                'game': 'Attaxx',
                'num_iterations': 10000,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 150,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 20,                # number of epochs for training on self-play data for each iteration
                'batch_size': 64,                # batch size for training
                'temperature': 1,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 4,                           # the value of the constant policy
                'dirichlet_alpha': 0.3,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.2,       # the value of the dirichlet noise
                'parallel_games': 20,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 12, 32, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    else:
        print("Game Unavailable")
        
        
    if LOAD:
        model.load_state_dict(torch.load(f'DevelopmentModels/{GAME+SAVE_NAME}/model_{LAST_ITERATION}.pt', map_location=device))
        optimizer.load_state_dict(torch.load(f'DevelopmentModels/{GAME+SAVE_NAME}/optimizer_{LAST_ITERATION}.pt', map_location=device))
    
        #with open(f'DevelopmentModels/{GAME+SAVE_NAME}/memory_{LAST_ITERATION}.pkl', 'rb') as f:
         #   memory = pickle.load(f)
    else:
        memory = None

    if not TEST:
        os.makedirs(f'DevelopmentModels/{GAME+SAVE_NAME}', exist_ok=True)
        alphaZero = AlphaZero(model, optimizer, game, args)
        alphaZero.learn(memory, LAST_ITERATION)
        
        
    elif not ONLINE:

        if not LOAD:
            print("No model to test")
            exit()

        if GAME == 'Go':
            PLAYER1 = "user"
            PLAYER2 = "AI"
            game = Go(SIZE, 5.5)

            model.load_state_dict(torch.load(f'DevelopmentModels/{GAME+SAVE_NAME}/model_{LAST_ITERATION}.pt'), map_location = device)
            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            game.print_board(state)

            player = 1
            prev_skip = False
            while True:
                if player == 1:
                    if PLAYER1 == 'user':
                        a, b = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                        print("\n")
                        action = a * SIZE + b
                        state = game.get_next_state(state, action, player)
                    else:
                        tmp_state = game.change_perspective(state, -1)
                        action = mcts.search([tmp_state], -player)                    
                        action = np.argmax(action[0])

                        print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
                        state = game.get_next_state(state, action, player)
                else:
                    if PLAYER2 == 'user':
                        a, b = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                        print("\n")
                        action = a * SIZE + b
                        state = game.get_next_state(state, action, player)
                        
                    else:
                        action = mcts.search([tmp_state], -player)                    
                        action = np.argmax(action[0])

                        print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
                        state = game.get_next_state(state, action, player)

                winner, win = game.get_value_and_terminated(state, action, player)
                
                if action == game.action_size:
                    if prev_skip:
                        win = True
                    else:
                        prev_skip = True
                else:
                    prev_skip = False

                if win:
                    game.print_board(state)
                    print(f"player {winner} wins")
                    exit()

                player = - player
                game.print_board(state)
                
        elif GAME == 'Attaxx':
            PLAYER1 = "AI"
            PLAYER2 = "AI"
            game = Attaxx(game_size)

            model.load_state_dict(torch.load(f'DevelopmentModels/{GAME+SAVE_NAME}/model_{LAST_ITERATION}.pt', map_location = device))
            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            game.print_board(state)

            player = 1
            
            
        while True:
                if player == 1:
                    if PLAYER1 == 'user':
                        move = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                        print("\n")
                        action = game.move_to_int(move)
                        state = game.get_next_state(state, action, player)
                    else:
                        tmp_state = game.change_perspective(state, -1)
                        action = mcts.search([tmp_state], -player)
                        action = np.argmax(action)
                        print(f"\nAlphaZero Action: {game.int_to_move(action)}\n")
                        state = game.get_next_state(state, action, player)
                else:
                    if PLAYER2 == 'user':
                        move = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                        print("\n")
                        action = game.move_to_int(move)
                        state = game.get_next_state(state, action, player)
                    else:
                        action = mcts.search([state], player)
                        action = np.argmax(action)
                        print(f"\nAlphaZero Action: {game.int_to_move(action)}\n")
                        state = game.get_next_state(state, action, player)

                winner, win = game.get_value_and_terminated(state, action, player)
                if win:
                    game.print_board(state)
                    print(f"player {winner} wins")
                    exit()
                    break

                player = -player
                game.print_board(state)