In [1]:
import chess
import torch

from chess_app.environment import ChessInterface
from chess_app.chess_exceptions import NotEnoughMovesExceptions
from re_algoritms.agents import RandomAgent, QAgent
from chess_app.dataset.loader import loaderChessPos, tanstsovVecLoader
import chess.svg
from tqdm import tqdm


In [2]:
board = chess.Board()

def is_checkmate(board, state):
    board.set_fen(state)
    return board.is_checkmate()


# Hyperparameters

In [3]:
frac = 0.01

lr = 1e-2
gamma = 0.96
n_epochs = 100
max_moves = 4

stalemate_reward = 0.1
too_much_reward = -0.6

# DQN Agent and dataset

In [4]:
from re_algoritms.dqn import GrandMasterNet
from re_algoritms.agents import DQNAgent

policy_net = GrandMasterNet(
    board_vec_dim=68,
    moves_vec_dim=5
)
player = DQNAgent(
    lr=lr, gamma=gamma, model=policy_net
)

loader = tanstsovVecLoader(r'chess_app\dataset\result.json', r'chess_app\dataset\settings.json')
loader.createDf()

dataset = loader.getAllData(frac=frac)
fen_strings = dataset.iloc[:, 0]
dataset = dataset.iloc[:, 1:]

# wandb logging

In [5]:
import wandb

wandb.init(
    project='chess-rl',
    config={
        'epochs': n_epochs,
        'lr': lr,
        'max_moves': max_moves,
        'gamma': gamma,
        'frac': frac,
        'dataset_len': len(fen_strings),
        'loss_func': '1 - torch.mean(torch.mul(self.actions, rewards))'
    }
)

wandb.run.name = f'BoardNet: {3}, MovesNet: {2}, OutputNet: {2},' \
                 f'activation: leaky-relu,' \
                 f'lr={lr},' \
                 f'gamma={gamma},' \
                 f'shuffle,' \
                 f'stalemate_reward={stalemate_reward},' \
                 f'too_much_reward={too_much_reward}'

wandb.define_metric("epochs")

wandb.define_metric("loss", summary="min")
wandb.define_metric("loss_epoch", summary="min")

wandb.define_metric('move_num')

wandb.define_metric("win_rate", summary="max")
wandb.define_metric("win_rate_epoch", summary="max")

wandb.define_metric("lose_rate", summary="min")
wandb.define_metric("lose_rate_epoch", summary="min")

wandb.define_metric('stalemate_rate', summary="min")
wandb.define_metric('stalemate_rate_epoch', summary="min")

wandb.define_metric('too_much_rate', summary="min")
wandb.define_metric('too_much_rate_epoch', summary="min")

wandb.watch(player.policy_net, log='all', log_freq=10)

[34m[1mwandb[0m: Currently logged in as: [33mcrazy_historian[0m ([33mai_community[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

# Learning (or Training)

In [6]:
bc_win = 0
wh_win = 0
stockfish = ChessInterface(verbose=True, engine_path=r"./chess_app/src/stockfish_15_x64_avx2.exe")

loss = None

player.reset_actions()
player.policy_net.train()

for epoch in range(n_epochs):
    print(f'\nEpoch num: {epoch}\n')
    count_w_l = [0, 0, 0, 0]
    for num, (array, fen_string) in enumerate(zip(dataset.to_numpy(), fen_strings)):
        # print(fen_string)
        # init necessary for game
        new_state = fen_string
        new_array = array

        move_num = None

        black_win = False
        white_win = False
        stalemate = False
        too_much_moves = False

        stockfish.set_board_fen(new_state)
        all_rewards = list()

        for _ in range(max_moves):
            # player make move
            stockfish.set_board_fen(new_state)
            try:
                player_move, move_num = player.return_move(
                    new_array,
                    stockfish.get_top_steps() # todo
                )
                wandb.log({'move_num': move_num})
                stockfish.player_move(player_move)
                new_state = stockfish.get_board_fen()

                all_rewards.append(0)
                # print(f'\tPlayer move: {player_move}, num of move: {move_num}')
                if is_checkmate(board, new_state):
                    white_win = True
                    break
            except NotEnoughMovesExceptions:
                # print('NotEnoughMovesExceptions')
                if len(all_rewards) < 1:
                    all_rewards.append(-1)
                else:
                    all_rewards[-1] = 1
                black_win = True
                break

            # environment make move
            try:
                machine_move = stockfish.env_move()
                new_state = stockfish.get_board_fen()
                # print(f'\tMachine move: {machine_move}')
                new_array = loader.extractPos(new_state, -1, -1)
                new_array = list(new_array.values())[3:]

                if is_checkmate(board, new_state):
                    black_win = True
                    break

            except ValueError as err:
                # print('Произошел пат!!!!!!!!!!!!!')
                stalemate = True
                break

        else:
            # print(f'> {max_moves}  ходов')
            too_much_moves = True


        if black_win is True:
            reward = -1
            count_w_l[1] += 1
        elif stalemate is True:
            reward = stalemate_reward
            count_w_l[2] += 1
        elif too_much_moves is True:
            reward = too_much_reward
            count_w_l[3] += 1
        else:
            reward = 1
            count_w_l[0] += 1

        all_rewards[-1] = reward
        loss = player.update_policy(all_rewards)
        # if num % 20 == 0 and num != 0:
        #     print(f'Num of move: {move_num}')
        #     print(f'Loss value: {loss}, num of party: {num}')
        #     print(f'Win rate: {count_w_l[0]}')
        #     print(f'Lose rate: {count_w_l[1]}')
        #     print(f'Stalemate rate: {count_w_l[2]}')
        #     print(f'Too much moves rate: {count_w_l[3]}')
        if num != 0:
            wandb.log({
                'loss': loss,
                'win_rate': count_w_l[0]/ num,
                'lose_rate': count_w_l[1]/ num,
                'stalemate_rate': count_w_l[2]/ num,
                'too_much_rate': count_w_l[3]/ num,
             })
    wandb.log({
        'epochs': n_epochs,
        'loss_epoch': loss,
        'win_rate_epoch': count_w_l[0]/ len(fen_strings),
        'lose_rate_epoch': count_w_l[1]/ len(fen_strings),
        'stalemate_rate_epoch': count_w_l[2]/ len(fen_strings),
        'too_much_rate_epoch': count_w_l[3]/ len(fen_strings),
    })
    torch.save(player.policy_net.state_dict(), 'weight_2.pth')



Epoch num: 0


Epoch num: 1


Epoch num: 2


Epoch num: 3


Epoch num: 4


Epoch num: 5


Epoch num: 6


Epoch num: 7


Epoch num: 8


Epoch num: 9


Epoch num: 10


Epoch num: 11


Epoch num: 12


Epoch num: 13


Epoch num: 14


Epoch num: 15


Epoch num: 16


Epoch num: 17


Epoch num: 18


Epoch num: 19


Epoch num: 20


Epoch num: 21


Epoch num: 22


Epoch num: 23


Epoch num: 24


Epoch num: 25


Epoch num: 26


Epoch num: 27


Epoch num: 28


Epoch num: 29


Epoch num: 30


Epoch num: 31


Epoch num: 32


Epoch num: 33


Epoch num: 34


Epoch num: 35


Epoch num: 36


Epoch num: 37


Epoch num: 38


Epoch num: 39


Epoch num: 40


Epoch num: 41


Epoch num: 42


Epoch num: 43


Epoch num: 44


Epoch num: 45


Epoch num: 46


Epoch num: 47


Epoch num: 48


Epoch num: 49


Epoch num: 50


Epoch num: 51


Epoch num: 52


Epoch num: 53


Epoch num: 54


Epoch num: 55


Epoch num: 56


Epoch num: 57


Epoch num: 58


Epoch num: 59


Epoch num: 60


Epoch num: 61


Epoch num: 62


E

KeyboardInterrupt: 