In [229]:
# Imports
from time import perf_counter as ttt
import numpy as np
import pickle
from numba import njit, jit
from Chess_Pieces import Pawn

from IPython.display import clear_output

In [230]:
def add(coor1, coor2) -> list:
    return [coor1[0]+coor2[0], coor1[1]+coor2[1]]

def sign(x:int) -> int:
    return 1 if x < 0 else 1

In [231]:
legal_moves = {
    (0, 0): [],
    (0, 1): [(2, 0), (2, 2)],
    (0, 2): [],
    (0, 3): [],
    (0, 4): [],
    (0, 5): [],
    (0, 6): [(2, 5), (2, 7)],
    (0, 7): [],
    (1, 0): [(2, 0), (3, 0)],
    (1, 1): [(2, 1), (3, 1)],
    (1, 2): [(2, 2), (3, 2)],
    (1, 3): [(2, 3), (3, 3)],
    (1, 4): [(2, 4), (3, 4)],
    (1, 5): [(2, 5), (3, 5)],
    (1, 6): [(2, 6), (3, 6)],
    (1, 7): [(2, 7), (3, 7)],
    (6, 0): [(4, 0, (5, 0))],
    (6, 1): [(4, 1, (5, 1))],
    (6, 2): [(4, 2, (5, 2))],
    (6, 3): [(4, 3, (5, 3))],
    (6, 4): [(4, 4, (5, 4))],
    (6, 5): [(4, 5, (5, 5))],
    (6, 6): [(4, 6, (5, 6))],
    (6, 7): [(4, 7, (5, 7))],
    (7, 0): [],
    (7, 1): [(5, 0), (5, 2)],
    (7, 2): [],
    (7, 3): [],
    (7, 4): [],
    (7, 5): [],
    (7, 6): [(5, 5), (5, 7)],
    (7, 7): [],
}

In [232]:
class Piece:

    def __init__(self, board, coor, piece_index, castle=None):
            # castle format [bool, bool] -- queen, king
        self.board_obj = board
        self.coor = coor
        self.p_index = piece_index
        self.side = np.sign(piece_index)
        if castle != None:
            self.castle = castle
        self.gen_move_range()
        
    def gen_move_range(self):
        with open('move_range.pkl', 'rb') as f:
            move_range = pickle.load(f)
        self.move_range = move_range[self.coor][self.p_index]

In [233]:
class Pawn1(Piece):
    # move_range: [default_move, first_jump, left_capture, right_capture]
    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side)

        # selecting best function from different versions
        self.gen_legal = self.gen_prelegal3

    def gen_legal(self):
        pass


    # does not work. do not run
    @jit
    def gen_prelegal1(self):
        move_range = np.array(self.move_range)
        mask = np.full(4, False, dtype=np.bool_)
        if self.board_obj.get_grid_side(move_range[0]) == None:
            mask[0] = True
        if len(move_range) == 4 and self.board_obj.get_grid_side(move_range[1]) == None:
            mask[1] = True
        if self.board_obj.get_grid_side(move_range[-2]) == -self.side:
            mask[2] = True
        if self.board_obj.get_grid_side(move_range[-1]) == -self.side:
            mask[3] = True

        return move_range[mask]

    # slower than gen_prelegal3
    def gen_prelegal2(self):
        move_range = np.array(self.move_range)
        n_moves = len(move_range)
        mask = np.full(n_moves, False, dtype=np.bool_)
        if self.board_obj.get_grid_side(move_range[0]) == None:
            mask[0] = True
        if n_moves == 4 and self.board_obj.get_grid_side(move_range[1]) == None:
            mask[1] = True
        if self.board_obj.get_grid_side(move_range[-2]) == -self.side:
            mask[-2] = True
        if self.board_obj.get_grid_side(move_range[-1]) == -self.side:
            mask[-1] = True

        return move_range[mask]
        
    def gen_prelegal3(self):
            # Average time per run: 0.11856174400076269ms
        dest = []
        # default move
        if self.board_obj.get_grid_side(self.move_range[0]) == None:
            dest.append(self.move_range[0])
        # first jump
        if len(self.move_range) == 4 and self.board_obj.get_grid_side(self.move_range[1]) == None:
            dest.append(self.move_range[1])
        # left capture
        if self.board_obj.get_grid_side(self.move_range[-2]) == -self.side:
            dest.append(self.move_range[-2])
        # right capture
        if self.board_obj.get_grid_side(self.move_range[-1]) == -self.side:
            dest.append(self.move_range[-1])
        return dest

    def gen_prelegal4(self):
            # Average time per run: 0.3493830720009282ms
        n_moves = len(self.move_range)
        mask = [False]*n_moves
        # default move
        if self.board_obj.get_grid_side(self.move_range[0]) == None:
            mask[0] = True
        # first jump
        if len(self.move_range) == 4 and self.board_obj.get_grid_side(self.move_range[1]) == None:
            mask[1] = True
        # left capture
        if self.board_obj.get_grid_side(self.move_range[-2]) == -self.side:
            mask[-2] = True
        # right capture
        if self.board_obj.get_grid_side(self.move_range[-1]) == -self.side:
            mask[-1] = True
        return np.array(self.move_range)[mask]


In [234]:
class King(Piece):
    # move_range: clockwise starting topleft
    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side*6)

    def gen_prelegal1(self):
        destinations = []
        for dest in self.move_range:
            occupant = self.board_obj.get_grid_side(dest)
            if occupant == 0 or occupant == -self.side:
                destinations.append(dest)
        return destinations

In [235]:
class Knight1(Piece):
    # move_range: clockwise starting topleft(-2, -1)
    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side*2)

    def gen_prelegal1(self):
        destinations = []
        for dest in self.move_range:
            occupant = self.board_obj.get_grid_side(dest)
            if occupant == 0 or occupant == -self.side:
                destinations.append(dest)
        return destinations

In [236]:
class SeqPiece(Piece):
    # pieces that move sequentially: bishop/rook/queen
    # move_range: clockwise starting topleft
    def prelegal_direc(self, direction):
        destinations = []
        for dest in self.move_range[direction]:
            occupant = self.board_obj.get_grid_side(dest)
            if occupant == 0:
                destinations.append(dest)
            elif occupant == -self.side:
                destinations.append(dest)
                break
            else:
                break
        return destinations

    def gen_prelegal(self):
        destinations = []
        for direction in self.direcs:
            destinations += self.prelegal_direc(direction)
        return destinations

In [237]:
class Bishop(SeqPiece):
    direcs = [(-1, -1), (-1, 1), (1, 1), (1, -1)]

    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side*3)

    # gen_legal() is inherited from SeqPiece

In [238]:
class Rook(SeqPiece):
    direcs = [(-1, 0), (0, 1), (1, 0), (0, -1)]

    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side*3)
    
    # gen_legal() is inherited from SeqPiece

In [239]:
class Queen(SeqPiece):
    direcs = [(-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1)]

    def __init__(self, board, coor:tuple, side:int):
        super().__init__(board, coor, side*3)
    
    # gen_legal() is inherited from SeqPiece

In [240]:
class EmptyCell:
    pass

In [241]:
class Board:
    legal_moves = {}  # should not be class attr, should be instance attr
    all_legal_dest = []

    def __init__(self, board_arr):
        self.board_arr = board_arr
        self.gen_pieces()
    
    def gen_pieces(self):
        self.grid = {}
        self.king_coors = {}
        
        for row in range(8):
            for col in range(8):
                coor = (row, col)
                piece = self.board_arr[coor]
                # print(coor)
                # print(piece)
                if abs(piece) == 1:
                    self.grid[coor] = Pawn(self, coor, piece)
                elif abs(piece) == 2:
                    self.grid[coor] = Knight(self, coor, int(piece/2))
                elif abs(piece) == 3:
                    self.grid[coor] = Bishop(self, coor, int(piece/3))
                elif abs(piece) == 4:
                    self.grid[coor] = Rook(self, coor, int(piece/4))
                elif abs(piece) == 5:
                    self.grid[coor] = Queen(self, coor, int(piece/5))
                elif abs(piece) == 6:
                    self.grid[coor] = King(self, coor, int(piece/6))
                    self.king_coors[piece] = coor
                else:
                    self.grid[coor] = EmptyCell()

    def get_grid_side(self, coor):
        try:
            return self.grid[coor].side
        except:
            pass

    def get_grid_obj(self, coor):
        try:
            return self.grid[coor]
        except:
            pass

    def gen_all_legal(self):
        pass

    def is_check(self, side):
        # get ally king coor
        king_coor = self.king_coors[side*6]
        # check if all opponent prelegal_moves includes king position

In [242]:
class Tree:
    pass

In [243]:
# pawn
# king
# knight
# bishop/rook/queen

# dict of all the coor:pieces with this coor in their move range

# class of empty cell with reference to pieces that can move to that cell?

# --> is_check function
# --> is_under_attack takes in prelegal move - yes. use list of prelegal moves of other pieces. 

board needs to be a dataclass: only holds information
cells also need to hold information

each board will have 64 cells. every cell must have info:  what legal-invaders can enter this cell (for check)
cell must be one out of 7 types of cell: [empty, pawn, knight, bishop, rook, queen, king]
pieces will have additional info: move_range, prelegal_dest, legal_dest
king will have additional info: castle, check?

when a move is made, 2 cells are directly involved: start_cell and end_cell. 
  - start cell becomes empty
  - the prelegal-invaders and legal-invaders of start and end cell should be updated.
  - if a piece is a seq prelegal-invader of start/end cell, its moves should be updated
  - if a pawn_move range includes a start/end cell, its moves should be updated



-- if king is in prelegal_dest of opponent piece, the king is in check
-- add cell info: affected cells ->
    cells which are blocked by this cell - allies (for seq move update)
    cells which have this cell in their prelegal_dest

steps:
1. prelegal move generation:
- generate all prelegal moves for all pieces --> collision, capture and offboard
- update enemy king check status if prelegal move for a piece includes the enemy king

2. legal move pruning
- perform pseudomove 
- check if the move reveals an attack that places the ally king in check
- reverse pseudomove

3. occupancy of cell
- return the side of the piece in cell
- indicate if the cell is empty
- indicate if the cell is offboard

4. is_check
- calculates if the king is placed in check for a given position

# Testing

In [244]:
# variables
ini_board = np.array([
    [-4,-2,-3,-5,-6,-3,-2,-4],
    [-1,-1,-1,-1,-1,-1,-1,-1],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 1, 1, 1, 1, 1, 1, 1, 1],
    [ 4, 2, 3, 5, 6, 3, 2, 4], 
])

empty_board = np.array([
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0],
    [ 0, 0, 0, 0, 0, 0, 0, 0]
])

In [245]:
# Speed Test
pawn_coors = [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7)]
test_board = Board(ini_board)
n_iter = 100000

start = ttt()
for _ in range(n_iter):
    for pawn_coor in pawn_coors:
        test_board.grid[pawn_coor].gen_prelegal2()
end = ttt()

print(f'gen_prelegal2() for Pawn took {end-start} seconds to run {n_iter*len(pawn_coors)} iterations\nAverage time per run: {(end-start)/n_iter*len(pawn_coors)*1000}ms')

start = ttt()
for _ in range(10000):
    for pawn_coor in pawn_coors:
        test_board.grid[pawn_coor].gen_prelegal3()
end = ttt()

print(f'gen_prelegal3() for Pawn took {end-start} seconds to run {n_iter*len(pawn_coors)} iterations\nAverage time per run: {(end-start)/n_iter*len(pawn_coors)*1000}ms')

start = ttt()
for _ in range(10000):
    for pawn_coor in pawn_coors:
        test_board.grid[pawn_coor].gen_prelegal4()
end = ttt()

print(f'gen_prelegal4() for Pawn took {end-start} seconds to run {n_iter*len(pawn_coors)} iterations\nAverage time per run: {(end-start)/n_iter*len(pawn_coors)*1000}ms')

gen_prelegal2() for Pawn took 25.32106719999865 seconds to run 1600000 iterations
Average time per run: 4.051370751999784ms
gen_prelegal3() for Pawn took 0.7410109000047669 seconds to run 1600000 iterations
Average time per run: 0.11856174400076269ms
gen_prelegal4() for Pawn took 2.1836442000058014 seconds to run 1600000 iterations
Average time per run: 0.3493830720009282ms
