In [5]:
# Install magent2
!pip install git+https://github.com/Farama-Foundation/MAgent2



In [None]:
# Import các thư viện cần thiết
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from collections import deque
from magent2.environments import battle_v4
from torch_model import QNetwork  # Đảm bảo rằng bạn đã tải lên file torch_model.py
from torch.utils.data import Dataset, DataLoader

# Định nghĩa lớp ReplayBuffer
class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def __len__(self):
        return len(self.buffer)
    
    def __getitem__(self, idx):
        state, action, reward, next_state, done = self.buffer[idx]
        return (
            torch.FloatTensor(state).permute(2, 0, 1),
            torch.LongTensor([action]),
            torch.FloatTensor([reward]), 
            torch.FloatTensor(next_state).permute(2, 0, 1),
            torch.FloatTensor([done])
        )

# Khởi tạo môi trường và thiết bị
env = battle_v4.env(map_size=45, max_cycles=300)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Khởi tạo Q-networks
q_network = QNetwork(
    env.observation_space("blue_0").shape,
    env.action_space("blue_0").n
).to(device)

target_network = QNetwork(
    env.observation_space("blue_0").shape,
    env.action_space("blue_0").n
).to(device)
target_network.load_state_dict(q_network.state_dict())
target_network.eval()

# Random policy for red agents
def random_policy():
    return env.action_space("red_0").sample()

# Khởi tạo Optimizer và Replay Buffer
optimizer = optim.Adam(q_network.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)

# Thiết lập hyperparameters
batch_size = 64
gamma = 0.99
target_update_freq = 1000
train_freq = 4
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.995
epsilon = epsilon_start
num_episodes = 800
step_count = 0

# Hàm chọn hành động cho agent xanh
def select_action(state, epsilon):
    if np.random.rand() < epsilon:
        return env.action_space("blue_0").sample()
    else:
        state_tensor = torch.tensor(state).float().permute(2, 0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = q_network(state_tensor)
        return q_values.argmax().item()

# Hàm tối ưu mô hình
def optimize_model():
    train_loader = DataLoader(
        replay_buffer,
        batch_size=batch_size,
        shuffle=True
    )
    
    for states, actions, rewards, next_states, dones in train_loader:
        states = states.to(device)
        actions = actions.to(device)
        rewards = rewards.to(device)
        next_states = next_states.to(device) 
        dones = dones.to(device)

        q_values = q_network(states).gather(1, actions)
        with torch.no_grad():
            next_q_values = target_network(next_states).max(1)[0].unsqueeze(1)
        target_q_values = rewards + gamma * next_q_values * (1 - dones)

        loss = F.mse_loss(q_values, target_q_values)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Vòng lặp huấn luyện chính
for episode in range(num_episodes):
    env.reset()
    total_reward = 0
    done = {agent: False for agent in env.agents}
    observations = {}  # Track observations

    while not all(done.values()):
        for agent in env.agent_iter():
            obs, reward, termination, truncation, _ = env.last()
            agent_team = agent.split("_")[0]
            
            # Track observations
            if agent not in observations:
                observations[agent] = obs
            next_obs = observations.get(agent, obs)  # Use tracked observation

            if termination or truncation:
                action = None
                done[agent] = True
            else:
                if agent_team == "blue":
                    action = select_action(obs, epsilon)
                    replay_buffer.add(obs, action, reward, next_obs, termination or truncation)
                    observations[agent] = next_obs  # Update observation
                    step_count += 1
                    total_reward += reward
                else:
                    action = random_policy()

            env.step(action)

        # Optimize outside agent loop
        if step_count % train_freq == 0:
            optimize_model()

    # Cập nhật epsilon
    epsilon = max(epsilon_end, epsilon * epsilon_decay)

    # Cập nhật target network
    if step_count % target_update_freq == 0:
        target_network.load_state_dict(q_network.state_dict())

    print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward:.2f}, Epsilon: {epsilon:.4f}")

    # Lưu checkpoint mô hình
    if (episode + 1) % 100 == 0:
        checkpoint_path = f"blue_vs_final_checkpoint_{episode + 1}.pt"
        torch.save(q_network.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved at episode {episode + 1}")

# Lưu mô hình đã huấn luyện
torch.save(q_network.state_dict(), "blue_vs_random.pt")
print("Training complete. Model saved as 'blue_vs_random.pt'")