In [6]:
import os, sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import importlib
import ChessEngine
import Minimax.SmartMoveFinder as SmartMoveFinder
import Minimax.Evaluate as Evaluate
from Move import BestRLMove
import ChessEnv
import replaybuffer
import network
import copy 
import random
import torch
import torch.nn as nn
import tqdm
from tqdm import trange

importlib.reload(ChessEngine)
importlib.reload(ChessEnv)
importlib.reload(replaybuffer)
importlib.reload(network)

<module 'network' from 'd:\\User\\ProjectGithub\\hiepnguyenn-99\\Chess\\RL\\network.py'>

In [7]:
env = ChessEnv.Env()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

q_net = network.DQN(env.action_size).to(device)

# load mô hình lưu
model_path = 'DQN.pth'
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    q_net.load_state_dict(checkpoint)
    q_net.train()
    print("Đã load model")
else:
    print("Không tìm thấy model")

target_net = copy.deepcopy(q_net).to(device)
target_net.eval()
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-4)
criterion = nn.SmoothL1Loss()  

cuda
Đã load model


In [8]:
env = ChessEnv.Env()
capacity = 32000
rb = replaybuffer.ReplayBuffer(capacity)
batch_size = 64
epsilon = 1.0
epsilon_final = 0.01
epsilon_decay = 0.995
gamma = 0.99
step = 0
target_update_freq = 2000
env.gs.whiteToMove = True # train trắng đi trước (minimax)

for episode in trange(500000):
    done = False
    state = env.reset()
    while not done:
        # nước đi của minimax
        white_capture = None # kiểm tra minimax có ăn quân của rl không
        if env.gs.whiteToMove:
            AIMove = None
            if Evaluate.check_mid_game(env.gs):
                SmartMoveFinder.DEPTH = 4
            else:
                SmartMoveFinder.DEPTH = 3
            AIMove = SmartMoveFinder.findBestMinimaxMove(env.gs, env.gs.getValidMoves())
            if AIMove is None:
                AIMove = SmartMoveFinder.findRandomMove(env.gs.getValidMoves())
            if AIMove.pieceCaptured != '--':
                white_capture = AIMove.pieceCaptured[1]
            env.gs.makeMove(AIMove)

        state_tensor = env.state_to_tensor()
        lenlegalmove = len(env.gs.getValidMoves())
        # khám phá
        if lenlegalmove != 0:    
            if random.random() < epsilon:
                move = random.choice(env.gs.getValidMoves())
                action = env.moveid_to_index[move.moveID]

            # khai thác
            else:
                q_value = q_net(state_tensor.unsqueeze(0).to(device)) 
                action = q_value.argmax()
            
        next_state_tensor, reward, done, before_legal_mask, after_legal_mask = env.step(action, white_capture, lenlegalmove) # đen đã đi, kiểm tra đen bị ăn quân ko
        rb.push(state_tensor, action, reward, next_state_tensor, done, before_legal_mask, after_legal_mask)

        if rb.__len__() >= batch_size:
            batch = rb.sample(batch_size)
            # chuyển về tensor
            states, actions, reward, next_states, dones, before_legal_masks, after_legal_mask= zip(*batch)
            states = torch.stack([s.to(device) for s in states])
            actions = torch.tensor(actions, device=device, dtype=torch.int64).unsqueeze(1) # (B, 1)
            reward = torch.tensor(reward, device=device, dtype=torch.float32).unsqueeze(1)
            next_states = torch.stack([ns.to(device) for ns in next_states])
            before_legal_masks = torch.stack([b.to(device) for b in before_legal_masks])
            after_legal_mask = torch.stack([a.to(device) for a in after_legal_mask])
            dones = torch.tensor(dones, device=device, dtype=torch.float32).unsqueeze(1)

            with torch.no_grad():
                next_q_values = target_net(next_states) # (B, action_size)
                next_q_values[~after_legal_mask] = -torch.inf
                next_q_max = next_q_values.max(1)[0].unsqueeze(1) # (B, 1)

                # nếu không còn nước đi hợp lệ của next_states thì đặt bằng 0
                legal_exists = after_legal_mask.any(dim=1, keepdim=True)
                next_q_max = torch.where(legal_exists, next_q_max, torch.zeros_like(next_q_max))

                q_target = reward + gamma * (1 - dones) * next_q_max # (B, 1)
                

            q_values = q_net(states.to(device)) # (B, action_size)
            q_values[~before_legal_masks] = -torch.inf
            q_value = q_values.gather(dim=1, index=actions) # tính q value với hành động đã chọn (B, 1)

            if step % 200 == 0:
                with open('train_log.txt', 'a') as f:
                    if q_target[0][0] != -10:
                        print(f'q_target max {q_target[0][0] * 1e-10}', file=f)
                    else:
                        print(f'q_target max {q_target[0][0]}', file=f)
                    print(f'q_value max {q_value[0][0] * 1e-10}', file=f)
            
            loss = criterion(q_value, q_target)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(q_net.parameters(), max_norm=1.0)
            optimizer.step()
            step += 1

            if step % target_update_freq == 0:
                target_net.load_state_dict(q_net.state_dict())
                torch.save(target_net.state_dict(), 'DQN.pth')
                with open('train_log.txt', 'a') as f:
                    print(f'step {step}, model saved', file=f)

        epsilon = max(epsilon_final, epsilon*epsilon_decay)


  0%|          | 1/500000 [00:29<4093:05:05, 29.47s/it]

total sum reward 0.0


  0%|          | 2/500000 [01:11<5087:36:15, 36.63s/it]

total sum reward 0.0


  0%|          | 33/500000 [02:00<179:56:17,  1.30s/it]

hết nước đi hợp lệ, self.gs.checkMate, -1


 11%|█▏        | 56955/500000 [3:16:16<25:26:48,  4.84it/s] 


KeyboardInterrupt: 