In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from collections import deque, namedtuple
from magent2.environments import battle_v4
import sys
import time
import copy
import os

# Thêm đường dẫn chứa file torch_model.py và final_torch_model.py nếu cần
sys.path.append('/kaggle/input/model-agent') 
from torch_model import QNetwork  # Thay bằng file chứa class QNetwork của bạn
from final_torch_model import QNetwork as FinalQNetwork # Thay bằng file chứa class FinalQNetwork của bạn

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def add(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, idx):
        state, action, reward, next_state, done = self.buffer[idx]
        return (
            torch.FloatTensor(state),
            torch.LongTensor([action]),
            torch.FloatTensor([reward]),
            torch.FloatTensor(next_state),
            torch.FloatTensor([done])
        )
    
    def can_sample(self, batch_size):
        return len(self.buffer) >= batch_size

class OpponentManager:
    def __init__(self, env, device):
        self.random_policy = lambda: env.action_space("red_0").sample()
        self.device = device
        self.best_model = None
        self.is_selfplay = False

    def get_action(self, obs, episode):
        if episode < RANDOM_PHASE_EPISODES or not self.is_selfplay:
            return self.random_policy()
        return self.get_selfplay_action(obs)

    def get_selfplay_action(self, obs):
        if self.best_model is None:
            return self.random_policy()
        state_tensor = torch.FloatTensor(obs).to(self.device)
        state_tensor = state_tensor.permute(2, 0, 1).unsqueeze(0)
        with torch.no_grad():
            q_values = self.best_model(state_tensor)
        return q_values.argmax().item()

    def update_best_model(self, model, reward):
        global best_reward
        if reward > best_reward:
            self.best_model = copy.deepcopy(model)
            best_reward = reward
            self.is_selfplay = True
            return True
        return False

# Khởi tạo môi trường
env = battle_v4.env(map_size=45, max_cycles=1000, step_reward=-0.005, attack_penalty=-0.1, attack_opponent_reward=0.2, dead_penalty=-0.1)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Khởi tạo Q-networks
q_network = QNetwork(
    env.observation_space("blue_0").shape,
    env.action_space("blue_0").n
).to(device)

target_network = QNetwork(
    env.observation_space("blue_0").shape,
    env.action_space("blue_0").n
).to(device)
target_network.load_state_dict(q_network.state_dict())
target_network.eval()

# Khởi tạo Opponent Manager
opponent = OpponentManager(env, device)

# Khởi tạo Optimizer
optimizer = optim.Adam(q_network.parameters(), lr=1e-4)

# Thiết lập hyperparameters
batch_size = 512  # Thử nghiệm với các giá trị khác nhau
gamma = 0.99
target_update_freq = 1000
train_freq = 4
epsilon_start = 1.0
epsilon_end = 0.05 # Điều chỉnh epsilon_end
epsilon_decay = 0.997 # Điều chỉnh epsilon_decay
num_episodes = 500 # Tăng số episodes
replay_buffer_capacity = 20000 # Tăng dung lượng replay buffer

# Khởi tạo các biến cần thiết
epsilon = epsilon_start
step_count = 0
best_reward = float('-inf')
rewards_history = []

# Self-play parameters
RANDOM_PHASE_EPISODES = 200  # Episodes to train against random
MIN_REWARD_THRESHOLD = 50    # Min reward before self-play

# Hàm chọn hành động cho agent xanh
def select_action(state, epsilon):
    if np.random.rand() < epsilon:
        return env.action_space("blue_0").sample()
    else:
        state_tensor = torch.tensor(state).float().permute(2, 0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = q_network(state_tensor)
        return q_values.argmax().item()

# Hàm tối ưu mô hình
def optimize_model(replay_buffer):
    if not replay_buffer.can_sample(batch_size):
        return
    
    train_loader = DataLoader(
        replay_buffer,
        batch_size=batch_size,
        shuffle=True
    )

    for states, actions, rewards, next_states, dones in train_loader:
        states = states.to(device)
        actions = actions.to(device)
        rewards = rewards.to(device)
        next_states = next_states.to(device)
        dones = dones.to(device)

        # Current Q values
        current_q_values = q_network(states).gather(1, actions)

        # Compute target Q values
        with torch.no_grad():
            next_q_values = target_network(next_states).max(1)[0].unsqueeze(1)
            target_q_values = rewards + gamma * next_q_values * (1 - dones)

        # Compute loss and optimize
        loss = F.mse_loss(current_q_values, target_q_values)
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_value_(q_network.parameters(), 100)
        
        optimizer.step()


def train(env, num_episodes, epsilon_start, epsilon_end, epsilon_decay):
    start_time = time.time()
    time_limit = 7200  # 2 hours

    replay_buffer = ReplayBuffer(replay_buffer_capacity)
    epsilon = epsilon_start
    step_count = 0
    best_reward = float('-inf')
    rewards_history = []

    for episode in range(num_episodes):
        if time.time() - start_time > time_limit:
            print(f"Time limit of {time_limit/3600:.1f} hours reached")
            torch.save(q_network.state_dict(), "blue_final.pt")
            print("Training stopped. Model saved as 'blue_final.pt'")
            return

        env.reset()
        episode_reward = 0
        done = {agent: False for agent in env.agents}
        observations = {}

        while not all(done.values()):
            for agent in env.agent_iter():
                obs, reward, termination, truncation, _ = env.last()
                agent_team = agent.split("_")[0]

                if agent not in observations:
                    observations[agent] = obs

                if termination or truncation:
                    done[agent] = True
                else:
                    if agent_team == "blue":
                        action = select_action(obs, epsilon)
                        next_obs = env.observe(agent)
                        replay_buffer.add(obs, action, reward, next_obs, termination or truncation)
                        episode_reward += reward
                        observations[agent] = next_obs
                    else:
                        action = opponent.get_action(obs, episode)

                env.step(action)

                if agent_team == 'blue':
                    if step_count % train_freq == 0:
                        optimize_model(replay_buffer)
                    step_count += 1

        # Transition to self-play after random phase
        if episode == RANDOM_PHASE_EPISODES:
            opponent.best_model = copy.deepcopy(q_network)
            opponent.is_selfplay = True
            print("Transitioning to self-play training")

        # Update best model if performance improves during self-play
        if episode >= RANDOM_PHASE_EPISODES:
            if opponent.update_best_model(q_network, episode_reward):
                print(f"New best self-play model at episode {episode} with reward {episode_reward:.2f}")

        epsilon = max(epsilon_end, epsilon * epsilon_decay)

        if step_count % target_update_freq == 0:
            target_network.load_state_dict(q_network.state_dict())

        if episode % 10 == 0:
            elapsed_time = time.time() - start_time
            avg_reward = np.mean(rewards_history[-10:]) if rewards_history else 0
            print(f"Episode {episode}/{num_episodes}")
            print(f"Elapsed time: {elapsed_time/3600:.1f} hours")
            print(f"Avg Reward: {avg_reward:.2f}")
            print(f"Epsilon: {epsilon:.4f}")
            print(f"Buffer size: {len(replay_buffer)}")
            print("-" * 50)

        rewards_history.append(episode_reward)

        if (episode + 1) % 100 == 0:
            checkpoint_path = f"blue_checkpoint_{episode+1}.pt"
            torch.save(q_network.state_dict(), checkpoint_path)
            print(f"Model checkpoint saved at episode {episode + 1}")

    torch.save(q_network.state_dict(), "blue.pt")
    print("Training complete. Model saved as 'blue.pt'")

def main():
    train(env, num_episodes, epsilon_start, epsilon_end, epsilon_decay)

if __name__ == "__main__":
    main()

: 