In [1]:
# Go Game Class  
class GoGame:
    def __init__(self, dimension):
        self.GRID_SIZE = dimension 
        self.action_size = self.GRID_SIZE * self.GRID_SIZE 
        self.board_history = [] # Track the board history
        
        
        self.last_time_switch = time.time() # To track time
        self.time = 0.0  # Initial time
        self.max_time = 10.0 # Max time per move
       
        self.komi = 5.5 # Komi score
        
        self.pass_count = 0  # To track consecutive passes
        self.is_terminal = False

    # Generates a mask indicating valid moves for the given player.
    def get_mask(self, state, player):
        action_mask = [1] * (self.action_size + 1) # +1 for pass move
        valid_moves = self.get_valid_moves(state, player) # Get valid moves for the player
        for i in range(self.action_size + 1): 
            if i not in valid_moves: # If the move is not valid, set the mask to 0
                action_mask[i] = 0 

        return action_mask
    
    # Resets the time
    def reset_time(self):
        self.time = 0.0

    # Returns the initial game board matrix
    def get_initial_state(self):
        return np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=int) 

    # Checks if a given move is valid for the specified player.
    def is_valid_move(self, action, state, player):
        if action == 81: # 81 is the pass move
            return True # Pass move is always valid
            
        # Get row and column from action index
        row = action // self.GRID_SIZE 
        col = action % self.GRID_SIZE 
        
        # Check if the move is outside the board
        if row < 0 or row >= self.GRID_SIZE or col < 0 or col >= self.GRID_SIZE:
            return False  # Move is outside the board
        # Check if the cell is already occupied
        if state[row][col] != 0:
            return False  # Cell is already occupied

        # Temporarily place the stone to check for liberties or capture
        state[row][col] = player # Place the stone
        group, liberties = self.check_liberties(row, col, state) # Check liberties
        opponent = -player # Get opponent player

        capture = any(self.check_liberties(r, c, state)[1] == 0 for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)] for r, c in [(row + dr, col + dc)] if 0 <= r < self.GRID_SIZE and 0 <= c < self.GRID_SIZE and state[r][c] == opponent) # Check if the move captures opponent stones
        
        # Undo the temporary move
        state[row][col] = 0
        
        
        
        return liberties > 0 or capture

    # Returns a list of valid moves for the given player.
    def get_valid_moves(self, state, player): # necessário - retorna uma lista em que 0 é [0][0], ..., 80 é [8][8]
        valid_moves = [] # List of valid moves
        for index, move in enumerate((state.reshape(-1) == 0).astype(np.uint8)): # Reshape the state to a 1D array and check for empty cells
            if self.is_valid_move(index, state, player): # Check if the move is valid
                valid_moves.append(index) # Add the move to the list of valid moves
    
        valid_moves.append(81)  # Add pass move
        return valid_moves

    # Updates the time
    def update_time(self): 
        current_time = time.time() # Get current time
        self.time += current_time - self.last_time_switch # Update time
        self.last_time_switch = current_time # Update last time switch
            
    # Checks if the game has ended based on the current state and action.
    def check_end_game(self, state, action, player):
        self.update_time() # Update time
    
        if action == 81: # 81 is the pass move
            self.pass_count += 1 # Increment pass count
        else: # Reset pass count if it's not a pass move
            self.pass_count = 0 
            
        # print(f"Pass count: {self.pass_count}, Time: {self.time}")
    
        if self.pass_count == 2: # If both players passed
            return None, True # Game is over

        # if self.is_time_up():
            # return -player, True
    
        return None, False

    # Checks the liberties of a group of stones on the board.
    def check_liberties(self, row, col, state, checked=None):
            # Set of checked positions to avoid repeated checks
            if checked is None: #
                checked = set()
            
            # Check if the position is already checked or empty 
            if (row, col) in checked or state[row][col] is None: 
                return set(), 0 # Return empty set and 0 liberties
            
            checked.add((row, col)) # Add the position to the checked set
            player = state[row][col] # Get the player at the position
            liberties = 0 # Number of liberties
            group = {(row, col)} # Group of stones

            # Check liberties in all four directions
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                r, c = row + dr, col + dc # Get the position in the direction
                if 0 <= r < self.GRID_SIZE and 0 <= c < self.GRID_SIZE: # Check if the position is on the board
                    if state[r][c] == 0: # Check if the position is empty
                        liberties += 1 # Increment liberties
                    elif state[r][c] == player and (r, c) not in checked: # Check if the position is the same player and not checked
                        additional_group, additional_liberties = self.check_liberties(r, c, state, checked) # Check liberties of the position
                        liberties += additional_liberties # Add liberties
                        group.update(additional_group) # Add positions to the group
            return group, liberties

    # Calculates the territory of a player.
    def calculate_territory(self, state, player):
        visited = set() # Set of visited positions
        territory_count = 0 # Number of territories

        # Recursive function to check if a territory is enclosed
        def is_enclosed(row, col, state, player):
            if (row, col) in visited or row < 0 or row >= self.GRID_SIZE or col < 0 or col >= self.GRID_SIZE: # Check if the position is already visited or outside the board
                return True
            if state[row][col] == 0: # Check if the position is empty
                visited.add((row, col)) # Add the position to the visited set
                return all(is_enclosed(row + dr, col + dc, state, player) for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]) # Check if all adjacent positions are enclosed
            return state[row][col] == player # Check if the position is the same player

        # Check all positions on the board
        for row in range(self.GRID_SIZE):
            for col in range(self.GRID_SIZE): 
                if state[row][col] == 0 and (row, col) not in visited: # Check if the position is empty and not visited
                    if is_enclosed(row, col, state, player): # Check if the territory is enclosed
                        territory_count += 1 # Increment territory count

        return territory_count
        
    # Calculates the score of a player.
    def calculate_points(self, state, player):
        points = 0 # Points of the player
        if player == -1: # If the player is the second to play
            points += self.komi + 1 # Gets an initial komi score for being the second to play

        # Count stones on the board
        for row in state:
            for cell in row: 
                if cell == player: # If the stone belongs to the player
                    points += 1 # Increment points

        # Calculate captured territories
        captured_territories = self.calculate_territory(state, player)

        # Add captured territories to the score
        points += captured_territories

        # Handle start points
        points -= 1

        return points

    # Checks if someone has won the game based on the scores.
    def someone_won(self, state):
        player_score = self.calculate_points(state, 1) # Calculate player score
        opponent_score = self.calculate_points(state, -1) # Calculate opponent score
        
        # print(f"Player score: {player_score}, Opponent score: {opponent_score}")

        if player_score > opponent_score: 
            return 1, True # Player won
        elif player_score < opponent_score:
            return -1, True # Opponent won

        return None, False # No one won, it's a draw

    # Determines the game value and termination status based on the given state and action.
    def get_value_and_terminated(self, state, action, player): # Retorna 1/0 (alguém ganhou) , 1/-1 consoante quem venceu, True/False se o jogo terminou ou não
        winner, win = self.someone_won(state) # Check if someone won
        winner_due_to_time, end_game = self.check_end_game(state, action, player)  # Check if the game ended due to time
        if winner_due_to_time is not None: # If the game ended due to time
            return 1, winner_due_to_time, True # Return 1, winner and True
        if end_game and win: # If the game ended and someone won
            return 1, winner, True # Return 1, winner and True
        return 0, None, False # Return 0, None and False

    # Changes the perspective of the state based on the player.
    def change_perspective(self, state, player):
        return state * player  

    # Returns the encoded state.
    def get_encoded_state(self, state): 
        encoded_state = np.stack( 
            (state == -1, state == 0, state == 1) # Stack the state in a 3D array
        ).astype(np.float32) # Convert to float32

        return encoded_state # Return the encoded state

    # Returns the next state based on the given state and action.
    def get_next_state(self, state, action, player):
        if action == 81:  # Pass action
            self.pass_count += 1  # Increment pass count for a pass action
            if self.pass_count == 2:
                # Set game to terminal state if two consecutive passes occur
                self.set_terminal_state(state)
            return state
        else:
        
            # Get row and column from action index
            row = action // self.GRID_SIZE 
            col = action % self.GRID_SIZE
            state[row][col] = player # Place the stone
            self.update_time() # Update time

            # print(f"State before move: {state}")

            captured_stones = self.capture_stones(row, col, state, -player) # Capture opponent stones

            # Check for Ko
            if self.is_ko(state):
                # Undo the move if it's a Ko
                state[row][col] = 0
            else:
                self.board_history.append(state.copy()) # Add the state to the board history

            #print(f"Applying move at row {row}, column {col} by player {player}")


            # Reset pass count on a regular move
            self.pass_count = 0
 
            return state

    # Removes captured stones from the board.
    def remove_captured_stones(self, state, group):
        for (r, c) in group: # Iterate over the group
            state[r][c] = 0 # Remove the stone

    # Checks if the current state is a Ko position by comparing with previous states.
    def is_ko(self, state):
        return any(np.array_equal(state, prev_state) for prev_state in self.board_history) 

    # Captures stones in the specified position if they have no liberties.
    def capture_stones(self, row, col, state, opponent):
        captured_stones = [] # List of captured stones
        for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:  # Check all four directions
            r, c = row + dr, col + dc # Get the position in the direction
            if 0 <= r < self.GRID_SIZE and 0 <= c < self.GRID_SIZE and state[r][c] == opponent: # Check if the position is on the board and is the opponent
                group, liberties = self.check_liberties(r, c, state) # Check liberties of the position
                if liberties == 0: # If the group has no liberties
                    captured_stones.extend(group)  # Add captured group positions
                    self.remove_captured_stones(state, group) # Remove captured stones
        # print(f"Checking liberties for stone at ({row}, {col})")
        # print(f"Captured stones: {captured_stones}")
        return captured_stones  # Return a list of positions
    
    def set_terminal_state(self, current_state):
        # Mark the game as finished
        self.is_terminal = True

        # Additional logic can be added here if needed, 
        # such as determining the winner or calculating final scores.
        # For example, you can invoke the `calculate_points` method for each player
        # and compare the scores to determine the winner.

        # Calculate final scores for both players
        player_score = self.calculate_points(current_state, 1)
        opponent_score = self.calculate_points(current_state, -1)

        # Determine winner based on scores
        if player_score > opponent_score:
            self.winner = 1
        elif opponent_score > player_score:
            self.winner = -1
        else:
            self.winner = 0  # This could indicate a draw