In [1]:
import chess
import chess.svg

import torch
import json
import chess
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import torch
import torch.nn as nn

class ChessMovePredictor(nn.Module):
    def __init__(self, num_moves):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64*4*4, 256)
        self.fc_move = nn.Linear(256, num_moves)
        self.fc_eval = nn.Linear(256, 1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc_move(x), self.fc_eval(x)

In [119]:
with open('move_idx.json', 'r') as f:
    move2idx = json.load(f)

idx2move = {i:m for m,i in move2idx.items()}

model = ChessMovePredictor(num_moves=len(move2idx))
model.load_state_dict(torch.load("chess-project/chess_model.pth", map_location="cpu"))
model.eval()


ChessMovePredictor(
  (conv1): Conv2d(13, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=256, bias=True)
  (fc_move): Linear(in_features=256, out_features=1791, bias=True)
  (fc_eval): Linear(in_features=256, out_features=1, bias=True)
)

In [122]:
piece_map = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5,
             'p':6,'n':7,'b':8,'r':9,'q':10,'k':11}

def fen_to_tensor(fen):
    board = chess.Board(fen)
    mat = np.zeros((8,8,12), dtype=np.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            row = 7 - chess.square_rank(square)
            col = chess.square_file(square)
            mat[row, col, piece_map[piece.symbol()]] = 1.0
    turn_channel = np.full((8,8,1), int(board.turn), dtype=np.float32)
    mat = np.concatenate([mat, turn_channel], axis=-1)
    return mat

def choose_best_move(board):
    model.eval()
    with torch.no_grad():
        temp_board = board.copy()

        board_tensor = torch.tensor(fen_to_tensor(temp_board.fen()), dtype=torch.float32)
        board_tensor = board_tensor.permute(2, 0, 1).unsqueeze(0).to(device)

        logits_move, pred_eval = model(board_tensor)
        probs = torch.softmax(logits_move, dim=1).cpu().numpy().flatten()

        legal_moves = list(board.legal_moves)
        best_move = legal_moves[0]
        best_score = -float('inf')

        total_moves = list(move2idx.keys())
        flag = False
        for move in legal_moves:
            move_str = str(move)

            if move_str in move2idx:
                move_score = probs[move2idx[move_str]]
                if move_score > best_score:
                    # print("best_score:", best_score, "move_score:", move_score)
                    best_score = move_score
                    best_move = move
        return best_move


In [123]:
def help_func(play):
    if play == "h":
        print("*chess game explained*")
    if play == "m":
        print("|".join([move.uci() for move in board.legal_moves]))

In [124]:
board = chess.Board()
count = 0
while not board.is_game_over():
    print(board)
    move = choose_best_move(board)
    # print("choosen move:", move)
    board.push(move)
    print("_____________")
    count+=1
    # if count>3:
    #     break
print("game over:", board.result())


r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
_____________
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . N . .
P P P P P P P P
R N B Q K B . R
_____________
r n b q k b . r
p p p p p p p p
. . . . . n . .
. . . . . . . .
. . . . . . . .
. . . . . N . .
P P P P P P P P
R N B Q K B . R
_____________
r n b q k b . r
p p p p p p p p
. . . . . n . .
. . . . N . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B . R
_____________
r n b q k b . r
p p p p p p . p
. . . . . n p .
. . . . N . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B . R
_____________
r n b q k b . r
p p p p p p . p
. . . . . n p .
. . . . . . . .
. . N . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B . R
_____________
r n b q k b . r
p p p p . p . p
. . . . p n p .
. . . . . . . .
. . N . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B . R
_____________
r n b 

In [125]:
import pygame

# Basic Pygame setup (simplified)
pygame.init()
screen_width = 800
screen_height = 800
screen = pygame.display.set_mode((screen_width, screen_height))
pygame.display.set_caption("Python Chess")

# Game loop (simplified)
running = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
    
    # Drawing board and pieces, handling moves, etc.
    # ...

    pygame.display.flip()

pygame.quit()

ModuleNotFoundError: No module named 'pygame'

In [90]:
board = chess.Board()
count = 0
while not board.is_game_over():
    print(board)
    if board.turn == chess.WHITE:
        move = choose_move(board)
        board.push(move)
    else:
        try:
            play = input("play legal move\n------- help (h) | legal moves (m) -------")
            while play == "h" or play == "m":
                help_func(play)
                play = input("play legal move\n------- help (h) | legal moves (m) -------")
            move = chess.Move.from_uci(play)
            if move in board.legal_moves:
                board.push(move)
            else:
                print("illegal move")
        except:
            print("unexpected input")
    
    count+=1
    if count>3:
        break
print("game over:", board.result())


r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
0.16361868381500244


AssertionError: push() expects move to be pseudo-legal, but got g1f3 in rnbqkbnr/pppppppp/8/8/8/7N/PPPPPPPP/RNBQKB1R

In [12]:
board = chess.Board()
[move for move in board.legal_moves]

[Move.from_uci('g1h3'),
 Move.from_uci('g1f3'),
 Move.from_uci('b1c3'),
 Move.from_uci('b1a3'),
 Move.from_uci('h2h3'),
 Move.from_uci('g2g3'),
 Move.from_uci('f2f3'),
 Move.from_uci('e2e3'),
 Move.from_uci('d2d3'),
 Move.from_uci('c2c3'),
 Move.from_uci('b2b3'),
 Move.from_uci('a2a3'),
 Move.from_uci('h2h4'),
 Move.from_uci('g2g4'),
 Move.from_uci('f2f4'),
 Move.from_uci('e2e4'),
 Move.from_uci('d2d4'),
 Move.from_uci('c2c4'),
 Move.from_uci('b2b4'),
 Move.from_uci('a2a4')]

In [97]:
board = chess.Board()


In [101]:
a = []
a.append(board.parse_san("d4").uci())

In [102]:
a

['d2d4']