# Alpha Go Mini

In beating 18x Go World Champion, Lee Sedol, Alpha Go became the first computer player to attain the highest possible certification of 9 dan. More than that, it inspired a new wave of possibilities for the application of Machine Learning/Artifical Intelligence models, shattering the glass ceiling that was the previously thought limit to a computer's creativity. Here I implement a basic 9x9 version of Alpha Go's model using a similar combination of two neural networks (policy and value), a residual network model structure, and supervised training on labeled 9x9 go games.

Author: Jason Katz (June 2025 - August 2025)

### Step 1: Go Engine

Playing Go requires an engine for representing the game's state and performing actions like making moves or calculating score. Below is an implementation of several critical methods such as "move", "_calculate_liberties" and "get_legal_actions".

In [None]:
from collections import deque
import copy
import numpy as np

class Board:
    def __init__(self, rows, cols):
        self._last_black_move = None
        self._last_white_move = None
        self._blacks_prisoners = 0
        self._whites_prisoners = 0
        self._rows = rows
        self._cols = cols
        self.player = 1
        self._ko = None
        self._ko_player = None
        self._board = np.zeros((rows, cols))
        self._turn_board = np.zeros((rows, cols))
        self.turn = 1
        

    def get_board(self):
        return self._board


    def get_rows(self):
        return self._rows


    def get_cols(self):
        return self._cols


    def percent_filled(self):
        print(np.count_nonzero(self._board))
        return np.count_nonzero(self._board) / (self._rows * self._cols)
    
    
    def _coord_outside_board(self, row, col):
        if row < 0 or row >= self._rows or col < 0 or col >= self._cols:
            return True
        return False
    
    
    def _get_connected(self, row, col, board=None):
        board = board if board is not None else self._board
        if self._coord_outside_board(row, col):
            raise ValueError("(", row, ",", col, ") is outside the bounds of the board; Cannot calculate liberties")
            
        chain = {(row, col)}
        neighbors = deque([(row+1, col), (row-1, col), (row, col+1), (row, col-1)])
        visited = {(row, col)}
        player = board[row][col]
        while len(neighbors) > 0:
            n_row, n_col = neighbors.popleft() # queue
            if self._coord_outside_board(n_row, n_col) or (n_row, n_col) in visited:
                continue
            visited.add((n_row, n_col))
            if board[n_row][n_col] == player:
                chain.add((n_row, n_col))
                new_neighbors = [(n_row+1, n_col), (n_row-1, n_col), (n_row, n_col+1), (n_row, n_col-1)]
                neighbors.extend(new_neighbors)
        return chain
    
    
    def _calculate_liberties(self, chain, board=None):
        board = board if board is not None else self._board
        visited, liberties = set(), set()
        for (row, col) in chain:
            visited.add((row, col))
            neighbors = deque([(row+1, col), (row-1, col), (row, col+1), (row, col-1)])
            for (n_row, n_col) in neighbors:
                if (n_row, n_col) in visited:
                    continue
                visited.add((n_row, n_col))
                if not self._coord_outside_board(n_row, n_col) and board[n_row][n_col] == 0:
                    liberties.add((n_row, n_col))
                
        return len(liberties)
           
        
    def get_legal_actions(self): #max 82 moves
        legal_actions = ["pass"]
        
        for row in range(self._rows):
            for col in range(self._cols):
                if (row, col) == self._ko and self._ko_player == -self.player:
                    continue
                elif self._board[row][col] != 0:
                    continue
                    
                board_copy = copy.deepcopy(self._board)
                board_copy[row][col] = self.player
                chain = self._get_connected(row, col, board=board_copy)
                if self._calculate_liberties(chain, board=board_copy) != 0:
                    legal_actions.append((row, col))
                    continue
                else:
                    for dr, dc in [(1,0), (-1,0), (0,1), (0,-1)]:
                        r, c = row + dr, col + dc
                        if 0 <= r < self._rows and 0 <= c < self._cols:
                            if board_copy[r][c] == -self.player:
                                opp_chain = self._get_connected(r, c, board=board_copy)
                                if self._calculate_liberties(opp_chain, board=board_copy) == 0:
                                    legal_actions.append((row, col))
                                    break
        print(f"{len(legal_actions)} found for player: {self.player}")
        return legal_actions

    
    def is_game_over(self, board=None):
        board = board if board is not None else self._board
        if self._last_black_move == "pass" and self._last_white_move == "pass":
            return True
        for row in range(self._rows):
            for col in range(self._cols):
                if board[row][col] == 0:
                    return False
        return True
    
    
    def _calculate_territories(self, area_scoring, board=None): 
        # Note: the concept of stones being "dead" is ignored in this implementation
        board = board if board is not None else self._board
        black_territories, white_territories = 0, 0
        visited = set()
        for row in range(self._rows):
            for col in range(self._cols):
                if (row, col) in visited:
                    continue
                if board[row][col] == 0:
                    group = self._get_connected(row, col, board)
                    neighbors, contested = set(), False
                    for (r, c) in group:
                        for (dr, dc) in [(1,0), (-1,0), (0,1), (0,-1)]:
                            if 0 <= r + dr < self._rows and 0 <= c + dc < self._cols:
                                n = board[r+dr][c+dc]
                                if n == 0:
                                    continue
                                neighbors.add(n)
                                if len(neighbors) > 1:
                                    contested = True
                                    break
                        if contested:
                            break
                    if len(neighbors) == 1:
                        player = next(iter(neighbors))
                        if player == 1:
                            black_territories += len(group)
                        else:
                            white_territories += len(group)
                    visited.update(group)
                    
                elif area_scoring and (row, col) not in visited:
                    group = self._get_connected(row, col, board)
                    if board[row][col] == 1:
                        black_territories += len(group)
                    else:
                        white_territories += len(group)
                    visited.update(group)
        return {"Black": black_territories, "White": white_territories}
    
    
    def _get_newest_turn(self, chain, board=None, turn_board=None):
        board = board if board is not None else self._board
        turn_board = turn_board if turn_board is not None else self._turn_board
        newest = 1
        for (r, c) in chain:
            if self._turn_board[r][c] > newest:
                newest = self._turn_board[r][c]
        return newest
            
    
    def _capture_prisoners(self, board=None, turn_board=None):
        whites_prisoners, blacks_prisoners = 0,0
        board = board if board is not None else self._board
        turn_board = turn_board if turn_board is not None else self._turn_board
        visited, to_remove = set(), set()
        
        for row in range(self._rows):
            for col in range(self._cols):
                if (row, col) not in visited and board[row][col] != 0:
                    candidate_group = self._get_connected(row, col, board)
                    visited.update(candidate_group)

                    if self._calculate_liberties(candidate_group, board) == 0:
                        group_timestamp = self._get_newest_turn(candidate_group, board, turn_board)
                        this_player = board[row][col]
                        remove_this_group = False
                        visited_neighbors = set()
                        
                        for (r, c) in candidate_group:
                            for (dr, dc) in [(1,0), (-1,0), (0,1), (0,-1)]:
                                nr, nc = r + dr, c + dc
                                if 0 <= nr < self._rows and 0 <= nc < self._cols:
                                    if (nr, nc) not in candidate_group and (nr, nc) not in visited_neighbors and board[nr][nc] == -this_player:
                                        opponent_group = self._get_connected(nr, nc, board)
                                        visited_neighbors.update(opponent_group)
                                        visited.update(opponent_group)
                                        
                                        if self._calculate_liberties(opponent_group, board) == 0: # ONE NEEDS TO GO
                                            neighbor_timestamp = self._get_newest_turn(opponent_group, board, turn_board)
                                            if neighbor_timestamp < group_timestamp:  # remove opponent_group
                                                to_remove.update(opponent_group)
                                            else:  # remove candidate_group and stop visiting neighbors
                                                to_remove.update(candidate_group)
                                                remove_this_group = True
                                            
                                            if remove_this_group:
                                                break   
                                        
                            if remove_this_group:
                                break
                        if len(to_remove) == 0:
                            to_remove.update(candidate_group)
                        
        # remove stones
        for (r, c) in to_remove:
            if board[r][c] == 1:
                whites_prisoners += 1
            else:
                blacks_prisoners += 1
                
            board[r][c] = 0
            turn_board[r][c] = 0

        return {"Black":blacks_prisoners, "White":whites_prisoners}
        
        
    def pass_turn(self):
        if self.player == 1:
            self._last_black_move = "pass"
        else:
            self._last_white_move = "pass"
        self.player *= -1
                    
                
    def move(self, row, col): # no simulations with move
        if row < 0 or row >= self._rows or col < 0 or col >= self._cols:
            raise ValueError("Invalid move: coordinate outside board")
        elif self._board[row][col] != 0:
            raise ValueError("Invalid move: coordinate not empty")
        elif (row, col) == self._ko and self._ko_player == -self.player:
            raise ValueError("Invalid move: Ko")
        newKo_r, newKo_c = None, None
        
        can_add = False
        
        # not immediately captured
        board_copy = copy.deepcopy(self._board)
        board_copy[row][col] = self.player
        chain = self._get_connected(row, col, board=board_copy)
        if self._calculate_liberties(chain, board=board_copy) != 0:
            can_add = True
            
        if not can_add: # captures an opponent piece
            for dr, dc in [(1,0), (-1,0), (0,1), (0,-1)]:
                r, c = row + dr, col + dc
                if 0 <= r < self._rows and 0 <= c < self._cols:
                    if board_copy[r][c] == -self.player:
                        opp_chain = self._get_connected(r, c, board=board_copy)
                        if self._calculate_liberties(opp_chain, board=board_copy) == 0:
                            can_add = True
                            if len(opp_chain) == 1:
                                newKo_r = r
                                newKo_c = c
                            break

        if not can_add:
            raise ValueError("Invalid move: immediate capture")
                            
        self._board[row][col] = self.player
        self._turn_board[row][col] = self.turn
        self.turn+=1
        
        if self.player == 1:
            self._last_black_move = (row, col)
        else:
            self._last_white_move = (row, col)

        new_prisoners = self._capture_prisoners()
        blacks_prisoners, whites_prisoners = new_prisoners["Black"], new_prisoners["White"]
        self._blacks_prisoners += blacks_prisoners
        self._whites_prisoners += whites_prisoners
        
        
        if blacks_prisoners + whites_prisoners == 1:
            self._ko = (newKo_r, newKo_c)
            self._ko_player = self.player
        else:
            self._ko = None
            self._ko_player = None
        self.player *= -1

### Step 2: Board to Tensor

With the Go engine complete, we now need a way to convert our representation of the board's state to a tensor—a nested array that is *mostly* one-hot encoded. I've opted to use 10 channels, each of which have a uniform 9x9 shape. These channels are as follows:

**Channel 0: Current Player's Stones** A 1 in (row, col) means a stone exists in that location, a 0 indicates otherwise.

**Channel 1: Opponent Player's Stones** A 1 in (row, col) means that an opposing stone exists in that location, a 0 indicates otherwise.

**Channel 2: Empty Positions** A 1 in (row, col) means that the space is empty, a 0 means it is filled.

**Channel 3: Ko Position** A 1 in (row, col) means that is the active Ko position. There can be at most 1 Ko position.

**Channel 4: Ko Player** The only relevant location in this 9x9 array is (0,0). A 1 in this location indicates that black was the player meaning to impose Ko (white cannot move here); alternatively, a 0 in this location indicates that white imposed the Ko.

**Channel 5: Current Player Turn** All 1's means it is black's turn; All 0's mean it's white's.

**Channel 6: Last Black Move** A 1 in (row, col) means this was black's last move. Black can have at most 1 last move.

**Channel 7: Last White Move** A 1 in (row, col) means this was white's last move. White can have at most 1 last move.

**Channel 8: Turn Board** This channel is *not* one-hot encoded. The number in (row, col) indicates which turn that stone was placed.

**Channel 9: Turn Number** (0,0) is the only relevant board location. This channel is also not one-hot encoded, with the number at (0,0) representing the current turn number. 

In [None]:
def board_to_tensor(board_instance):
    board = board_instance._board
    ko = board_instance._ko # (r,c)
    ko_player = board_instance._ko_player # 1 or -1
    player = board_instance.player # 1 or -1
    rows, cols = board.shape # 9,9
    last_black_move = board_instance._last_black_move # (r,c)
    last_white_move = board_instance._last_white_move # (r,c)
    turn_board = board_instance._turn_board
    turn = board_instance.turn # int
    tensor = np.zeros((10, rows, cols), dtype = np.float32)
    
    # Channels 0-2 are for board pieces
    for r in range(rows):
        for c in range(cols):
            if board[r][c] == player:
                tensor[0][r][c] = 1  # current player's stones
            elif board[r][c] == -player:
                tensor[1][r][c] = 1  # opponent stones
            elif board[r][c] == 0:
                tensor[2][r][c] = 1  # empty positions
                
    # Channels 3-4 are for Ko     
    if ko:
        ko_r, ko_c = ko
        tensor[3][ko_r][ko_c] = 1  # ko position
        if ko_player == 1:
            tensor[4][:, :] = 1
        else:
            tensor[4][:, :] = -1
    
    # Channel 5 is for the player's turn it is next
    if player == 1:
        tensor[5][:, :] = 1  # 1 for black
    else:
        tensor[5][:, :] = 0  # 0 for white
    
    # Channel 6-7 is for last black/white moves
    tensor[6][:, :] = 0
    tensor[7][:, :] = 0
    if last_black_move:
        r, c = last_black_move
        tensor[6][r][c] = 1
    if last_white_move:
        r, c = last_white_move
        tensor[7][r][c] = 1
    
    # Channel 8 is for tracking when pieces were placed
    tensor[8] = turn_board
    
    # Channel 9 is for the current turn #
    tensor[9][0][0] = turn
        
    return tensor

### ResNet Architecture

Below is a simple implementation of a Residual Network (ResNet) architecture used to create the policy and value networks. The policy network combines PyTorch's CNNs with ReLu activation to create an output array of size 81 (representing all possible moves) and their resulting win probabilities. The value network takes a similar approach and outputs a scalar value representing the expected win probability from the current board state.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleResidualBlock(nn.Module):
    def __init__(self, channels):
        super(SimpleResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = F.relu(out)
        
        out = self.conv2(out)
        out = F.relu(out)
        
        out = out + residual
        out = F.relu(out)
        return out

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels=10, out_channels=32, kernel_size=3, padding=1)
        self.res_block1 = SimpleResidualBlock(32)
        self.res_block2 = SimpleResidualBlock(32)
        
        self.policy_conv = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1)
        self.policy_fc1 = nn.Linear(in_features=(2 * 9 * 9), out_features=128)
        self.policy_fc2 = nn.Linear(in_features=128, out_features=(9 * 9 * 1))
        
        self.value_conv = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1)
        self.value_fc1 = nn.Linear(in_features=(1 * 9 * 9), out_features=64)
        self.value_fc2 = nn.Linear(in_features=64, out_features=1)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        x = self.initial_conv(x)
        x = F.relu(x)
        x = self.res_block1(x)
        x = self.res_block2(x)
        
        policy = self.policy_conv(x)
        policy = F.relu(policy)
        policy = policy.view(batch_size, -1) # reshapes to (batch_size, 162)
        policy = self.policy_fc1(policy)
        policy = F.relu(policy)
        policy_logits = self.policy_fc2(policy)
                                    
        
        print(f"Policy logits shape: {policy_logits.shape}")
        
        value = self.value_conv(x)
        value = F.relu(value)
        value = value.view(batch_size, -1) # reshapes to (batch_size, 81)
        value = self.value_fc1(value)
        value = F.relu(value)
        value = self.value_fc2(value)
        value = torch.tanh(value) # ensures values are in range [-1, 1] (W/L)
        print(f"Value shape: {value.shape}")
        
        return policy_logits, value

### Training Data

I've downloaded many games of Go in SGF format from https://github.com/gto76/online-go-games/blob/master/games.zip.005. To train our model, we first need to parse these SGF files into an accepted form of training data (tensors). Part of the data cleaning process including filtering out games that were not played on a 9x9 board as well as those where the player's rank was below 9k. This parsing and chunking of the resulting dataset was computed on my local CPU.

In [None]:
import os
import torch
import numpy as np
from sgfmill import sgf


# Parse a single SGF file and generate training samples
def extract_data_from_sgf_custom(file_path):
    try:
        with open(file_path, "rb") as f:
            game = sgf.Sgf_game.from_bytes(f.read())
    except Exception as e:
        print(f"Skipping file {file_path}: {e}")
        return [], None

    # Skip non-9x9 boards
    if game.get_size() != 9:
        return [], None
    
    try:  
        black_rank = game.get_root().get('BR')
        white_rank = game.get_root().get('WR')
        if len(black_rank) > 2 or len(white_rank) > 2:
            return [], None
    except Exception as e:
        print(f"Skipping file {file_path}: {e}")
        return [], None
            
    print(f"Black Rank: {black_rank}")

    board = Board(9, 9)
    samples = []

    for node in game.get_main_sequence():
        move = node.get_move()
        color, coords = move

        if coords is None:
            continue  # Skip pass moves

        row, col = coords
        player = 1 if color == 'b' else -1

        board_tensor = board_to_tensor(board)
        move_index = row * 9 + col
        samples.append((board_tensor, move_index))
        try:
            board.move(row, col)
        except:
            continue
        

    winner = game.get_winner()
    result = 1 if winner == 'b' else -1 if winner == 'w' else 0
    return samples, result


# Process all SGF files in a directory
def process_all_sgfs_custom(sgf_dir):
    X, y_policy, y_value = [], [], []

    buffer_X, buffer_policy, buffer_value = [], [], []
    counter = 0
    chunk_size = 100
    chunk_index = 0

    for filename in os.listdir(sgf_dir):
        
        if not filename.endswith(".sgf"):
            continue

        path = os.path.join(sgf_dir, filename)
        game_data, result = extract_data_from_sgf_custom(path)

        if game_data and result is not None:
            print(counter)
            for board_tensor, move_idx in game_data:
                buffer_X.append(torch.tensor(board_tensor))
                buffer_policy.append(move_idx)
                buffer_value.append(result)

        counter += 1

        # Save chunk if ready
        if counter % chunk_size == 0 and len(buffer_X) > 0:
            print(f"Saving chunk {chunk_index} with {len(buffer_X)} samples")
            torch.save({
                'X': torch.stack(buffer_X),
                'y_policy': torch.tensor(buffer_policy),
                'y_value': torch.tensor(buffer_value, dtype=torch.float32)
            }, f"dataset_chunk_{chunk_index}.pt")

            chunk_index += 1
            buffer_X, buffer_policy, buffer_value = [], [], []
            
        if counter > 200000:
            break

    # Final save if anything is left over
    if len(buffer_X) > 0:
        print(f"Saving final chunk {chunk_index} with {len(buffer_X)} samples")
        torch.save({
            'X': torch.stack(buffer_X),
            'y_policy': torch.tensor(buffer_policy),
            'y_value': torch.tensor(buffer_value, dtype=torch.float32)
        }, f"dataset_chunk_{chunk_index}.pt")

    return None, None, None  # if saving in chunks, we don't return a full in-memory dataset

# Save to .pt
def save_dataset(path, X, y_policy, y_value):
    dataset = {"X": X, "y_policy": y_policy, "y_value": y_value}
    torch.save(dataset, path)

sgf_dir = "[FILL ME IN]"
# sgf_dir = "/Users/jasonkatz/Desktop/pro-go-games"
output_path = "[FILL ME IN]/pro_go_dataset.pt"
# output_path = "/Users/jasonkatz/Desktop/test-games/go_dataset_custom.pt"

X, y_policy, y_value = process_all_sgfs_custom(sgf_dir)
save_dataset(output_path, X, y_policy, y_value)
print(f"Saved {len(X)} training samples to {output_path}")


### Training the Model

To train the model, we load each chunked dataset, instantiate our model, and iteratively update a policy loss and value loss after each pass through the set of labeled data. The resulting weights are ultimately saved to a ".pt" file for future use.

In [None]:
import torch
import os
from torch.utils.data import Dataset

class GoChunkedDataset(Dataset):
    def __init__(self, chunk_folder):
        self.chunk_paths = [
            os.path.join(chunk_folder, f)
            for f in os.listdir(chunk_folder)
            if f.endswith(".pt")
        ]
        
        self.samples = []  # (file_idx, local_idx) mapping

        # Load file sizes
        self.chunk_metadata = []
        for i, path in enumerate(self.chunk_paths):
            data = torch.load(path)
            size = len(data['X'])
            self.chunk_metadata.append((i, size))
            for j in range(size):
                self.samples.append((i, j))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_idx, local_idx = self.samples[idx]
        data = torch.load(self.chunk_paths[file_idx])
        return (
            data['X'][local_idx],          # [10, 9, 9]
            data['y_policy'][local_idx],   # int (0–80)
            data['y_value'][local_idx]     # float (-1 to 1)
        )


In [None]:
from torch.utils.data import DataLoader

dataset = GoChunkedDataset("[FILL ME IN]")
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [None]:
import torch.nn as nn

policy_loss_fn = nn.CrossEntropyLoss()
value_loss_fn = nn.MSELoss()
model = ResNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(40):
    total_loss = 0
    for board_tensor, move_idx, game_result in dataloader:
        board_tensor = board_tensor.to(device)
        move_idx = move_idx.to(device)
        game_result = game_result.to(device)

        policy_logits, value_output = model(board_tensor)

        loss_policy = policy_loss_fn(policy_logits, move_idx)
        loss_value = value_loss_fn(value_output.squeeze(), game_result)

        loss = loss_policy + loss_value
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")


In [None]:
torch.save(model.state_dict(), "resnet_go_pretrained.pt")

In [None]:
import torch

data = torch.load("[FILL ME IN]")

In [None]:
import torch

# Load your combined dataset (or use a chunk)
data = torch.load("dataset_chunk_0.pt")  # or "dataset_chunk_0.pt"

X = data['X']
y_policy = data['y_policy']
y_value = data['y_value']

model = ResNet()
model.load_state_dict(torch.load("resnet_go_pretrained.pt", map_location=torch.device("cpu")))
model.eval()

import torch.nn.functional as F

correct_policy = 0
total = 0
mse_loss = 0

with torch.no_grad():
    for i in range(len(X)):
        board_tensor = X[i].unsqueeze(0)  # Add batch dimension [1, 10, 9, 9]
        policy_target = y_policy[i]
        value_target = y_value[i]

        policy_logits, value_output = model(board_tensor)

        # POLICY ACCURACY
        predicted_move = policy_logits.argmax(dim=1).item()
        if predicted_move == policy_target.item():
            correct_policy += 1

        # VALUE MSE
        predicted_value = value_output.squeeze().item()
        mse_loss += (predicted_value - value_target.item()) ** 2

        total += 1

# Final metrics
accuracy = correct_policy / total
mse = mse_loss / total

print(f"Evaluation Complete")
print(f"Policy Accuracy: {accuracy * 100:.2f}%")
print(f"Value Head MSE:  {mse:.4f}")