In [268]:
import os, sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import importlib
import ChessEngine
import Minimax.SmartMoveFinder as SmartMoveFinder
import Minimax.Evaluate as Evaluate
import ChessEnv
import replaybuffer
import network
import copy 
import random
import torch
import torch.nn as nn
import tqdm
from tqdm import trange

importlib.reload(ChessEngine)
importlib.reload(ChessEnv)
importlib.reload(replaybuffer)
importlib.reload(network)

<module 'network' from 'd:\\User\\ProjectGithub\\hiepnguyenn-99\\Chess\\RL\\network.py'>

In [269]:
env = ChessEnv.Env()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

q_net = network.DQN(env.action_size).to(device)

# load mô hình lưu
model_path = 'DQN.pth'
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    q_net.load_state_dict(checkpoint)
    q_net.train()
    print("Đã load model")
else:
    print("Không tìm thấy model")

target_net = copy.deepcopy(q_net).to(device)
target_net.eval()
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-4)
criterion  = nn.MSELoss(reduction='mean')

cuda
Đã load model


In [270]:
env = ChessEnv.Env()
capacity = 1
rb = replaybuffer.ReplayBuffer(capacity)
batch_size = 1
epsilon = 1.0
epsilon_final = 0.01
epsilon_decay = 0.995
gamma = 0.9
step = 0
target_update_freq = 1
env.gs.whiteToMove = True # train trắng đi trước (minimax)

for episode in trange(1, desc="Episodes"):
    done = False
    state = env.reset()
    while not done:
        # nước đi của minimax
        white_capture = None # kiểm tra minimax có ăn quân của rl không
        if env.gs.whiteToMove:
            if Evaluate.check_mid_game(env.gs):
                SmartMoveFinder.DEPTH = 4
            else:
                SmartMoveFinder.DEPTH = 3
            MinimaxMove = SmartMoveFinder.findBestMinimaxMove(env.gs, env.gs.getValidMoves())
            if MinimaxMove is None:
                MinimaxMove = SmartMoveFinder.findRandomMove(env.gs.getValidMoves())
            print(f'white {env.gs.whiteToMove} move {MinimaxMove.moveID}')
            if MinimaxMove.pieceCaptured != '--':
                white_capture = MinimaxMove.pieceCaptured[1]
                print(f'white captured {white_capture}')
            env.gs.makeMove(MinimaxMove)

        state_tensor = env.state_to_tensor()
        lenlegalmove = len(env.gs.getValidMoves())
        print(f'len move black {len(env.gs.getValidMoves())}')
        # khám phá
        if lenlegalmove != 0:    
            if random.random() < epsilon:
                move = random.choice(env.gs.getValidMoves())
                action = env.moveid_to_index[move.moveID]
                # print(f'white {env.gs.whiteToMove} move {move.moveID}')

            # khai thác
            else:
                q_value = q_net(state_tensor.unsqueeze(0).to(device)) 
                action = q_value.argmax()
                    
        next_state_tensor, reward, done, before_legal_mask, after_legal_mask = env.step(action, white_capture, lenlegalmove) # đen đã đi, kiểm tra đen bị ăn quân ko
        rb.push(state_tensor, action, reward, next_state_tensor, done, before_legal_mask, after_legal_mask)

        if rb.__len__() >= batch_size:
            batch = rb.sample(batch_size)
            # chuyển về tensor
            states, actions, reward, next_states, dones, before_legal_masks, after_legal_mask= zip(*batch)
            states = torch.stack([s.to(device) for s in states])
            actions = torch.tensor(actions, device=device, dtype=torch.int64).unsqueeze(1) # (B, 1)
            reward = torch.tensor(reward, device=device, dtype=torch.float32).unsqueeze(1)
            next_states = torch.stack([ns.to(device) for ns in next_states])
            before_legal_masks = torch.stack([b.to(device) for b in before_legal_masks])
            after_legal_mask = torch.stack([a.to(device) for a in after_legal_mask])
            dones = torch.tensor(dones, device=device, dtype=torch.float32).unsqueeze(1)

            with torch.no_grad():
                next_q_values = target_net(next_states) # (B, action_size)
                next_q_values[~after_legal_mask] = -torch.inf
                # print(f'legal_masks[78] {after_legal_mask[0, 78].item()}')
                next_q_max = next_q_values.max(1)[0].unsqueeze(1) # (B, 1)

                # nếu không còn nước đi hợp lệ của next_states thì đặt bằng 0
                legal_exists = after_legal_mask.any(dim=1, keepdim=True)
                next_q_max = torch.where(legal_exists, next_q_max, torch.zeros_like(next_q_max))

                q_target = reward if done else reward + gamma * next_q_max # (B, 1)

            q_values = q_net(states.to(device)) # (B, action_size)
            q_values[~before_legal_masks] = -torch.inf
            q_value = q_values.gather(dim=1, index=actions) # tính q value với hành động đã chọn

            print (f'type states {type(states)}, shape {states.shape}')
            print (f'type actions {type(actions)}, shape {actions.shape}')
            print (f'actions {actions}')
            print (f'type reward {type(reward)}, shape {reward.shape}')
            print (f'reward {reward}')
            print (f'type next_states  {type(next_states)}, shape {next_states.shape}')
            print (f'type before_legal_masks {type(before_legal_masks)}, shape {before_legal_masks.shape}')
            print (f'before_legal_masks {before_legal_masks}')
            print (f'type next_legal_masks {type(after_legal_mask)}, shape {after_legal_mask.shape}')
            print (f'next_legal_masks {after_legal_mask}')
            print (f'type dones {type(dones)}, shape {dones.shape}')
            print (f'dones {dones}')
            print (f'type q_target {type(q_target)}, shape {q_target.shape}')
            print (f'q_target {q_target}')
            print (f'type q_values {type(q_values)}, shape {q_values.shape}')
            print (f'q_values {q_values[0, 75:85]}')
            print (f'type q_value {type(q_value)}, shape {q_value.shape}')
            print (f'q_values {q_value}')

            loss = criterion(q_value, q_target)
            print(f'loss {loss.item()}')
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # step += 1

            if step % target_update_freq == 0:
                step = 0
                target_net.load_state_dict(q_net.state_dict())
                # torch.save(target_net.state_dict(), 'DQN.pth')

        done = True
        epsilon = max(epsilon_final, epsilon*epsilon_decay)

Episodes:   0%|          | 0/1 [00:00<?, ?it/s]

white True move 7152
len move black 25
Các chỉ số có giá trị True trong before_legal_mask: [0, 1, 2, 254, 255, 256, 257, 446, 447, 519, 527, 583, 591, 647, 655, 711, 719, 775, 783, 839, 847, 903, 911, 967, 975]
Index: 0 -> MoveID: 1
Index: 1 -> MoveID: 2
Index: 2 -> MoveID: 3
Index: 254 -> MoveID: 402
Index: 255 -> MoveID: 403
Index: 256 -> MoveID: 405
Index: 257 -> MoveID: 406
Index: 446 -> MoveID: 705
Index: 447 -> MoveID: 706
Index: 519 -> MoveID: 1020
Index: 527 -> MoveID: 1030
Index: 583 -> MoveID: 1121
Index: 591 -> MoveID: 1131
Index: 647 -> MoveID: 1222
Index: 655 -> MoveID: 1232
Index: 711 -> MoveID: 1323
Index: 719 -> MoveID: 1333
Index: 775 -> MoveID: 1424
Index: 783 -> MoveID: 1434
Index: 839 -> MoveID: 1525
Index: 847 -> MoveID: 1535
Index: 903 -> MoveID: 1626
Index: 911 -> MoveID: 1636
Index: 967 -> MoveID: 1727
Index: 975 -> MoveID: 1737


Episodes: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]

type states <class 'torch.Tensor'>, shape torch.Size([1, 13, 8, 8])
type actions <class 'torch.Tensor'>, shape torch.Size([1, 1])
actions tensor([[447]], device='cuda:0')
type reward <class 'torch.Tensor'>, shape torch.Size([1, 1])
reward tensor([[0.]], device='cuda:0')
type next_states  <class 'torch.Tensor'>, shape torch.Size([1, 13, 8, 8])
type before_legal_masks <class 'torch.Tensor'>, shape torch.Size([1, 4032])
before_legal_masks tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')
type next_legal_masks <class 'torch.Tensor'>, shape torch.Size([1, 4032])
next_legal_masks tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')
type dones <class 'torch.Tensor'>, shape torch.Size([1, 1])
dones tensor([[0.]], device='cuda:0')
type q_target <class 'torch.Tensor'>, shape torch.Size([1, 1])
q_target tensor([[0.9598]], device='cuda:0')
type q_values <class 'torch.Tensor'>, shape torch.Size([1, 4032])
q_values tensor([-inf, -inf, -inf, -inf, -inf


