In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/rl-parameter/trained_agent.pth
/kaggle/input/rl-parameter/red.pt


In [2]:
! pip install magent2 pytorch_lightning
! pip install pettingzoo==1.22.0


Collecting magent2
  Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting pygame>=2.1.0 (from magent2)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pygame, magent2
Successfully installed magent2-0.3.3 pygame-2.6.1
Collecting pettingzoo==1.22.0
  Downloading PettingZoo-1.22.0-py3-none-any.whl.metadata (5.0 kB)
Downloading PettingZoo-1.22.0-py3-none-any.whl (823 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
class PPOActorCriticConv(nn.Module):
    def __init__(self, input_channels, input_size, action_space_size):
        super(PPOActorCriticConv, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),  # (input_channels x 13 x 13 -> 32 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),             # (32 x 13 x 13 -> 64 x 13 x 13)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)                              # (64 x 13 x 13 -> 64 x 6 x 6)
        )
        # Flatten output from Conv2D
        conv_output_size = 64 * (input_size // 2) * (input_size // 2)  # For 13x13 input -> 64x6x6 = 2304

        # Fully Connected Shared Layer
        self.shared_layer = nn.Sequential(
            nn.Linear(conv_output_size, 256),
            nn.ReLU()
        )
        
        # Actor Head
        self.actor = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_space_size),
            nn.Softmax(dim=-1)  # Output probabilities for actions
        )

        # Critic Head
        self.critic = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Output value of the state
        )

    def forward(self, x):
        # Input shape: (batch_size, input_channels, input_size, input_size)
        conv_output = self.conv_layers(x)  # Convolutional layers
        flat_output = torch.flatten(conv_output, start_dim=1)  # Flatten (batch_size, conv_output_size)
        shared_output = self.shared_layer(flat_output)  # Shared layer
        policy = self.actor(shared_output)  # Actor output
        value = self.critic(shared_output)  # Critic output
        return policy, value

class PPOAgentWithLightning(pl.LightningModule):
    def __init__(self, input_channels, input_size, action_space_size, lr=3e-4, gamma=0.99, clip_epsilon=0.2):
        super(PPOAgentWithLightning, self).__init__()
        self.model = PPOActorCriticConv(input_channels, input_size, action_space_size)
        self.lr = lr
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon

    def forward(self, x):
        policy, value = self.model(x)
        return policy, value

    def compute_loss(self, batch):
        states, actions, rewards, dones, old_policies = batch
        policy, value = self(states)
        value = value.squeeze(-1)

        # Compute Advantage
        returns, advantages = self.compute_advantages(rewards, value.detach(), dones)

        # Compute Policy Loss
        new_policies = policy.gather(1, actions.unsqueeze(-1)).squeeze(-1)
        policy_ratio = new_policies / old_policies
        clipped_ratio = torch.clamp(policy_ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
        policy_loss = -torch.min(policy_ratio * advantages, clipped_ratio * advantages).mean()

        # Compute Value Loss
        value_loss = nn.MSELoss()(value, returns)

        # Combine Losses
        total_loss = policy_loss + 0.5 * value_loss
        return total_loss

    def compute_advantages(self, rewards, values, dones):
        returns = []
        advantages = []
        G = 0
        for r, v, d in zip(reversed(rewards), reversed(values), reversed(dones)):
            G = r + (1 - d) * self.gamma * G
            returns.insert(0, G)
            advantages.insert(0, G - v)
        returns = torch.tensor(returns, dtype=torch.float32, device=self.device)
        advantages = torch.tensor(advantages, dtype=torch.float32, device=self.device)
        return returns, advantages

    def training_step(self, batch, batch_idx):
        loss = self.compute_loss(batch)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)


In [4]:
from torch.utils.data import Dataset

class RLReplayDataset(Dataset):
    def __init__(self, buffer):
        self.buffer = buffer

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, idx):
        state, action, reward, next_state, done, old_policy = self.buffer[idx]
        return (
            torch.tensor(state, dtype=torch.float32).permute(2, 0, 1),  # (channels, height, width)
            torch.tensor(action, dtype=torch.long),
            torch.tensor(reward, dtype=torch.float32),
            torch.tensor(done, dtype=torch.float32),
            torch.tensor(old_policy, dtype=torch.float32)
        )


In [5]:
from magent2.environments import battle_v4
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

# Initialize environment
env = battle_v4.env(map_size=45, max_cycles=300, step_reward=0.001, attack_opponent_reward=1, dead_penalty=-0.4)
env.reset()

# Hyperparameters
input_channels = 5
input_size = 13
action_space_size = 21
lr = 3e-4
gamma = 0.99
clip_epsilon = 0.2
replay_buffer = []
max_buffer_size = 10000
batch_size = 256
n_episodes = 40

# Initialize Lightning Module
agent = PPOAgentWithLightning(input_channels, input_size, action_space_size, lr, gamma, clip_epsilon)

# Trainer
trainer = pl.Trainer(max_epochs=n_episodes, devices=1 if torch.cuda.is_available() else 4, accelerator='gpu' if torch.cuda.is_available() else 'cpu')

for episode in tqdm(range(n_episodes), desc="Training Episodes"):
    env.reset()
    episode_buffer = []
    for agent_name in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()
        team = agent_name.split('_')[0]

        if termination or truncation:
            action = None
        else:
            observation_tensor = torch.tensor(observation, dtype=torch.float32).permute(2, 0, 1)
            policy, _ = agent(observation_tensor.unsqueeze(0))  # Add batch dimension
            action = torch.multinomial(policy.squeeze(0), 1).item()

            # Store transition
            next_observation, _, _, _, _ = env.last()
            old_policy = policy[0, action].item()
            done = 1 if termination or truncation else 0
            episode_buffer.append((observation, action, reward, next_observation, done, old_policy))

        env.step(action)

    replay_buffer.extend(episode_buffer)
    if len(replay_buffer) > max_buffer_size:
        replay_buffer = replay_buffer[-max_buffer_size:]

    # Training step
    if len(replay_buffer) >= batch_size:
        dataset = RLReplayDataset(replay_buffer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        trainer.fit(agent, dataloader)

env.close()


Training Episodes:   0%|          | 0/40 [00:00<?, ?it/s]/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (40) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Training Episodes:   2%|▎         | 1/40 [01:52<1:13:15, 112.71s/it]/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /kaggle/working/lightning_logs/version_0/checkpoints exists and is not empty.
Training Episodes: 100%|██████████| 40/40 [29:52<00:00, 44.81s/it]


In [6]:
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import pytorch_lightning as pl
# import random
# import os
# from magent2.environments import battle_v4
# from torch.utils.data import DataLoader, Dataset
# from tqdm import tqdm

# class RLReplayDataset(Dataset):
#     def __init__(self, replay_buffer):
#         self.replay_buffer = replay_buffer

#     def __len__(self):
#         return len(self.replay_buffer)

#     def __getitem__(self, idx):
#         state, action, reward, next_state, done = self.replay_buffer[idx]
#         state = torch.tensor(state, dtype=torch.float32)
#         next_state = torch.tensor(next_state, dtype=torch.float32)
#         action = torch.tensor(action, dtype=torch.long)
#         reward = torch.tensor(reward, dtype=torch.float32)
#         done = torch.tensor(done, dtype=torch.float32)
#         return state, action, reward, next_state, done

# def collate_fn(batch):
#     states_list, actions_list, rewards_list, next_states_list, dones_list = zip(*batch)
#     states = torch.stack(states_list, dim=0)      # (B,H,W,C)
#     next_states = torch.stack(next_states_list,0) # (B,H,W,C)
#     actions = torch.stack(actions_list)
#     rewards = torch.stack(rewards_list)
#     dones = torch.stack(dones_list)

#     # Để tương thích với logic cũ, ta giả lập dict {'blue': ...}
#     return {'blue': states}, actions, rewards, {'blue': next_states}, dones

# class ResidualBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, stride=1):
#         super().__init__()
#         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
#         self.bn1 = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
#         self.bn2 = nn.BatchNorm2d(out_channels)

#         if in_channels != out_channels or stride != 1:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
#                 nn.BatchNorm2d(out_channels)
#             )
#         else:
#             self.shortcut = nn.Identity()

#     def forward(self, x):
#         identity = self.shortcut(x)
#         out = self.relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += identity
#         out = self.relu(out)
#         return out

# class QNetwork(pl.LightningModule):
#     def __init__(self, observation_shape=(13,13,5), action_shape=21, epsilon=0.2):
#         super().__init__()
#         self.save_hyperparameters()
#         self.action_shape = action_shape
#         self.epsilon = epsilon
#         C = observation_shape[-1]
#         H, W = observation_shape[0], observation_shape[1]

#         # ResNet-like structure
#         self.stage1 = nn.Sequential(
#             ResidualBlock(C, C, stride=1),
#             ResidualBlock(C, C, stride=1)
#         )

#         self.stage2 = nn.Sequential(
#             ResidualBlock(C, C*2, stride=2),
#             ResidualBlock(C*2, C*2, stride=1)
#         )

#         self.stage3 = nn.Sequential(
#             ResidualBlock(C*2, C*4, stride=2),
#             ResidualBlock(C*4, C*4, stride=1)
#         )

#         self.upsample = nn.Upsample(size=(H, W), mode='bilinear', align_corners=False)

#         with torch.no_grad():
#             dummy_input = torch.randn(*observation_shape).permute(2,0,1).unsqueeze(0)
#             x = self.stage1(dummy_input)
#             x = self.stage2(x)
#             x = self.stage3(x)
#             x = self.upsample(x)
#             flatten_dim = x.view(1, -1).shape[1]

#         self.network = nn.Sequential(
#             nn.Linear(flatten_dim, 120),
#             nn.ReLU(),
#             nn.Linear(120, 84),
#             nn.ReLU(),
#             nn.Linear(84, action_shape),
#         )

#     def forward(self, obs):
#         # obs: (B,H,W,C)
#         obs = obs.permute(0,3,1,2).contiguous() # (B,C,H,W)
#         x = self.stage1(obs)
#         x = self.stage2(x)
#         x = self.stage3(x)
#         x = self.upsample(x)
#         x = x.reshape(x.shape[0], -1)
#         return self.network(x)

#     def select_action(self, obs, eval_mode=False):
#         if len(obs.shape) == 3:
#             obs = obs.unsqueeze(0)
#         if not eval_mode and random.random() < self.epsilon:
#             return random.randint(0, self.action_shape - 1)
#         with torch.no_grad():
#             q_values = self(obs)
#         return torch.argmax(q_values, dim=-1).item()

#     def training_step(self, batch, batch_idx):
#         states, actions, rewards, next_states, dones = batch
#         blue_obs = states['blue']
#         next_blue_obs = next_states['blue']
#         q_values = self(blue_obs)
#         with torch.no_grad():
#             q_values_next = self(next_blue_obs)
#         max_next_q = q_values_next.max(dim=1)[0]
#         target = rewards + 0.9 * max_next_q * (1 - dones)

#         q_values_current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
#         loss = nn.MSELoss()(q_values_current, target)
#         self.log('train_loss', loss, on_step=True, on_epoch=True)
#         return loss

#     def configure_optimizers(self):
#         return optim.Adam(self.parameters(), lr=0.001)

