In [3]:
import random
import torch
import torch.nn as nn #  PyTorch의 Neural Network 모듈을 포함하는 모듈
import torch.optim as optim #  PyTorch에서 제공하는 최적화(optimizer) 알고리즘을 포함하는 모듈
import chess
import math

class Node:
    def __init__(self, board, parent=None, action=None):
        self.board = board  # 체스 보드 상태
        self.parent = parent  # 부모 노드
        self.children = []  # 자식 노드들
        self.visit = 0  # 방문 횟수
        self.wins = 0  # 승리 횟수
        self.untried_moves = list(board.legal_moves)  # 아직 시도하지 않은 수들
        self.game_over = board.is_game_over()
        if self.game_over:
            self.untried_moves = []
        # action: 부모 노드에서 선택한 이동; 루트는 None
        self.action = action  # action은 UCI 문자열 (예: "e2e4")

    def uct_value(self):
        if self.visit == 0:
            return float('inf')
        return self.wins / self.visit + math.sqrt(2 * math.log(self.parent.visit) / self.visit)

    def best_child(self, policy_probabilities):
        """
        자식 노드들 중에서 UCT 값과 정책 네트워크의 확률을 결합하여 가장 높은 값을 가진 자식을 선택
        policy_probabilities: 현재 노드의 legal_moves에 대응하는 확률 배열 (길이 = len(legal_moves))
        """
        max_value = float('-inf')
        selected_child = None

        legal_moves = list(self.board.legal_moves)
        # 만약 정책 확률의 크기가 legal_moves와 다르면 에러 처리
        if len(policy_probabilities) != len(legal_moves):
            print("Error: Policy probabilities length does not match number of legal moves.")
            return None

        for child in self.children:
            # child.board.peek()는 자식 노드로 도달한 마지막 이동
            move = child.board.peek()
            if move not in legal_moves:
                continue
            move_index = legal_moves.index(move)
            # 결합: UCT + 정책 확률(여기서는 단순 합산; 하이퍼파라미터 조정 가능)
            value = child.uct_value() + policy_probabilities[move_index]
            if value > max_value:
                max_value = value
                selected_child = child
        return selected_child


# 정책 네트워크 ( Policy Network )
class PolicyNetwork(nn.Module):
    def __init__(self):
      super(PolicyNetwork, self).__init__()
      self.conv1 = nn.Conv2d(12, 32, kernel_size=3, stride=1, padding=1)  # Conv2D: 보드 상태를 특징 맵으로 변환
      self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
      self.fc1 = nn.Linear(64 * 8 * 8, 256) # Fully Connected Layer ( 1차원 벡터로 변환 )
      self.fc2 = nn.Linear(256, 4672) # 체스에서 가능한 수의 개수 ( 4672 개의 수 예측 )

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.shape[0], -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)  # 확률 값 출력
        return x

# 가치 네트워크 (Value Network)
class ValueNetwork(nn.Module):
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, stride=1, padding=1)  # Conv2D: 보드 상태를 특징 맵으로 변환
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)  # Fully connected layer (1차원 벡터로 변환)
        self.fc2 = nn.Linear(256, 1)  # 승리 확률 예측

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # ReLU 활성화 함수
        x = torch.relu(self.conv2(x))
        x = x.view(x.shape[0], -1)  # Flatten
        x = torch.relu(self.fc1(x))  # ReLU 활성화 함수
        x = torch.sigmoid(self.fc2(x))  # 승리 확률 출력 (0~1 사이 값)
        return x

def move_to_index(move_uci):
    return abs(hash(move_uci)) % 4672

In [43]:
import sys
import torch
import chess
import chess.svg
from PyQt5.QtWidgets import QApplication, QVBoxLayout, QWidget, QLabel
from PyQt5.QtSvg import QSvgWidget
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QMouseEvent
import numpy as np

class ChessApp(QWidget):
    def __init__(self):
        super().__init__()
        self.board = chess.Board()
        self.selected_square = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.initUI()

        # 모델 불러오기 및 평가모드 설정
        self.policy_net = PolicyNetwork().to(self.device)
        self.value_net = ValueNetwork().to(self.device)
        # weights_only 인자는 제거
        self.policy_net.load_state_dict(torch.load("./models/policy_net_0328.pth", map_location=self.device, weights_only=True))
        self.value_net.load_state_dict(torch.load("./models/value_net_0328.pth", map_location=self.device, weights_only=True))
        self.policy_net.eval()
        self.value_net.eval()

    def initUI(self):
        self.setWindowTitle('AI 체스')
        self.setGeometry(100, 100, 500, 500)
        
        self.turn_label = QLabel("사용자의 턴", self)
        self.turn_label.setAlignment(Qt.AlignCenter)
        
        self.svg_widget = QSvgWidget()
        self.svg_widget.setMouseTracking(True)
        self.svg_widget.mousePressEvent = self.handle_click

        layout = QVBoxLayout()
        layout.addWidget(self.svg_widget)
        self.setLayout(layout)
        self.update_board()

    def update_board(self):
        """체스 보드를 업데이트하고 UI에 반영"""
        last_move = self.board.peek() if self.board.move_stack else None
        board_svg = chess.svg.board(self.board, lastmove=last_move)
        self.svg_widget.load(bytearray(board_svg, encoding='utf-8'))

    def handle_click(self, event: QMouseEvent):
        if self.board.turn != chess.WHITE:
            return  # 사용자의 턴이 아닐 때 클릭 무시
        """체스 보드를 클릭하면 사용자의 수를 처리"""
        file = event.x() // (self.svg_widget.width() // 8)
        rank = 7 - (event.y() // (self.svg_widget.height() // 8))
        square = chess.square(file, rank)

        if self.selected_square is None:
            if self.board.piece_at(square) and self.board.color_at(square) == chess.WHITE:
                self.selected_square = square
        else:
            move = chess.Move(self.selected_square, square)
            if move in self.board.legal_moves:
                self.board.push(move)
                self.selected_square = None
                self.update_board()
                self.turn_label.setText("AI의 턴")
                QApplication.processEvents()
                self.ai_move()  # AI의 턴 실행
            else:
                self.selected_square = None

    def ai_move(self):
        """AI의 수를 계산하고 실행 (MCTS와 신경망 결합)"""
        move = self.mcts_search(num_simulations=1500)
        if move is None:
            print("사용자 승리")
            self.closeEvent()
        elif move:
            self.board.push(move)
            self.update_board()
            self.turn_label.setText("사용자의 턴")

    # 헬퍼 함수: 체스 보드를 텐서로 변환
    def board_to_tensor(self, board, device):
        piece_map = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
        board_matrix = np.zeros((12, 8, 8), dtype=np.float32)
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                row, col = divmod(square, 8)
                board_matrix[piece_map[piece.symbol()], row, col] = 1
        tensor = torch.tensor(board_matrix, dtype=torch.float32).unsqueeze(0).to(device)
        return tensor

    # MCTS와 정책-가치 신경망을 결합한 수 선택 함수
    def mcts_search(self, num_simulations=1500):
        # 현재 보드를 사용하여 MCTS 탐색 시작 (self.board)
        root = Node(self.board.copy())
        c_puct = 2.0  # 탐험 상수

        for _ in range(num_simulations):
            node = root
            simulation_board = self.board.copy()

            # 선택 단계: leaf 노드까지 내려감
            while node.untried_moves == [] and node.children:
                tensor = self.board_to_tensor(simulation_board, self.device)
                with torch.no_grad():
                    logits = self.policy_net(tensor)
                    probs = torch.softmax(logits, dim=1).cpu().numpy().flatten()
                legal_moves = list(simulation_board.legal_moves)
                legal_probs = np.array([probs[move_to_index(move.uci())] for move in legal_moves])
                if legal_probs.sum() > 0:
                    legal_probs = legal_probs / legal_probs.sum()
                else:
                    legal_probs = np.ones(len(legal_moves)) / len(legal_moves)
                best_value = -float('inf')
                best_child = None
                for child in node.children:
                    move = chess.Move.from_uci(child.action)
                    if move not in legal_moves:
                        continue
                    move_index = legal_moves.index(move)
                    Q = child.wins / child.visit if child.visit > 0 else 0
                    U = c_puct * legal_probs[move_index] * math.sqrt(node.visit) / (1 + child.visit)
                    value = Q + U
                    if value > best_value:
                        best_value = value
                        best_child = child
                if best_child is None:
                    break
                simulation_board.push(chess.Move.from_uci(best_child.action))
                node = best_child

            # 확장 단계
            if node.untried_moves:
                move = node.untried_moves.pop()
                simulation_board.push(move)
                new_node = Node(simulation_board.copy(), parent=node, action=move.uci())
                node.children.append(new_node)
                node = new_node

            # 시뮬레이션 단계: 가치 네트워크 평가
            tensor = self.board_to_tensor(simulation_board, self.device)
            with torch.no_grad():
                value = self.value_net(tensor).item()

            # 백프로파게이션 단계
            while node is not None:
                node.visit += 1
                node.wins += value
                node = node.parent
        
        if root.children is None:
            return None
            
        best_node = max(root.children, key=lambda n: n.visit)
        return chess.Move.from_uci(best_node.action)

    def closeEvent(self, event):
        print("체스 애플리케이션 종료 중...")
        self.close()
        event.accept()

if __name__ == '__main__':
    if QApplication.instance() is None:
        app = QApplication(sys.argv)
    else:
        app = QApplication.instance()
    ex = ChessApp()
    ex.show()
    app.exec()

RuntimeError: Error(s) in loading state_dict for PolicyNetwork:
	Unexpected key(s) in state_dict: "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.num_batches_tracked", "conv3.weight", "conv3.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "bn3.num_batches_tracked", "conv4.weight", "conv4.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "bn4.num_batches_tracked", "fc3.weight", "fc3.bias". 
	size mismatch for fc1.weight: copying a param with shape torch.Size([512, 16384]) from checkpoint, the shape in current model is torch.Size([256, 4096]).
	size mismatch for fc1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([4672, 256]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([4672]).