In [None]:
import torch
import numpy as np
import random
import copy
import os
import threading
import time
from queue import Queue
import AlphaExitNet

model = 'cuda' if torch.cuda.is_available() else 'mps'

class GPUOptimizedAlphaTrainingApp:
    def __init__(self, batch_size=128, training_step_delay=1):
        self.training_step_delay = training_step_delay
        self.batch_size = batch_size

        # GPU 관련 설정
        self.device = torch.device("cuda" if torch.cuda.is_available() else "mps")
        if self.device.type == "cuda":
            torch.backends.cudnn.benchmark = True

        # 환경과 네트워크 초기화
        self.env = AlphaExitNet.ExitStrategyEnv()
        self.network = AlphaExitNet.AlphaZeroNet(
            board_size=7,
            in_channels=2,
            num_res_blocks=3,
            num_filters=64
        )
        self.network.to(self.device)
        if os.path.exists("alphazero_model.pth"):
            AlphaExitNet.load_model(self.network, "alphazero_model.pth", self.device)

        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=2e-4)

        # 학습 데이터 관련 변수
        self.replay_buffer = []
        self.episode_data = []
        self.episode_count = 0
        self.max_replay_buffer_size = 10000

        # MCTS 관련 설정
        self.num_simulations = 100
        self.temperature = 1.0

        # 상태 업데이트를 위한 큐 (터미널 출력용)
        self.state_queue = Queue()

        # 초기 상태 설정
        self.current_state = self.env.reset()
        self.start_training_thread()
        self.start_terminal_output_thread()

    def terminal_output_loop(self):
        """에피소드가 종료될 때마다 최종 상태를 터미널에 출력"""
        while True:
            try:
                while not self.state_queue.empty():
                    msg = self.state_queue.get_nowait()
                    # msg는 딕셔너리 형태로 에피소드 종료 정보를 담고 있음
                    if msg.get("type") == "episode_end":
                        ep = msg.get("episode_count")
                        loss = msg.get("loss")
                        phase = msg.get("phase")
                        loss_str = f"{loss:.4f}" if loss is not None else "N/A"
                        print(f"Episode {ep} finished: Loss: {loss_str}, Phase: {phase}")
            except Exception as e:
                print("Terminal output error:", e)
            time.sleep(self.training_step_delay)

    def start_terminal_output_thread(self):
        output_thread = threading.Thread(target=self.terminal_output_loop, daemon=True)
        output_thread.start()

    def run_mcts(self, state):
        """GPU에 최적화된 MCTS 실행"""
        legal_moves_mask = AlphaExitNet.get_legal_moves_mask(state, self.env)
        with torch.no_grad():
            initial_policy, _ = AlphaExitNet.neural_net_fn(state, self.network, self.device, legal_moves_mask)

        root_node = AlphaExitNet.Node(state, prior=1.0)
        root_node.expand(initial_policy, AlphaExitNet.next_state_func)

        action_probs = AlphaExitNet.mcts_search(
            root_node,
            lambda s: AlphaExitNet.neural_net_fn(s, self.network, self.device, AlphaExitNet.get_legal_moves_mask(s, self.env)),
            num_simulations=self.num_simulations,
            c_puct=5.0,
            next_state_func=AlphaExitNet.next_state_func,
            is_terminal_func=AlphaExitNet.is_terminal_func
        )
        return action_probs

    def train_network(self, batch):
        """GPU에 최적화된 네트워크 학습"""
        total_loss = 0
        self.optimizer.zero_grad()

        # Placement phase
        placement_examples = [ex for ex in batch if ex[0]["phase"] == "placement"]
        if placement_examples:
            states, policies, outcomes = zip(*placement_examples)
            state_tensors = torch.stack([
                torch.tensor(s['board'], dtype=torch.float32)
                for s in states
            ]).to(self.device)
            policy_tensors = torch.stack([
                torch.tensor([policies[i].get(a, 0.0) for a in range(49)], dtype=torch.float32)
                for i in range(len(placement_examples))
            ]).to(self.device)
            outcome_tensors = torch.tensor(outcomes, dtype=torch.float32).view(-1, 1).to(self.device)

            log_policy, predicted_value = self.network(state_tensors, phase="placement")
            loss = AlphaExitNet.compute_loss(log_policy, predicted_value, policy_tensors, outcome_tensors, self.network)
            loss.backward()
            total_loss += loss.item()

        # Movement phase
        movement_examples = [ex for ex in batch if ex[0]["phase"] == "movement"]
        if movement_examples:
            states, policies, outcomes = zip(*movement_examples)
            state_tensors = torch.stack([
                torch.tensor(s['board'], dtype=torch.float32)
                for s in states
            ]).to(self.device)
            policy_tensors = torch.stack([
                torch.tensor([policies[i].get(a, 0.0) for a in range(24)], dtype=torch.float32)
                for i in range(len(movement_examples))
            ]).to(self.device)
            outcome_tensors = torch.tensor(outcomes, dtype=torch.float32).view(-1, 1).to(self.device)

            log_policy, predicted_value = self.network(state_tensors, phase="movement")
            loss = AlphaExitNet.compute_loss(log_policy, predicted_value, policy_tensors, outcome_tensors, self.network)
            loss.backward()
            total_loss += loss.item()

        self.optimizer.step()
        return total_loss

    def training_loop(self):
        """별도 스레드에서 실행되는 학습 루프"""
        while True:
            state = self.env.reset()
            self.episode_data = []

            while True:
                action_probs = self.run_mcts(state)

                # 행동 선택 (temperature 적용)
                actions = list(action_probs.keys())
                probs = np.array([action_probs[a] for a in actions])
                # if self.episode_count < 500:  # 초기에는 더 많은 탐험
                #     probs = probs ** (1 / self.temperature)
                probs = probs / np.sum(probs)
                action = np.random.choice(actions, p=probs)

                next_state, reward, done, info = self.env.step(action)
                self.episode_data.append((
                    copy.deepcopy(state),
                    action_probs,
                    state["current_player"],
                    reward,
                    info
                ))

                if done:
                    break
                state = next_state

            # 에피소드 종료 후 처리
            self.episode_count += 1
            self.process_episode()

            # 모델 저장
            if self.episode_count % 1 == 0:
                AlphaExitNet.save_model(self.network, "alphazero_model.pth")

            # Temperature 조정
            if self.episode_count % 100 == 0:
                self.temperature = max(0.1, self.temperature * 0.95)

    def process_episode(self):
        """에피소드 데이터 처리 및 학습"""
        cumulative_return = 0.0
        gamma = 0.9
        penalty = -1.0

        for (s, mcts_policy, player, r, info) in reversed(self.episode_data):
            if info.get("max_turn_penalty", False):
                cumulative_return = r + penalty + gamma * cumulative_return
            else:
                cumulative_return = r + gamma * cumulative_return
            self.replay_buffer.insert(0, (s, mcts_policy, cumulative_return))

        # 버퍼 크기 제한
        if len(self.replay_buffer) > self.max_replay_buffer_size:
            self.replay_buffer = self.replay_buffer[-self.max_replay_buffer_size:]

        # 배치 학습 및 에피소드 종료 출력
        if len(self.replay_buffer) >= self.batch_size:
            batch = random.sample(self.replay_buffer, self.batch_size)
            loss = self.train_network(batch)
            # 에피소드가 끝났을 때만 출력하도록 메시지 큐에 삽입
            self.state_queue.put({
                "type": "episode_end",
                "episode_count": self.episode_count,
                "loss": loss,
                "phase": self.env.phase
            })

    def start_training_thread(self):
        training_thread = threading.Thread(target=self.training_loop, daemon=True)
        training_thread.start()

if __name__ == "__main__":
    app = GPUOptimizedAlphaTrainingApp()
    while True:
        time.sleep(1)


Model loaded from alphazero_model.pth
Model saved to alphazero_model.pth
Episode 1 finished: Loss: 33.9413, Phase: movement
Model saved to alphazero_model.pth
Episode 2 finished: Loss: 23.8086, Phase: movement
Model saved to alphazero_model.pth
Episode 3 finished: Loss: 15.8153, Phase: movement
Model saved to alphazero_model.pth
Episode 4 finished: Loss: 13.9746, Phase: movement
Model saved to alphazero_model.pth
Episode 5 finished: Loss: 16.1795, Phase: movement
Model saved to alphazero_model.pth
Episode 6 finished: Loss: 8.2604, Phase: movement
Model saved to alphazero_model.pth
Episode 7 finished: Loss: 14.0755, Phase: movement
Model saved to alphazero_model.pth
Episode 8 finished: Loss: 12.8680, Phase: movement
Model saved to alphazero_model.pth
Episode 9 finished: Loss: 8.6040, Phase: movement
Model saved to alphazero_model.pth
Episode 10 finished: Loss: 10.6733, Phase: movement
Model saved to alphazero_model.pth
Episode 11 finished: Loss: 10.5047, Phase: movement
Model saved to a