In [1]:
import chess
import torch

from chess_app.environment import ChessInterface
from chess_app.chess_exceptions import NotEnoughMovesExceptions
from re_algoritms.agents import RandomAgent, QAgent
from chess_app.dataset.loader import loaderChessPos, tanstsovVecLoader
import chess.svg
from tqdm import tqdm


In [2]:

def is_checkmate(board, state):
    board.set_fen(state)
    return board.is_checkmate()


In [3]:
from re_algoritms.dqn import GrandMasterNet
from re_algoritms.agents import DQNAgent

board = chess.Board()

lr = 1e-2
gamma = 0.7
n_epochs = 100

policy_net = GrandMasterNet(
    board_vec_dim=68,
    moves_vec_dim=5
)
player = DQNAgent(
    lr=lr, gamma=gamma, model=policy_net
)
loader = tanstsovVecLoader(r'chess_app\dataset\result.json', r'chess_app\dataset\settings.json')
loader.createDf()
frac = 1.0
dataset = loader.getAllData(frac=frac)
fen_strings = dataset.iloc[:, 0]
dataset = dataset.iloc[:, 1:]

In [4]:
dataset

Unnamed: 0,white_king_castle,white_queen_castle,black_king_castle,black_queen_castle,a8,b8,c8,d8,e8,f8,...,g2,h2,a1,b1,c1,d1,e1,f1,g1,h1
4441,0,0,1,1,9,0,10,8,7,10,...,2,0,0,0,0,0,1,0,0,0
7022,0,0,0,0,0,0,8,0,10,0,...,5,5,0,0,0,0,0,0,1,0
479,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4504,0,0,0,0,3,0,0,0,0,2,...,0,0,0,1,0,0,0,0,0,0
2456,0,0,0,0,0,10,0,4,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9228,0,0,0,0,9,0,10,0,0,9,...,5,5,0,0,0,0,0,3,1,0
1411,1,1,0,0,9,0,10,8,0,9,...,5,0,3,0,4,0,1,0,0,3
5544,0,0,0,0,0,0,0,9,0,0,...,5,5,0,0,0,0,3,0,1,0
7396,0,0,0,0,0,0,0,0,0,4,...,0,5,0,0,0,0,0,0,0,7


In [5]:
for array, fen_string in zip(dataset.to_numpy(), fen_strings):
    # print(_)
    print(array, fen_string)
    break

[ 0  0  1  1  9  0 10  8  7 10 12  9 11 11 12 11 11 11 11  0  0  0  0 11
  0  0  0 11  0  0  0  0  0  0  0  4  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  1  0  0  0] r1bqkbnr/ppnpppp1/3p3p/7B/8/8/6Q1/4K3 w kq - 0 1


In [6]:
# import wandb
#
# wandb.init(
#     project='chess-rl',
#     config={
#         'epochs': n_epochs,
#         'lr': lr,
#         'gamma': gamma,
#         'frac': frac
#     }
# )
#
# # wandb.run.name = f'BoardNet: {3}, MovesNet: {2}, OutputNet: {2}, activation: leaky-relu, lr={lr}, gamma={gamma}'
#
# wandb.define_metric("loss", summary="min")
# wandb.define_metric("win_rate", summary="max")
# wandb.define_metric("lose_rate", summary="min")
# wandb.define_metric('stalemate_rate', summary="min")
# wandb.watch(player.policy_net, log='all', log_freq=10)

In [7]:
def check_gradients(model):
    for params in model.parameters():
        if params.grad is not None:
            print(f'Mean: {torch.mean(params.grad.data.view(-1).detach())}')
            print(f'STD: { torch.std(params.grad.data.view(-1).detach())}')
            break

In [8]:
bc_win = 0
wh_win = 0
stockfish = ChessInterface(verbose=True, engine_path=r"./chess_app/src/stockfish_15_x64_avx2.exe")
metrics_l = []
metrics_w = []
losses = []

epoch_slice = 10

loss = None

player.reset_actions()
player.policy_net.train()
for epoch in range(n_epochs):
    print(f'\nEpoch num: {epoch}\n')
    count_w_l = [0, 0, 0]
    for num, (array, fen_string) in enumerate(zip(dataset.to_numpy()[:epoch_slice], fen_strings[:epoch_slice])):
        # print(f'Party №{num + 857}')
        # print(fen_string)
        # init necessary for game
        new_state = fen_string
        new_array = array

        black_win = False
        white_win = False
        stalemate = False
        too_much_moves = False

        stockfish.set_board_fen(new_state)
        all_rewards = list()

        for _ in range(4):
            # player make move
            stockfish.set_board_fen(new_state)
            try:
                player_move, move_num = player.return_move(
                    new_array,
                    stockfish.get_top_steps() # todo
                )

                stockfish.player_move(player_move)
                new_state = stockfish.get_board_fen()

                all_rewards.append(0)
                # print(f'\tPlayer move: {player_move}, num of move: {move_num}')
                if is_checkmate(board, new_state):
                    white_win = True
                    break
            except NotEnoughMovesExceptions:
                # print('NotEnoughMovesExceptions')
                if len(all_rewards) < 1:
                    all_rewards.append(-1)
                else:
                    all_rewards[-1] = 1
                black_win = True
                break

            # environment make move
            try:
                machine_move = stockfish.env_move()
                new_state = stockfish.get_board_fen()
                # print(f'\tMachine move: {machine_move}')
                new_array = loader.extractPos(new_state, -1, -1)
                new_array = list(new_array.values())[3:]

                if is_checkmate(board, new_state):
                    black_win = True
                    break

            except ValueError as err:
                # print('Произошел пат!!!!!!!!!!!!!')
                stalemate = True
                count_w_l[2] += 1
                break

        else:
            # print('> 4 ходов')
            count_w_l[2] += 1
            too_much_moves = True


        if black_win is True:
            reward = -1
            count_w_l[1] += 1
        elif stalemate is True:
            reward = - 0.1
        elif too_much_moves is True:
            reward = 0.2
        else:
            reward = 1
            count_w_l[0] += 1

        all_rewards[-1] = reward
        loss = player.update_policy(all_rewards)
        if num % 2 == 0:
            # print(f'Loss value: {loss}, num of party: {num}')
            # print(f'Win rate: {count_w_l[0]}')
            # print(f'Lose rate: {count_w_l[1]}')
            # print(f'Stalemate rate: {count_w_l[2]}')
            losses.append(loss)
            # check_gradients(player.policy_net)
    # wandb.log({
    #         'loss': loss,
    #         'win_rate': count_w_l[0]/epoch_slice,
    #         'lose_rate': count_w_l[1]/epoch_slice,
    #         'stalemate_rate': count_w_l[2]/epoch_slice
    #     })
    print(f'Loss value: {loss}, num of party: {num}')
    print(f'Win rate: {count_w_l[0]}')
    print(f'Lose rate: {count_w_l[1]}')
    print(f'Stalemate rate: {count_w_l[2]}')

    metrics_l.append(count_w_l[1]/len(fen_strings))
    metrics_w.append(count_w_l[0]/len(fen_strings))
    print(f'Winrate {count_w_l[0]} \t Loserate {count_w_l[1]}')
    torch.save(player.policy_net.state_dict(), 'weight.pth')



Epoch num: 0



  for num, (array, fen_string) in enumerate(zip(dataset.to_numpy()[:epoch_slice], fen_strings[:epoch_slice])):


Loss value: 2.0867767333984375, num of party: 9
Win rate: 2
Lose rate: 2
Stalemate rate: 6
Winrate 2 	 Loserate 2

Epoch num: 1

Loss value: 0.3438456356525421, num of party: 9
Win rate: 2
Lose rate: 3
Stalemate rate: 5
Winrate 2 	 Loserate 3

Epoch num: 2

Loss value: 0.00017084150749724358, num of party: 9
Win rate: 10
Lose rate: 0
Stalemate rate: 0
Winrate 10 	 Loserate 0

Epoch num: 3

Loss value: 5.960466182841628e-07, num of party: 9
Win rate: 10
Lose rate: 0
Stalemate rate: 0
Winrate 10 	 Loserate 0

Epoch num: 4

Loss value: 1.1920930376163597e-07, num of party: 9
Win rate: 10
Lose rate: 0
Stalemate rate: 0
Winrate 10 	 Loserate 0

Epoch num: 5

Loss value: 1.1920930376163597e-07, num of party: 9
Win rate: 10
Lose rate: 0
Stalemate rate: 0
Winrate 10 	 Loserate 0

Epoch num: 6

Loss value: 1.1920930376163597e-07, num of party: 9
Win rate: 10
Lose rate: 0
Stalemate rate: 0
Winrate 10 	 Loserate 0

Epoch num: 7

Loss value: 1.1920930376163597e-07, num of party: 9
Win rate: 10
Los

KeyboardInterrupt: 

In [None]:
for param in policy_net.parameters():
    print(param.grad)