<a href="https://colab.research.google.com/github/averyrair/ChessBAKEN/blob/selector_testing/MoveSelector/MoveSelectorBAKEN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Install Dependencies

!pip install chess
#!pip install stockfish
#!apt install stockfish

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=000cec4511644f09ff2964e98a802e2da0c865509c72efd1f608a999f1642d58
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.2


In [2]:
#@title Import Directories

from google.colab import drive
# For this to work, you need to have the "Chess Bot BAKEN" project shared with
# the SAME email you linked this Colab to, presumably your GitHub email address.
drive.mount('/content/drive', force_remount=1)
chess_dir = '/content/drive/Shareddrives/Chess Bot BAKEN'

Mounted at /content/drive


In [3]:
#@title Import Libraries

import os
import chess
import chess.pgn
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [4]:
#@title Define Custom Board Representation

def customBoardRep(board):

  return board.fen().split(" ")[0]

In [5]:
#@title Import PGN and Export CSV

import chess.pgn
#if not os.path.exists(os.path.join(chess_dir, 'moveSelectorDataset.csv')):
if 0:
  pgn_dir = os.path.join(chess_dir, 'KingBase2019-pgn')
  pgnFiles = [os.path.join(pgn_dir, f) for f in os.listdir(pgn_dir) if f.endswith('.pgn')]

  # Read every PGN file in the directory.
  iterator = 0
  for pgnFile in pgnFiles:
    print(pgnFile)
    pgn = open(pgnFile)

    # Prepare CSV to export dataset.
    with open(os.path.join(chess_dir, 'moveSelectorDataset' + pgnFile[74:81] +'.csv'), 'w') as csvfile:
      csvwriter = csv.writer(csvfile)
      #csvwriter.writeheader(['Label', 'Input'])

      # Read every game in the PGN file.
      while True:
        game = chess.pgn.read_game(pgn)
        if game is None: break
        iterator += 1
        if iterator % 500 == 0:
          print(iterator)
          #break
        board = game.board()

        # New data point for every move in the game.
        gameMoves = list(game.mainline_moves())
        for i in range(len(gameMoves)-1):
          board.push(gameMoves[i])
          legalMoves = ""
          for m in list(board.legal_moves):
            legalMoves = legalMoves + str(m) + " "
          legalMoves = legalMoves.removesuffix(" ")
          #print(gameMoves[i+1], legalMoves, customBoardRep(board))
          csvwriter.writerow([gameMoves[i+1], customBoardRep(board), legalMoves])

In [5]:
#@title Define Fully-Connected Neural Network

class myFCN(nn.Module):
  def __init__(self, inSize, hiddenSizes, outSize):
    super().__init__()
    self.inSize = inSize
    self.hiddenSize = hiddenSizes
    self.outSize = outSize

    self.lin1 = nn.Linear(inSize, hiddenSizes[0])
    self.lin2 = nn.Linear(hiddenSizes[0], hiddenSizes[1])
    self.lin3 = nn.Linear(hiddenSizes[1], outSize)

    self.bn1 = nn.BatchNorm1d(hiddenSizes[0])
    self.bn2 = nn.BatchNorm1d(hiddenSizes[1])

    self.relu = nn.ReLU()
    self.softmax = nn.Softmax(dim=0)

  def forward(self, x):
    x = self.relu(self.bn1(self.lin1(x)))
    x = self.relu(self.bn2(self.lin2(x)))
    x = self.softmax(self.lin3(x))
    return x

In [15]:
#@title Initialize Fully-Connected Neural Network

currentModelFile = 'MoveSelectorV1.6.pt'

moveSelector = myFCN(inSize=2560, hiddenSizes=[2200, 2200], outSize=1792)
if (os.path.exists(os.path.join(chess_dir, currentModelFile))):
  moveSelector = torch.load(os.path.join(chess_dir, currentModelFile), weights_only=False)
  moveSelector.eval()
  print("Loaded Model: " + currentModelFile)
print(moveSelector)

Loaded Model: MoveSelectorV1.6.pt
myFCN(
  (lin1): Linear(in_features=2560, out_features=2200, bias=True)
  (lin2): Linear(in_features=2200, out_features=2200, bias=True)
  (lin3): Linear(in_features=2200, out_features=1792, bias=True)
  (bn1): BatchNorm1d(2200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(2200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (softmax): Softmax(dim=0)
)


In [16]:
#@title Define Dataset

class ChessDataset(Dataset):
    def __init__(self, chessgame_file, chess_dir, transform=None, target_transform=None):
        self.game_labels = pd.read_csv(chessgame_file)
        self.chess_dir = chess_dir  # This is not used. Everything is in the chessgame file.
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.game_labels)

    def __getitem__(self, idx):
        label = self.game_labels.iloc[idx, 0]
        gameState = self.game_labels.iloc[idx, 1]
        moves = self.game_labels.iloc[idx, 2]

        return gameState, moves, label

In [17]:
#@title Define Encode/Decode Functions

# The number of possible pieces and square states (8*8*6*2)
NUM_BOARD_NODES = 768
# The number of possible moves from any square to any other
NUM_MOVE_NODES = 1792


# Determines whether any piece could ever move from one square to another
# Note: pass in start and end coordinates as tuples containing two ints
def canMove(start, end):
    # check duplicate
    if start[0] == end[0] and start[1] == end[1]:
        return False

    # check row or col
    if start[0] == end[0] or start[1] == end[1]:
        return True

    # check diagonal
    if abs(start[0]-end[0]) == abs(start[1]-end[1]):
        return True

    # check knight
    knightCases = [(1, 2), (1, -2), (-1, 2), (-1, -2),
                   (2, 1), (2, -1), (-2, 1), (-2, -1)]
    for c in knightCases:
        if start[0] + c[0] == end[0] and start[1] + c[1] == end[1]:
            return True

    return False


# Generate the mapping matrix (only needs to be done once)
def getMappingMatrix():
    moveMap = [[[[-1 for l in range(8)] for k in range(
        8)] for j in range(8)] for i in range(8)]
    offset = 0
    for i in range(8):
        for j in range(8):
            for k in range(8):
                for l in range(8):
                    if canMove((i, j), (k, l)):
                        moveMap[i][j][k][l] = offset
                        offset += 1
    return moveMap

# Encode a board position and legal moves into a binary array
def encodeSelector(boardFEN, moves, moveMap):
    # Encode board position
    encodedBoard = [0 for _ in range(NUM_BOARD_NODES)]
    rowsFEN = boardFEN.split('/')
    for i in range(8):
        col = 0
        for j in range(len(rowsFEN[i])):
            if rowsFEN[i][j].isdigit():
                col += ord(rowsFEN[i][j])-48
            else:
                match rowsFEN[i][j]:
                    case 'P':
                        idx = (8*6*2*col) + (6*2*i) + (0*2) + 0
                        encodedBoard[idx] = 1
                    case 'N':
                        idx = (8*6*2*col) + (6*2*i) + (1*2) + 0
                        encodedBoard[idx] = 1
                    case 'B':
                        idx = (8*6*2*col) + (6*2*i) + (2*2) + 0
                        encodedBoard[idx] = 1
                    case 'R':
                        idx = (8*6*2*col) + (6*2*i) + (3*2) + 0
                        encodedBoard[idx] = 1
                    case 'Q':
                        idx = (8*6*2*col) + (6*2*i) + (4*2) + 0
                        encodedBoard[idx] = 1
                    case 'K':
                        idx = (8*6*2*col) + (6*2*i) + (5*2) + 0
                        encodedBoard[idx] = 1
                    case 'p':
                        idx = (8*6*2*col) + (6*2*i) + (0*2) + 1
                        encodedBoard[idx] = 1
                    case 'n':
                        idx = (8*6*2*col) + (6*2*i) + (1*2) + 1
                        encodedBoard[idx] = 1
                    case 'b':
                        idx = (8*6*2*col) + (6*2*i) + (2*2) + 1
                        encodedBoard[idx] = 1
                    case 'r':
                        idx = (8*6*2*col) + (6*2*i) + (3*2) + 1
                        encodedBoard[idx] = 1
                    case 'q':
                        idx = (8*6*2*col) + (6*2*i) + (4*2) + 1
                        encodedBoard[idx] = 1
                    case 'k':
                        idx = (8*6*2*col) + (6*2*i) + (5*2) + 1
                        encodedBoard[idx] = 1
                col += 1

    # Encode possible moves
    encodedMoves = [0 for _ in range(NUM_MOVE_NODES)]
    for m in moves.split(' '):
        startCol = ord(str(m)[0])-97  # e.g. c -> 2
        startRow = ord(str(m)[1])-49  # e.g. 8 -> 7
        endCol = ord(str(m)[2])-97
        endRow = ord(str(m)[3])-49
        encodedMoves[(moveMap[startCol][startRow][endCol][endRow])] = 1

    # Combine into single list
    return encodedBoard + encodedMoves


# Decode board position from a binary array of length 768
def decodePosition(input):
    # Create matrix of board
    boardMatrix = [['.' for _ in range(8)] for _ in range(8)]

    # Populate board
    for i in range(len(input)):
        if input[i] == 1:
            piece = i % 12
            # Piece: P = 0, R = 1, N = 2, B = 3, Q = 4, K = 5
            # Color: W = 0, B = 1
            pieceLabel = '.'
            match piece:
                case 0:
                    pieceLabel = 'P'
                case 1:
                    pieceLabel = 'p'
                case 2:
                    pieceLabel = 'N'
                case 3:
                    pieceLabel = 'n'
                case 4:
                    pieceLabel = 'B'
                case 5:
                    pieceLabel = 'b'
                case 6:
                    pieceLabel = 'R'
                case 7:
                    pieceLabel = 'r'
                case 8:
                    pieceLabel = 'Q'
                case 9:
                    pieceLabel = 'q'
                case 10:
                    pieceLabel = 'K'
                case 11:
                    pieceLabel = 'k'
            boardMatrix[int(i / (6*2)) % 8][int(i / (8*6*2))] = pieceLabel

    # Convert to FEN notation
    boardFEN = ''
    for i in range(len(boardMatrix)):
        empty = 0
        for j in range(len(boardMatrix[i])):
            if boardMatrix[i][j] == '.':
                empty += 1
            else:
                if empty > 0:
                    boardFEN += str(empty)
                    empty = 0
                boardFEN += boardMatrix[i][j]
        if empty > 0:
            boardFEN += str(empty)
        if i < len(boardMatrix) - 1:
            boardFEN += '/'

    return chess.Board(boardFEN)


# Decode moves from a binary array of length 1792
# Returns a list of moves with start and end squares
def decodeMoves(input, moveMap):
    moves = []
    for h in range(len(input)):
        if input[h] > 0: # NOTE: may need to change this condition based on selector NN output
            # Search moveMap to find the actual move squares
            found = False
            for i in range(8):
                for j in range(8):
                    for k in range(8):
                        for l in range(8):
                            if moveMap[i][j][k][l] == h:
                                moves.append(chr(i+97) + chr(j+49) + chr(k+97) + chr(l+49))
                                found = True
                                break
                        if found:
                            break
                    if found:
                        break
                if found:
                    break
    return moves

def decodeLegalMoves(original, input, moveMap, threshold = 0):
    indices = [i for i, x in enumerate(original) if x == 1]
    moves = []
    for h in indices:#range(len(input)):
        if input[h] > threshold: # NOTE: may need to change this condition based on selector NN output
            # Search moveMap to find the actual move squares
            found = False
            for i in range(8):
                for j in range(8):
                    for k in range(8):
                        for l in range(8):
                            if moveMap[i][j][k][l] == h:
                                moves.append((chr(i+97) + chr(j+49) + chr(k+97) + chr(l+49), input[h].item()))
                                found = True
                                break
                        if found:
                            break
                    if found:
                        break
                if found:
                    break
    return moves

def moveLabel(moveMade, moveMap):
    encodedMoves = [0 for _ in range(NUM_MOVE_NODES)]
    startCol = ord(str(moveMade)[0])-97  # e.g. c -> 2
    startRow = ord(str(moveMade)[1])-49  # e.g. 8 -> 7
    endCol = ord(str(moveMade)[2])-97
    endRow = ord(str(moveMade)[3])-49
    encodedMoves[(moveMap[startCol][startRow][endCol][endRow])] = 1
    return encodedMoves

def getMoveLoss(modelOutput, expectedOutput, weighting = NUM_MOVE_NODES):
    totalLoss = 0
    for i in range(NUM_MOVE_NODES):
        if expectedOutput[i] == 1:
            # model should output value close to 1
            totalLoss += abs(1 - modelOutput[i]) * weighting
            pass
        else:
            # model should output value close to 0
            totalLoss += abs(modelOutput[i])
            pass
    return totalLoss

def getMoveLossV1(modelOutput, moveMade, moveMap, exponent = 1):
    startCol = ord(str(moveMade)[0])-97  # e.g. c -> 2
    startRow = ord(str(moveMade)[1])-49  # e.g. 8 -> 7
    endCol = ord(str(moveMade)[2])-97
    endRow = ord(str(moveMade)[3])-49
    #totalLoss = torch.sum(modelOutput)
    totalLoss = abs(1 - modelOutput[(moveMap[startCol][startRow][endCol][endRow])]) ** exponent

    return totalLoss

def getMoveLossV2(modelOutput, moveMade, moveMap):
    startCol = ord(str(moveMade)[0])-97  # e.g. c -> 2
    startRow = ord(str(moveMade)[1])-49  # e.g. 8 -> 7
    endCol = ord(str(moveMade)[2])-97
    endRow = ord(str(moveMade)[3])-49

    moveChance = modelOutput[moveMap[startCol][startRow][endCol][endRow]]
    losses = [x for x in modelOutput if x >= moveChance]
    # This just doesn't really work at all.
    return torch.tensor(len(losses), dtype=torch.float, requires_grad=True)

### TESTING ###

# Test the encoder and decoder
def testEncodeDecode():
    # Encode board
    moveMap = getMappingMatrix()
    with open(os.path.join(chess_dir, 'Move Selector Dataset', 'moveSelectorDatasetD70-D99.csv'), 'r') as file:
        data = list(csv.reader(file))
    curr = data[0]
    encoded = encodeSelector(curr[1], curr[2], moveMap)

    # Decode board
    decodedBoard = decodePosition(encoded[0:NUM_BOARD_NODES])
    decodedMoves = decodeMoves(encoded[-NUM_MOVE_NODES:], moveMap)

    # Compare original and decoded position/moves
    print('Original position and legal moves: ')
    print(chess.Board(curr[1]))
    print(sorted(curr[2].split(' ')))
    print('---------------')
    print('Decoded position and legal moves: ')
    print(decodedBoard)
    print(decodedMoves)

#testEncodeDecode()

In [None]:
#@title Train Model

dataset = ChessDataset(chessgame_file=os.path.join(chess_dir, 'Move Selector Dataset/moveSelectorDatasetA00-A39.csv'), chess_dir=chess_dir)
moveMap = getMappingMatrix()

optim = torch.optim.SGD(moveSelector.parameters(), lr=0.0001, momentum=0.9)

epochs = 5
batchSize = 128
percentageOfDataset = 0.2

print("Number of Positions Training On: ", int(len(dataset)*percentageOfDataset))
print("Estimated Training Time: ", int(len(dataset)*percentageOfDataset*epochs/15000), " minutes")  #V1
#print("Estimated Training Time: ", int(len(dataset)*percentageOfDataset*epochs/1000), " minutes")  #V2

for e in range(epochs):
  runningLoss = 0

  # Number of batches
  for b in range(int(len(dataset)*percentageOfDataset/batchSize)):

    X = []
    labels = []

    # Make the batch
    for i in range(batchSize):
      offset = b * batchSize
      X.append(encodeSelector(dataset[i + offset][0], dataset[i + offset][1], moveMap))
      labels.append(dataset[i + offset][2])

    optim.zero_grad()

    X = torch.tensor(X).float()
    output = moveSelector(X)

    loss = 0
    for i in range(batchSize):
      #expectedOutput = moveLabel(labels[i], moveMap)
      #loss += getMoveLoss(output[i], expectedOutput, 3)
      loss += getMoveLossV1(output[i], labels[i], moveMap, exponent=8)

    loss.backward()
    runningLoss += loss.item()
    optim.step()

  else:
    print(e, runningLoss/len(dataset)*percentageOfDataset)


Number of Positions Training On:  100530
Estimated Training Time:  26  minutes
0 0.0001935178972451237
1 9.729932333386392e-05
2 6.255352399962323e-05
3 4.826813401632592e-05


In [None]:
#@title Save Model

newModelFile = 'MoveSelectorV1.6.pt'
torch.save(moveSelector, os.path.join(chess_dir, newModelFile))

In [62]:
#@title Evaluate Model
moveMap = getMappingMatrix()

# Comment this line out to test on same dataset as trained on.
dataset = ChessDataset(chessgame_file=os.path.join(chess_dir, 'Move Selector Dataset/moveSelectorDatasetD70-D99.csv'), chess_dir=chess_dir)
batchSize = 2

X = []
labels = []
for i in range(batchSize):
  X.append(encodeSelector(dataset[i][0], dataset[i][1], moveMap))
  labels.append(dataset[i][2])

X = torch.tensor(X).float()
# print(X)
testData = moveSelector(X)
# print(testData)

rankings = []
for i in range(batchSize):
  index = dict(decodeLegalMoves(X[i][-NUM_MOVE_NODES:], testData[i], moveMap))[labels[i]]
  rankings.append(sorted(testData[i].tolist(), reverse=True).index(index))
print(rankings)

dataPoint = 0
torch.set_printoptions(threshold=torch.inf)
#print(testData[dataPoint])
print(labels[dataPoint])
print(decodeLegalMoves(X[dataPoint][-NUM_MOVE_NODES:], testData[dataPoint], moveMap))
print(sorted(testData[dataPoint].tolist(), reverse=True))

[0, 6]
c2c4
[('a2a3', 0.9214321970939636), ('a2a4', 0.9685519337654114), ('b1a3', 0.5036048293113708), ('b1c3', 0.9945598244667053), ('b1d2', 0.9852638244628906), ('b2b3', 0.9973764419555664), ('b2b4', 0.7191117405891418), ('c1d2', 0.8265488743782043), ('c1e3', 0.6596152186393738), ('c1f4', 0.9612470865249634), ('c1g5', 0.8269076943397522), ('c1h6', 0.9473519325256348), ('c2c3', 0.997412383556366), ('c2c4', 0.9999884366989136), ('d1d2', 0.5791128873825073), ('d1d3', 0.9726913571357727), ('d4d5', 0.9395662546157837), ('e1d2', 0.5773073434829712), ('e2e3', 0.977073073387146), ('e2e4', 0.996885359287262), ('f2f3', 0.9989809393882751), ('f2f4', 0.9511043429374695), ('g1f3', 0.9952097535133362), ('g1h3', 0.9862961769104004), ('g2g3', 0.9965524673461914), ('g2g4', 0.9899294972419739), ('h2h3', 0.8641778230667114), ('h2h4', 0.843379020690918)]
[0.9999884366989136, 0.9989809393882751, 0.997412383556366, 0.9973764419555664, 0.996885359287262, 0.9965524673461914, 0.9952097535133362, 0.9945598244

In [91]:
#@title Play Against Selector (Testing)
import random

def sumOutput(input):
  sum = 0.0
  for i in input:
    sum += i[1]
  return sum

def getSelectorMove(board):
  boardFEN = board.fen().split(' ')[0]
  moves = ' '.join([move.uci() for move in board.legal_moves])

  encodedMoves = [[]]
  encodedMoves[0] = encodeSelector(boardFEN, moves, moveMap)
  testBatchSize = 127
  for i in range(testBatchSize):
    encodedMoves.append(encodeSelector(dataset[i][0], dataset[i][1], moveMap))

  X = torch.tensor(encodedMoves).float()
  torch.set_printoptions(threshold=torch.inf)
  moveSelector.eval()
  with torch.no_grad():
    modelMoves = moveSelector(X)
  decodedMoves = decodeLegalMoves(X[0][-NUM_MOVE_NODES:], modelMoves[0], moveMap)
  print(sumOutput(decodedMoves))
  decodedMoves = sorted(decodedMoves, key=lambda x: x[1], reverse=True)
  print(decodedMoves)
  # random.shuffle(decodedMoves)
  # print(decodedMoves)
  return decodedMoves[0][0]

def get_user_move(board):
    while True:
        try:
            move_str = input("Enter move in algebraic notation (e.g., e2e4): ")
            move = board.parse_uci(move_str)
            if move in board.legal_moves:
              return move
            else:
              print("Illegal move, try again.")
        except ValueError:
            print("Invalid move format. Please use algebraic notation (e.g., e2e4).")
        except chess.InvalidMoveError:
            print("Invalid move, try again.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}. Please try again.")

playAsWhite = True
moveMap = getMappingMatrix()
board = chess.Board()
print('---------------')
print(board)
# encodeSelector(dataset[i][0], dataset[i][1], moveMap)
# print(board.fen())
if playAsWhite:
  userMove = get_user_move(board)
  board.push(userMove)
  print(str(userMove) + ' played')
  print('---------------')
  print(board)
while True:
  move = getSelectorMove(board)
  board.push_san(move)
  print(str(move) + ' played')
  print('---------------')
  print(board)
  if board.is_game_over():
    break
  userMove = get_user_move(board)
  board.push(userMove)
  print(str(userMove) + ' played')
  print('---------------')
  print(board)
  if board.is_game_over():
    break

---------------
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
Enter move in algebraic notation (e.g., e2e4): c2c4
c2c4 played
---------------
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . P . . . . .
. . . . . . . .
P P . P P P P P
R N B Q K B N R
0.9706582824358065
[('e7e5', 0.5040971040725708), ('g8f6', 0.09700669348239899), ('c7c5', 0.06723078340291977), ('g7g6', 0.061334580183029175), ('g8h6', 0.05843100696802139), ('e7e6', 0.0542818084359169), ('b8c6', 0.037450291216373444), ('f7f5', 0.028497418388724327), ('g7g5', 0.025492530316114426), ('c7c6', 0.01556718721985817), ('d7d6', 0.013024883344769478), ('d7d5', 0.007901963777840137), ('b7b6', 0.00011554310185601935), ('b7b5', 6.667120032943785e-05), ('h7h6', 5.530185444513336e-05), ('a7a6', 4.3269206798868254e-05), ('f7f6', 2.908816713897977e-05), ('h7h5', 1.5098282347025815e-05), ('a7a5', 1.3518989362637512e-05), ('b8a6', 3.540824

KeyboardInterrupt: Interrupted by user