In [9]:
import chess
import torch

from chess_app.environment import ChessInterface
from chess_app.chess_exceptions import NotEnoughMovesExceptions
from re_algoritms.agents import RandomAgent, QAgent
from chess_app.dataset.loader import loaderChessPos, tanstsovVecLoader
import chess.svg
from tqdm import tqdm


In [10]:

def is_checkmate(board, state):
    board.set_fen(state)
    return board.is_checkmate()


In [11]:
from re_algoritms.dqn import GrandMasterNet
from re_algoritms.agents import DQNAgent

board = chess.Board()

lr = 1e-2
gamma = 0.96
n_epochs = 100

stalemate_reward = 0.8

policy_net = GrandMasterNet(
    board_vec_dim=68,
    moves_vec_dim=5
)
player = DQNAgent(
    lr=lr, gamma=gamma, model=policy_net
)
loader = tanstsovVecLoader(r'chess_app\dataset\result.json', r'chess_app\dataset\settings.json')
loader.createDf()
frac = 0.4
dataset = loader.getAllData(frac=frac)
fen_strings = dataset.iloc[:, 0]
dataset = dataset.iloc[:, 1:]

In [12]:
dataset

Unnamed: 0,white_king_castle,white_queen_castle,black_king_castle,black_queen_castle,a8,b8,c8,d8,e8,f8,...,g2,h2,a1,b1,c1,d1,e1,f1,g1,h1
203,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,7,8,0,0,0
797,0,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2129,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8671,0,0,0,0,9,0,0,0,0,7,...,0,0,0,0,1,0,0,0,0,3
9336,0,0,0,0,0,0,0,0,0,9,...,5,5,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2910,0,0,0,0,9,0,0,8,0,0,...,0,5,3,0,4,10,0,3,1,0
3722,0,0,0,0,0,0,10,0,7,10,...,0,0,0,0,0,4,4,0,0,0
8806,0,0,0,0,9,0,0,0,0,0,...,5,5,0,0,0,0,3,0,1,0
4975,0,0,0,0,0,0,12,0,0,0,...,0,0,0,0,0,0,0,0,4,1


In [13]:
for array, fen_string in zip(dataset.to_numpy(), fen_strings):
    # print(_)
    print(array, fen_string)
    break

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 8 0 0 0] 8/7Q/8/8/8/1K6/8/3kq3 w - - 0 1


In [14]:
import wandb

wandb.init(
    project='chess-rl',
    config={
        'epochs': n_epochs,
        'lr': lr,
        'gamma': gamma,
        'frac': frac,
        'dataset_len': len(fen_strings)
    }
)

wandb.run.name = f'BoardNet: {3}, MovesNet: {2}, OutputNet: {2}, activation: leaky-relu, lr={lr}, gamma={gamma}, shuffle, stalemate_reward={stalemate_reward}, 1 - loss'

wandb.define_metric("loss", summary="min")
wandb.define_metric("loss_epoch", summary="min")

wandb.define_metric('move_num')

wandb.define_metric("win_rate", summary="max")
wandb.define_metric("win_rate_epoch", summary="max")

wandb.define_metric("lose_rate", summary="min")
wandb.define_metric("lose_rate_epoch", summary="min")

wandb.define_metric('stalemate_rate', summary="min")
wandb.define_metric('stalemate_rate_epoch', summary="min")

wandb.watch(player.policy_net, log='all', log_freq=10)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
lose_rate,▁▄█
loss,█▁▆
move_num,▁▁█
stalemate_rate,▁▅█
win_rate,▁▆█

0,1
move_num,1


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

[]

In [15]:
def check_gradients(model):
    for params in model.parameters():
        if params.grad is not None:
            print(f'Mean: {torch.mean(params.grad.data.view(-1).detach())}')
            print(f'STD: { torch.std(params.grad.data.view(-1).detach())}')
            break

In [16]:
bc_win = 0
wh_win = 0
stockfish = ChessInterface(verbose=True, engine_path=r"./chess_app/src/stockfish_15_x64_avx2.exe")
metrics_l = []
metrics_w = []
losses = []


loss = None

player.reset_actions()
player.policy_net.train()
for epoch in range(n_epochs):
    print(f'\nEpoch num: {epoch}\n')
    count_w_l = [0, 0, 0]
    for num, (array, fen_string) in enumerate(zip(dataset.to_numpy(), fen_strings)):
        # print(f'Party №{num + 857}')
        # print(fen_string)
        # init necessary for game
        new_state = fen_string
        new_array = array

        move_num = None

        black_win = False
        white_win = False
        stalemate = False
        too_much_moves = False

        stockfish.set_board_fen(new_state)
        all_rewards = list()

        for _ in range(4):
            # player make move
            stockfish.set_board_fen(new_state)
            try:
                player_move, move_num = player.return_move(
                    new_array,
                    stockfish.get_top_steps() # todo
                )

                stockfish.player_move(player_move)
                new_state = stockfish.get_board_fen()

                all_rewards.append(0)
                # print(f'\tPlayer move: {player_move}, num of move: {move_num}')
                if is_checkmate(board, new_state):
                    white_win = True
                    break
            except NotEnoughMovesExceptions:
                # print('NotEnoughMovesExceptions')
                if len(all_rewards) < 1:
                    all_rewards.append(-1)
                else:
                    all_rewards[-1] = 1
                black_win = True
                break

            # environment make move
            try:
                machine_move = stockfish.env_move()
                new_state = stockfish.get_board_fen()
                # print(f'\tMachine move: {machine_move}')
                new_array = loader.extractPos(new_state, -1, -1)
                new_array = list(new_array.values())[3:]

                if is_checkmate(board, new_state):
                    black_win = True
                    break

            except ValueError as err:
                # print('Произошел пат!!!!!!!!!!!!!')
                stalemate = True
                count_w_l[2] += 1
                break

        else:
            # print('> 4 ходов')
            count_w_l[2] += 1
            too_much_moves = True


        if black_win is True:
            reward = -1
            count_w_l[1] += 1
        elif stalemate is True:
            reward = - stalemate_reward
        elif too_much_moves is True:
            reward = 0.2
        else:
            reward = 1
            count_w_l[0] += 1

        all_rewards[-1] = reward
        loss = player.update_policy(all_rewards)
        if num % 20 == 0:
            print(f'Num of move: {move_num}')
            print(f'Loss value: {loss}, num of party: {num}')
            print(f'Win rate: {count_w_l[0]}')
            print(f'Lose rate: {count_w_l[1]}')
            print(f'Stalemate rate: {count_w_l[2]}')

            wandb.log({
                'loss': loss,
                'move_num': move_num,
                'win_rate': count_w_l[0]/ len(fen_strings),
                'lose_rate': count_w_l[1]/ len(fen_strings),
                'stalemate_rate': count_w_l[2]/ len(fen_strings)
             })

            losses.append(loss)
    wandb.log({
        'loss_epoch': loss,
        'win_rate_epoch': count_w_l[0]/ len(fen_strings),
        'lose_rate_epoch': count_w_l[1]/ len(fen_strings),
        'stalemate_rate_epoch': count_w_l[2]/ len(fen_strings)
    })

    metrics_l.append(count_w_l[1]/len(fen_strings))
    metrics_w.append(count_w_l[0]/len(fen_strings))
    print(f'Winrate {count_w_l[0]} \t Loserate {count_w_l[1]}')
    torch.save(player.policy_net.state_dict(), 'weight.pth')



Epoch num: 0

Num of move: tensor([3])
Loss value: 5.721584320068359, num of party: 0
Win rate: 0
Lose rate: 0
Stalemate rate: 1
Num of move: tensor([0])
Loss value: 4.528450965881348, num of party: 20
Win rate: 4
Lose rate: 2
Stalemate rate: 15
Num of move: tensor([3])
Loss value: 3.1421751976013184, num of party: 40
Win rate: 6
Lose rate: 9
Stalemate rate: 26
Num of move: tensor([2])
Loss value: -0.9144361019134521, num of party: 60
Win rate: 9
Lose rate: 17
Stalemate rate: 35
Num of move: tensor([4])
Loss value: 5.824033737182617, num of party: 80
Win rate: 11
Lose rate: 23
Stalemate rate: 47
Num of move: tensor([3])
Loss value: 1.0000042915344238, num of party: 100
Win rate: 14
Lose rate: 29
Stalemate rate: 58
Num of move: tensor([3])
Loss value: 1.0000003576278687, num of party: 120
Win rate: 17
Lose rate: 31
Stalemate rate: 73
Num of move: tensor([3])
Loss value: 1.0000001192092896, num of party: 140
Win rate: 20
Lose rate: 35
Stalemate rate: 86
Num of move: tensor([3])
Loss val

KeyboardInterrupt: 

In [None]:
for param in policy_net.parameters():
    print(param.grad)