In [30]:
import os, sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import importlib
import ChessEngine
import Minimax.SmartMoveFinder as SmartMoveFinder
import Minimax.Evaluate as Evaluate
import ChessEnv
import replaybuffer
import network
import copy 
import random
import torch
import torch.nn as nn
import tqdm
from tqdm import trange

importlib.reload(ChessEngine)
importlib.reload(ChessEnv)
importlib.reload(replaybuffer)
importlib.reload(network)

<module 'network' from 'd:\\User\\ProjectGithub\\hiepnguyenn-99\\Chess\\RL\\network.py'>

In [31]:
env = ChessEnv.Env()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

q_net = network.DQN(env.action_size).to(device)

# load mô hình lưu
model_path = 'DQN.pth'
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    q_net.load_state_dict(checkpoint)
    q_net.train()
    print("Đã load model")
else:
    print("Không tìm thấy model")

target_net = copy.deepcopy(q_net).to(device)
target_net.eval()
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-4)
criterion  = nn.MSELoss(reduction='mean')

cuda
Đã load model


In [32]:
env = ChessEnv.Env()
capacity = 1
rb = replaybuffer.ReplayBuffer(capacity)
batch_size = 1
epsilon = 1.0
epsilon_final = 0.01
epsilon_decay = 0.995
gamma = 0.9
step = 0
target_update_freq = 1
done = False
env.gs.whiteToMove = True # train trắng đi trước (minimax)

for episode in trange(10000, desc="Episodes"):
    state = env.reset()
    while not done:
        # nước đi của minimax
        captured = None # kiểm tra minimax có ăn quân của rl không
        if env.gs.whiteToMove:
            if Evaluate.check_mid_game(env.gs):
                SmartMoveFinder.DEPTH = 4
            else:
                SmartMoveFinder.DEPTH = 3
            MinimaxMove = SmartMoveFinder.findBestMinimaxMove(env.gs, env.gs.getValidMoves())
            if MinimaxMove is None:
                MinimaxMove = SmartMoveFinder.findRandomMove(env.gs.getValidMoves())
            print(f'white {env.gs.whiteToMove} move {MinimaxMove.moveID}')
            if MinimaxMove.pieceCaptured != '--':
                captured = MinimaxMove.pieceCaptured[1]
                print(f'white captured {captured}')
            env.gs.makeMove(MinimaxMove)

        state_tensor = env.state_to_tensor()
        # khám phá
        if random.random() < epsilon:
            move = random.choice(env.gs.getValidMoves())
            action = env.moveid_to_index[move.moveID]
            print(f'white {env.gs.whiteToMove} move {move.moveID}')

        # khai thác
        else:
            q_value = q_net(state_tensor.unsqueeze(0).to(device)) 
            action = q_value.argmax()
        
        next_state_tensor, reward, done, legal_mask = env.step(action, captured) # đen đã đi, kiểm tra đen bị ăn quân ko
        rb.push(state_tensor, action, reward, next_state_tensor, done, legal_mask)

        if rb.__len__() >= batch_size:
            batch = rb.sample(batch_size)
            # đã chuyển về tensor
            states, actions, reward, next_states, dones,  legal_masks= zip(*batch)
            states = torch.stack([s.to(device) for s in states])
            actions = torch.tensor(actions, device=device, dtype=torch.int64).unsqueeze(0)
            reward = torch.tensor(reward, device=device, dtype=torch.float32).unsqueeze(0)
            next_states = torch.stack([ns.to(device) for ns in next_states])
            legal_masks = torch.stack([lm.to(device) for lm in legal_masks])
            dones = torch.tensor(dones, device=device, dtype=torch.float32).unsqueeze(0)

            with torch.no_grad():
                next_q_values = target_net(next_states) # (B, action_size)
                next_q_values[~legal_masks] = -torch.inf
                next_q_max = next_q_values.max(1)[0].unsqueeze(1) # (B, 1)
                q_target = reward + gamma * (1 - dones) * next_q_max # (B, 1)

            q_values = q_net(states.to(device)) # (B, action_size)
            q_values[~legal_masks] = -torch.inf
            q_value = q_values.gather(dim=1, index=actions) # tính q value với hành động đã chọn

            loss = criterion(q_value, q_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # step += 1

            if step % target_update_freq == 0:
                step = 0
                target_net.load_state_dict(q_net.state_dict())
                torch.save(target_net.state_dict(), 'DQN.pth')

        done = True
        epsilon = max(epsilon_final, epsilon*epsilon_decay)

Episodes:   0%|          | 1/10000 [00:00<46:46,  3.56it/s]

white True move 7152
white False move 1727


Episodes: 100%|██████████| 10000/10000 [00:01<00:00, 7256.59it/s]
