In [1]:
import torch
import numpy as np
from loguru import logger
from pettingzoo.atari import combat_tank_v2
from src.utils import save_episode_as_gif
from src.agent_sac import AgentSAC

import warnings
warnings.filterwarnings("ignore")
torch.set_default_dtype(torch.float32)
SEED = 42

[32m2024-12-01 15:02:50.561[0m | [1mINFO    [0m | [36msrc.utils[0m:[36mdevice[0m:[36m66[0m - [1mUsing cpu device.[0m


In [2]:
# Instantiate the environment
env = combat_tank_v2.env(render_mode="rgb_array", has_maze=False)
env.reset(seed=SEED)

# Get observation dimensions
observation = env.last()[0]
H, W, C = observation.shape
action_dim = env.action_space("first_0").n

print(f"Input channels: {C}\nAction space: {action_dim}")

Input channels: 3
Action space: 18


In [3]:
# Hyperparameters
state_dimension: int = 16
num_actions: int = action_dim
EPOCHS: int = 2000
HIDDEN_DIMENSION: int = 16
LEARNING_RATE: float = 3e-4  # Slightly lower learning rate for SAC
DISCOUNT_FACTOR: float = 0.99
EPISODES: int = 3_000
ALPHA: float = 0.2  # Temperature parameter
TAU: float = 0.005  # Soft update rate
BATCH_SIZE: int = 256
BUFFER_SIZE: int = 1_000_000

# Initialize SAC agents
player_one = AgentSAC(
    name='first_0',
    state_dimension=state_dimension,
    action_dim=num_actions,
    hidden_dimension=HIDDEN_DIMENSION,
    learning_rate=LEARNING_RATE,
    obs_dim=(C, H, W),
    alpha=ALPHA,
    gamma=DISCOUNT_FACTOR,
    tau=TAU,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE
)

player_two = AgentSAC(
    name='second_0',
    state_dimension=state_dimension,
    action_dim=num_actions,
    hidden_dimension=HIDDEN_DIMENSION,
    learning_rate=LEARNING_RATE,
    obs_dim=(C, H, W),
    alpha=ALPHA,
    gamma=DISCOUNT_FACTOR,
    tau=TAU,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE
)

In [None]:
# Training loop
all_players = [player_one, player_two]
agents_map = dict(zip(env.agents, all_players))
agent_scores = {k: 0 for k in env.agents}

for epoch in range(EPOCHS):
    # Run one episode
    env.reset()
    for step, agent_name in enumerate(env.agent_iter()):
        agent = agents_map[agent_name]
        win = agent.take_action(env)

        if win or step > 2*EPISODES:
            agent_scores[agent_name] += int(win)
            winner = agent_name
            break
    
    # Optimize both agents
    logger.info(f'Epoch: {epoch+1:4}/{EPOCHS} \t| Winner: {winner:10} \t| Steps: {step}')
    for player in all_players:
        critic_loss, reward = player.optimize()
        
        logger.debug(
            f"Player: {player.name} \t| "
            f"Buffer Size: {len(player.replay_buffer)} \t| "
            f"Loss: {critic_loss:.5f} \t| "
            f"Reward: {reward:.3f}"
        )
        player.save()
    
    # Save episode visualization
    if (epoch+1) % 50 == 0:
        save_episode_as_gif(
            env, 
            agents_map, 
            save_path=f"episodes/sac_epoch_{epoch+1}.gif",
            fps=60
        )

AssertionError: action is not in action space

In [None]:
# Display final scores
print("\nFinal Scores:")
for agent_name, score in agent_scores.items():
    print(f"{agent_name}: {score}")