In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/trained_agent.pth
/kaggle/input/blue_agent_14.pth
/kaggle/input/red.pt
/kaggle/input/blue_agent_13.pth


In [None]:
! pip install magent2 pytorch_lightning
! pip install pettingzoo==1.22.0


Collecting magent2
  Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting pygame>=2.1.0 (from magent2)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading magent2-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hDownloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m87.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: pygame, magent2
Successfully installed magent2-0.3.3 pygame-2.6.1


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import random
import os
from magent2.environments import battle_v4
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


class RLReplayDataset(Dataset):
    def __init__(self, replay_buffer):
        self.replay_buffer = replay_buffer

    def __len__(self):
        return len(self.replay_buffer)

    def __getitem__(self, idx):
        state, action, reward, next_state, done = self.replay_buffer[idx]
        # state, next_state: (H,W,C)
        state = torch.tensor(state, dtype=torch.float32)
        next_state = torch.tensor(next_state, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)

        states = state  # (H,W,C)
        next_states = next_state
        return states, action, reward, next_states, done

def collate_fn(batch):
    states_list, actions_list, rewards_list, next_states_list, dones_list = zip(*batch)

    states = torch.stack(states_list, dim=0)        # (B,H,W,C)
    next_states = torch.stack(next_states_list,0)   # (B,H,W,C)
    actions = torch.stack(actions_list)
    rewards = torch.stack(rewards_list)
    dones = torch.stack(dones_list)

    return {'blue': states}, actions, rewards, {'blue': next_states}, dones

class SpatialCNN(nn.Module):
    def __init__(self, in_channels=5, out_channels=32):
        super(SpatialCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
    def forward(self, x):
        # x: (B, C, H, W)
        return self.conv(x)  # (B, out_channels, H, W)

class FunctionalPolicyAgent(pl.LightningModule):
    def __init__(self, action_space_size, embed_dim=5, height=13, width=13, hidden_dim=256, dropout=0.3, epsilon=0.2):
        super(FunctionalPolicyAgent, self).__init__()
        self.action_space_size = action_space_size
        self.epsilon = epsilon
        self.height = height
        self.width = width
        self.hidden_dim = hidden_dim
        self.dropout = dropout

        # Spatial CNN
        self.spatial = SpatialCNN(in_channels=embed_dim, out_channels=32)

        # Q-network
        self.q_network = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=1),
            nn.Flatten(),
            nn.Linear(height*width*3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, action_space_size)
        )

    def forward(self, obs):
        # obs: (B,H,W,C)
        obs = obs.permute(0,3,1,2).contiguous()  # (B,C,H,W)
        spatial_features = self.spatial(obs)  # (B,32,H,W)
        q_values = self.q_network(spatial_features)
        return q_values

    def select_action(self, obs, eval_mode=False):
        if len(obs.shape) == 3:
            obs = obs.unsqueeze(0)  # (1,H,W,C)
        if not eval_mode and random.random() < self.epsilon:
            return random.randint(0, self.action_space_size - 1)
        with torch.no_grad():
            q_values = self.forward(obs)
        return torch.argmax(q_values, dim=-1).item()

    def training_step(self, batch, batch_idx):
        states, actions, rewards, next_states, dones = batch
        blue_obs = states['blue']
        next_blue_obs = next_states['blue']
        actions = actions
        rewards = rewards
        dones = dones

        q_values = self.forward(blue_obs)
        with torch.no_grad():
            q_values_next = self.forward(next_blue_obs)
        max_next_q = q_values_next.max(dim=1)[0]
        target = rewards + 0.9 * max_next_q * (1 - dones)

        q_values_current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
        loss = nn.MSELoss()(q_values_current, target)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)


In [None]:
# 0.005 3 -0.5
# 0.003 2 -0.2
env = battle_v4.env(map_size=45, max_cycles=200, step_reward=0.005, attack_opponent_reward=3, dead_penalty=-0.5)
replay_buffer = []
max_buffer_size = 10000
batch_size = 256
n_episodes = 100
action_space_size = 21
blue_agent = FunctionalPolicyAgent(action_space_size, embed_dim=5, height=13, width=13)
red_agent = FunctionalPolicyAgent(action_space_size, embed_dim=5, height=13, width=13)
blue_agent.load_state_dict(torch.load("/kaggle/input/rl-parameter/blue_agent_13.pth"))

trainer = pl.Trainer(max_epochs=3, devices=2, accelerator='gpu' if torch.cuda.is_available() else 'cpu')
device ='cuda' if torch.cuda.is_available() else 'cpu'
red_update_interval = 10
red_agent.load_state_dict(blue_agent.state_dict())
prev_states = {}
prev_actions = {}
def preprocess_observation(obs, agent_team):

    if agent_team == 'red':
        return obs[:, ::-1, :].copy() 
    return obs

for episode in tqdm(range(n_episodes), desc="Training episodes"):
    env.reset()
    prev_states.clear()
    prev_actions.clear()

    done_agents = set()
    for agent_name in env.agent_iter():
        obs, reward, termination, truncation, info = env.last()
        agent_team = agent_name.split('_')[0]
        done_flag = termination or truncation
        processed_obs = preprocess_observation(obs, agent_team)
        if done_flag:
            action = None
            done_agents.add(agent_name)
        else:
            obs_tensor = torch.tensor(processed_obs, dtype=torch.float32)
            if agent_team == 'blue':
                action = blue_agent.select_action(obs_tensor)
            else:
                action = red_agent.select_action(obs_tensor)

        if agent_name in prev_states and prev_actions[agent_name] is not None:
            next_state = obs
            replay_buffer.append((prev_states[agent_name], prev_actions[agent_name], float(reward), next_state, float(done_flag)))
            if len(replay_buffer) > max_buffer_size:
                replay_buffer = replay_buffer[-max_buffer_size:]

        if not done_flag:
            prev_states[agent_name] = obs
            prev_actions[agent_name] = action
        else:
            # agent done
            prev_states[agent_name] = obs
            prev_actions[agent_name] = None

        env.step(action)

    if len(replay_buffer) >= batch_size:
        dataset = RLReplayDataset(replay_buffer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        trainer  = pl.Trainer(max_epochs=3, devices=2, accelerator='gpu' if torch.cuda.is_available() else 'cpu')
        trainer.fit(blue_agent, dataloader)
    if episode % red_update_interval == 0:
        red_agent.load_state_dict(blue_agent.state_dict())


torch.save(blue_agent.state_dict(), "blue_agent.pth")
print("Model parameters saved.")

env.close()