In [None]:
import gym
import torch
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Tuple, Optional
from torch.distributions.categorical import Categorical

import numpy as np
import os
from loguru import logger
from torch.optim import Adam
import matplotlib.pyplot as plt
import copy
import torch.nn.functional as F
from pettingzoo.atari import combat_tank_v2
from pettingzoo.atari import space_war_v2
from pettingzoo.mpe import simple_v3
from itertools import count
from src.utils import (
    
    save_episode_as_gif,
    loss_fn,
    loss_fn_dqn
)
from src.agent import Agent
from src.agent_dqn import Agent_dqn
from src.policy import ValueFunctionQ
from src.buffer import ReplayBuffer

import numpy as np
from IPython.display import clear_output, display
from PIL import Image
from IPython.display import Image as IPImage
import io

plt.ion()
import warnings
warnings.filterwarnings("ignore")
torch.set_default_dtype(torch.float32)
SEED = 42

In [None]:

# instantiate the environment

# env = space_war_v2.env(render_mode="rgb_array")
env = combat_tank_v2.env(render_mode="rgb_array", has_maze=False)
env.reset(seed=SEED)

# get the state and action dimensions
observation = env.last()[0]  # Get initial observation
H, W, C = observation.shape # (height, width, channels)
action_dim = env.action_space("first_0").n

# num_actions = environment.action_space.n
# state_dimension = environment.observation_space.shape[0]
print(f"Input channels: {C}\nAction space: {action_dim}")

## Model Definition

In [None]:


################################## Hyper-parameters Tuning ##################################
state_dimension: int = 16
num_actions: int = action_dim
EPOCHS: int = 250
HIDDEN_DIMENSION: int = 16
LEARNING_RATE: float = 3e-3
DISCOUNT_FACTOR: float = .97
EPISODES: int = 3_000
gamma = DISCOUNT_FACTOR
BATCH_SIZE: int = 8  # for dqn

#############################################################################################

# Instantiate the policy network
# policy1 = Policy(
#     state_dimension, num_actions, hidden_dimension=HIDDEN_DIMENSION
# ).to(DEVICE)

player_one = Agent('first_0', state_dimension, num_actions, HIDDEN_DIMENSION, LEARNING_RATE, obs_dim=(C, H, W), gamma=DISCOUNT_FACTOR)
player_two = Agent('second_0', state_dimension, num_actions, HIDDEN_DIMENSION, 3*LEARNING_RATE, obs_dim=(C, H, W), gamma=DISCOUNT_FACTOR)
# player_one = Agent_dqn('first_0', state_dimension, num_actions, HIDDEN_DIMENSION, LEARNING_RATE, BATCH_SIZE, obs_dim=(C, H, W), gamma=DISCOUNT_FACTOR)
# player_two = Agent_dqn('second_0', state_dimension, num_actions, HIDDEN_DIMENSION, LEARNING_RATE, BATCH_SIZE, obs_dim=(C, H, W), gamma=DISCOUNT_FACTOR)

In [None]:
all_players = [player_one, player_two]
agents_map = dict(zip(env.agents, all_players))

agent_scores =  {k: 0 for k in env.agents}

for epoch in range(EPOCHS):
    # run one episode 
    env.reset()
    for step, agent_name in enumerate(env.agent_iter()):
        
        agent = agents_map[agent_name]
        win = agent.take_action(env)
        # try:
        #     win = agent.take_action(env)
        # except:
        #     # env.reset()
        #     # winner = agent_name
        #     break
        #     #continue
        
        if win or step > 2*EPISODES:
            agent_scores[agent_name] += int(win)
            winner = agent_name
            break
    
    logger.info(f'Epoch: {epoch+1:4}/{EPOCHS} \t| Winner: {winner:10} \t| Steps: {step}')
    for player in all_players:
        loss, reward = player.optimize(loss_fn)  #  loss_fn for REINFORCE, loss_fn_dqn for DQN
        logger.debug(f"Player: {player.name} \t| Cache Size: {len(player.cache)} \t| Loss: {loss*1e8:.5f} \t| Reward: {reward :.3f}")
        player.clear_cache()
        player.save()
    
    
    if (epoch+1) % 50 == 0:
        save_dir = "episodes"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_episode_as_gif(env, agents_map, save_path=f"{save_dir}/epoch_{epoch+1}.gif", fps=60)    

In [None]:
agent_scores

In [None]:
gif_path = save_episode_as_gif(env, agents_map, fps=60)
IPImage(open(gif_path,'rb').read())