In [None]:
!pip install -q git+https://github.com/Farama-Foundation/MAgent2

# Init

In [1]:
# Libraries
import numpy as np
from magent2.environments import battle_v4

# Modules
from model import VdnQNet
from team import TeamManager
from utils import save_data, seed, device, VdnHyperparameters
from train import train, run_episode, run_model_train_test

# Training

In [None]:
# Training

save_name_team1 = 'vdn_blue'
save_name_team2 = 'vdn_red'

# Hyperparameters
hp = VdnHyperparameters(
    lr=0.002,
    gamma=0.99,
    batch_size=512,
    buffer_limit=9000,
    max_episodes=200,
    max_epsilon=0.9,
    min_epsilon=0.1,
    episode_min_epsilon=100,
    test_episodes=1,
    warm_up_steps=3000,
    update_iter=20,
    chunk_size=1,
    update_target_interval=20,
    recurrent=True
)
print(hp)

# Create environment
env = battle_v4.parallel_env(map_size=45)
test_env = battle_v4.parallel_env(map_size=45)

env.reset(seed=seed)
test_env.reset(seed=seed)
team_manager = TeamManager(env.agents)

# Create models for two teams
q_team1 = VdnQNet(team_manager.get_my_agents(), env.observation_spaces, env.action_spaces).to(device)
q_target_team1 = VdnQNet(team_manager.get_my_agents(), env.observation_spaces, env.action_spaces).to(device)

q_team2 = VdnQNet(team_manager.get_other_agents(), env.observation_spaces, env.action_spaces).to(device)
q_target_team2 = VdnQNet(team_manager.get_other_agents(), env.observation_spaces, env.action_spaces).to(device)

# Run training for both teams
train_scores_team1, train_scores_team2, test_scores_team1, test_scores_team2, losses_team1, losses_team2 = run_model_train_test(
    env, 
    test_env, 
    q_team1, q_team2, 
    q_target_team1, q_target_team2, 
    save_name_team1, save_name_team2, 
    team_manager, 
    hp, 
    train, 
    run_episode
)

# Save data for Team 1
save_data(np.array(train_scores_team1), f'{save_name_team1}-train_scores')
save_data(np.array(test_scores_team1), f'{save_name_team1}-test_scores')
save_data(np.array(losses_team1), f'{save_name_team1}-losses')

# Save data for Team 2
save_data(np.array(train_scores_team2), f'{save_name_team2}-train_scores')
save_data(np.array(test_scores_team2), f'{save_name_team2}-test_scores')
save_data(np.array(losses_team2), f'{save_name_team2}-losses')