<a href="https://colab.research.google.com/github/averyrair/ChessBAKEN/blob/MoveSelect/MoveSelector/MoveSelectorBAKEN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
#@title Install Dependencies

!pip install chess
#!pip install stockfish
#!apt install stockfish

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147776 sha256=35db34c26ed5bfacd6be8427d086bd8e891a955ce7e1dba4c615c4c226de53b3
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.2


In [4]:
#@title Import Directories

from google.colab import drive
# For this to work, you need to have the "Chess Bot BAKEN" project shared with
# the SAME email you linked this Colab to, presumably your GitHub email address.
drive.mount('/content/drive', force_remount=1)
chess_dir = '/content/drive/Shareddrives/Chess Bot BAKEN'

Mounted at /content/drive


In [5]:
#@title Import Libraries

import os
import chess
import chess.pgn
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
#@title Define Custom Board Representation

def customBoardRep(board):
  # From the current board position, this method should output integer equivalent of ones and zeros of length 8 x 8 x 6 x 2.
  P = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.PAWN, chess.WHITE).tolist())).astype(int)))) + " "
  N = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.KNIGHT, chess.WHITE).tolist())).astype(int)))) + " "
  B = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.BISHOP, chess.WHITE).tolist())).astype(int)))) + " "
  R = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.ROOK, chess.WHITE).tolist())).astype(int)))) + " "
  Q = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.QUEEN, chess.WHITE).tolist())).astype(int)))) + " "
  K = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.KING, chess.WHITE).tolist())).astype(int)))) + " "
  p = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.PAWN, chess.BLACK).tolist())).astype(int)))) + " "
  n = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.KNIGHT, chess.BLACK).tolist())).astype(int)))) + " "
  b = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.BISHOP, chess.BLACK).tolist())).astype(int)))) + " "
  r = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.ROOK, chess.BLACK).tolist())).astype(int)))) + " "
  q = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.QUEEN, chess.BLACK).tolist())).astype(int)))) + " "
  k = str(sum(j<<i for i,j in enumerate((np.asarray(board.pieces(chess.KING, chess.BLACK).tolist())).astype(int)))) + " "

  return P + N + B + R + Q + K + p + n + b + r + q + k

In [None]:
#@title Import PGN and Export CSV

import chess.pgn
if not os.path.exists(os.path.join(chess_dir, 'moveSelectorDataset.csv')):
  pgn_dir = os.path.join(chess_dir, 'KingBase2019-pgn')
  pgnFiles = [os.path.join(pgn_dir, f) for f in os.listdir(pgn_dir) if f.endswith('.pgn')]

  # Prepare CSV to export dataset.
  with open(os.path.join(chess_dir, 'moveSelectorDataset.csv'), 'w') as csvfile:
    csvwriter = csv.writer(csvfile)
    #csvwriter.writeheader(['Label', 'Input'])

    # Read every PGN file in the directory.
    iterator = 0
    for pgnFile in pgnFiles:
      print(pgnFile)
      pgn = open(pgnFile)

      # Read every game in the PGN file.
      while True:
        game = chess.pgn.read_game(pgn)
        if game is None: break
        iterator += 1
        if iterator % 500 == 0: print(iterator)
        board = game.board()

        # New data point for every move in the game.
        gameMoves = list(game.mainline_moves())
        for i in range(len(gameMoves)-1):
          board.push(gameMoves[i])
          legalMoves = " "
          for m in list(board.legal_moves):
            legalMoves = legalMoves + str(m) + " "
          #print(gameMoves[i+1], legalMoves, customBoardRep(board))
          csvwriter.writerow([gameMoves[i+1], customBoardRep(board) + legalMoves])

In [None]:
#@title Define Fully-Connected Neural Network

class myFCN(nn.Module):
  def __init__(self, inSize, hiddenSizes, outSize):
    super().__init__()
    self.inSize = inSize
    self.hiddenSize = hiddenSizes
    self.outSize = outSize

    self.lin1 = nn.Linear(inSize, hiddenSizes[0])
    self.lin2 = nn.Linear(hiddenSizes[0], hiddenSizes[1])
    self.lin3 = nn.Linear(hiddenSizes[1], outSize)

    self.bn1 = nn.BatchNorm2d(hiddenSizes[0])
    self.bn2 = nn.BatchNorm2d(hiddenSizes[1])

    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.relu(self.bn1(self.lin1(x)))
    x = self.relu(self.bn2(self.lin2(x)))
    x = self.lin3(x)
    return x

In [None]:
#@title Initialize Fully-Connected Neural Network

moveSelector = myFCN(inSize=2560, hiddenSizes=[2200, 2200], outSize=1792)
moveSelector

myFCN(
  (lin1): Linear(in_features=2560, out_features=2200, bias=True)
  (lin2): Linear(in_features=2200, out_features=2200, bias=True)
  (lin3): Linear(in_features=2200, out_features=1792, bias=True)
  (bn1): BatchNorm2d(2200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(2200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
)

In [None]:
#@title Define Dataset

class ChessDataset(Dataset):
    def __init__(self, chessgame_file, chess_dir, transform=None, target_transform=None):
        self.game_labels = pd.read_csv(chessgame_file)
        self.chess_dir = chess_dir  # This is not used. Everything is in the chessgame file.
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.game_labels)

    def __getitem__(self, idx):
        label = self.game_labels.iloc[idx, 0]
        gameState = self.game_labels.iloc[idx, 1]
        if self.transform:
            gameState = self.transform(gameState)
        if self.target_transform:
            label = self.target_transform(label)
        return gameState, label

In [6]:
#@title Define Encode/Decode Functions

# The number of possible pieces and square states (8*8*6*2)
NUM_BOARD_NODES = 768
# The number of possible moves from any square to any other
NUM_MOVE_NODES = 1792


# Determines whether any piece could ever move from one square to another
# Note: pass in start and end coordinates as tuples containing two ints
def canMove(start, end):
    # check duplicate
    if start[0] == end[0] and start[1] == end[1]:
        return False

    # check row or col
    if start[0] == end[0] or start[1] == end[1]:
        return True

    # check diagonal
    if abs(start[0]-end[0]) == abs(start[1]-end[1]):
        return True

    # check knight
    knightCases = [(1, 2), (1, -2), (-1, 2), (-1, -2),
                   (2, 1), (2, -1), (-2, 1), (-2, -1)]
    for c in knightCases:
        if start[0] + c[0] == end[0] and start[1] + c[1] == end[1]:
            return True

    return False


# Generate the mapping matrix (only needs to be done once)
def getMappingMatrix():
    moveMap = [[[[0 for l in range(8)] for k in range(
        8)] for j in range(8)] for i in range(8)]
    offset = 0
    for i in range(8):
        for j in range(8):
            for k in range(8):
                for l in range(8):
                    if canMove((i, j), (k, l)):
                        moveMap[i][j][k][l] = offset
                        offset += 1
    return moveMap

# Encode a board position and legal moves into a binary array
def encodeSelector(boardFEN, moves, moveMap):
    # Encode board position
    encodedBoard = [0 for _ in range(NUM_BOARD_NODES)]
    rowsFEN = boardFEN.split('/')
    for i in range(8):
        col = 0
        for j in range(len(rowsFEN[i])):
            if rowsFEN[i][j].isdigit():
                col += ord(rowsFEN[i][j])-48
            else:
                match rowsFEN[i][j]:
                    case 'P':
                        idx = (8*6*2*col) + (6*2*i) + (0*2) + 0
                        encodedBoard[idx] = 1
                    case 'N':
                        idx = (8*6*2*col) + (6*2*i) + (1*2) + 0
                        encodedBoard[idx] = 1
                    case 'B':
                        idx = (8*6*2*col) + (6*2*i) + (2*2) + 0
                        encodedBoard[idx] = 1
                    case 'R':
                        idx = (8*6*2*col) + (6*2*i) + (3*2) + 0
                        encodedBoard[idx] = 1
                    case 'Q':
                        idx = (8*6*2*col) + (6*2*i) + (4*2) + 0
                        encodedBoard[idx] = 1
                    case 'K':
                        idx = (8*6*2*col) + (6*2*i) + (5*2) + 0
                        encodedBoard[idx] = 1
                    case 'p':
                        idx = (8*6*2*col) + (6*2*i) + (0*2) + 1
                        encodedBoard[idx] = 1
                    case 'n':
                        idx = (8*6*2*col) + (6*2*i) + (1*2) + 1
                        encodedBoard[idx] = 1
                    case 'b':
                        idx = (8*6*2*col) + (6*2*i) + (2*2) + 1
                        encodedBoard[idx] = 1
                    case 'r':
                        idx = (8*6*2*col) + (6*2*i) + (3*2) + 1
                        encodedBoard[idx] = 1
                    case 'q':
                        idx = (8*6*2*col) + (6*2*i) + (4*2) + 1
                        encodedBoard[idx] = 1
                    case 'k':
                        idx = (8*6*2*col) + (6*2*i) + (5*2) + 1
                        encodedBoard[idx] = 1
                col += 1

    # Encode possible moves
    encodedMoves = [0 for _ in range(NUM_MOVE_NODES)]
    for m in moves.split(' '):
        startCol = ord(str(m)[0])-97  # e.g. c -> 2
        startRow = ord(str(m)[1])-49  # e.g. 8 -> 7
        endCol = ord(str(m)[2])-97
        endRow = ord(str(m)[3])-49
        encodedMoves[(moveMap[startCol][startRow][endCol][endRow])] = 1

    # Combine into single list
    return encodedBoard + encodedMoves


# Decode board position from a binary array of length 768
def decodePosition(input):
    # Create matrix of board
    boardMatrix = [['.' for _ in range(8)] for _ in range(8)]

    # Populate board
    for i in range(len(input)):
        if input[i] == 1:
            piece = i % 12
            # Piece: P = 0, R = 1, N = 2, B = 3, Q = 4, K = 5
            # Color: W = 0, B = 1
            pieceLabel = '.'
            match piece:
                case 0:
                    pieceLabel = 'P'
                case 1:
                    pieceLabel = 'p'
                case 2:
                    pieceLabel = 'N'
                case 3:
                    pieceLabel = 'n'
                case 4:
                    pieceLabel = 'B'
                case 5:
                    pieceLabel = 'b'
                case 6:
                    pieceLabel = 'R'
                case 7:
                    pieceLabel = 'r'
                case 8:
                    pieceLabel = 'Q'
                case 9:
                    pieceLabel = 'q'
                case 10:
                    pieceLabel = 'K'
                case 11:
                    pieceLabel = 'k'
            boardMatrix[int(i / (6*2)) % 8][int(i / (8*6*2))] = pieceLabel

    # Convert to FEN notation
    boardFEN = ''
    for i in range(len(boardMatrix)):
        empty = 0
        for j in range(len(boardMatrix[i])):
            if boardMatrix[i][j] == '.':
                empty += 1
            else:
                if empty > 0:
                    boardFEN += str(empty)
                    empty = 0
                boardFEN += boardMatrix[i][j]
        if empty > 0:
            boardFEN += str(empty)
        if i < len(boardMatrix) - 1:
            boardFEN += '/'

    return chess.Board(boardFEN)


# Decode moves from a binary array of length 1792
# Returns a list of moves with start and end squares
def decodeMoves(input, moveMap):
    moves = []
    for h in range(len(input)):
        if input[h] > 0: # NOTE: may need to change this condition based on selector NN output
            # Search moveMap to find the actual move squares
            found = False
            for i in range(8):
                for j in range(8):
                    for k in range(8):
                        for l in range(8):
                            if moveMap[i][j][k][l] == h:
                                moves.append(chr(i+97) + chr(j+49) + chr(k+97) + chr(l+49))
                                found = True
                                break
                        if found:
                            break
                    if found:
                        break
                if found:
                    break
    return moves

### TESTING ###

# Test the encoder and decoder
def testEncodeDecode():
    # Encode board
    moveMap = getMappingMatrix()
    with open(os.path.join(chess_dir, 'Move Selector Dataset', 'moveSelectorDatasetD70-D99.csv'), 'r') as file:
        data = list(csv.reader(file))
    curr = data[0]
    encoded = encodeSelector(curr[1], curr[2], moveMap)

    # Decode board
    decodedBoard = decodePosition(encoded[0:NUM_BOARD_NODES])
    decodedMoves = decodeMoves(encoded[-NUM_MOVE_NODES:], moveMap)

    # Compare original and decoded position/moves
    print('Original position and legal moves: ')
    print(chess.Board(curr[1]))
    print(sorted(curr[2].split(' ')))
    print('---------------')
    print('Decoded position and legal moves: ')
    print(decodedBoard)
    print(decodedMoves)

testEncodeDecode()

Original position and legal moves: 
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . P . . . .
. . . . . . . .
P P P . P P P P
R N B Q K B N R
['a7a5', 'a7a6', 'b7b5', 'b7b6', 'b8a6', 'b8c6', 'c7c5', 'c7c6', 'd7d5', 'd7d6', 'e7e5', 'e7e6', 'f7f5', 'f7f6', 'g7g5', 'g7g6', 'g8f6', 'g8h6', 'h7h5', 'h7h6']
---------------
Decoded position and legal moves: 
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . P . . . .
. . . . . . . .
P P P . P P P P
R N B Q K B N R
['a7a5', 'a7a6', 'b7b5', 'b7b6', 'b8a6', 'b8c6', 'c7c5', 'c7c6', 'd7d5', 'd7d6', 'e7e5', 'e7e6', 'f7f5', 'f7f6', 'g7g5', 'g7g6', 'g8f6', 'g8h6', 'h7h5', 'h7h6']
