In [None]:
import inspect
import functools
import random
from copy import copy

import numpy as np
from gymnasium.spaces import Discrete, MultiDiscrete

from pettingzoo.utils.env import ParallelEnv
from pettingzoo.test import parallel_api_test

In [None]:
class PrisonEnvironment(ParallelEnv):
    metadata = {
        "name": "prison_environment_v0",
    }

    def __init__(self):
        self.escape_y = None
        self.escape_x = None
        self.guard_y = None
        self.guard_x = None
        self.prisoner_y = None
        self.prisoner_x = None
        self.timestep = None
        self.possible_agents = ["prisoner", "guard"]

    def reset(self, seed=None, options=None):
        self.agents = copy(self.possible_agents)
        self.timestep = 0

        self.prisoner_x = 0
        self.prisoner_y = 0

        self.guard_x = 7
        self.guard_y = 7

        self.escape_x = random.randint(2, 5)
        self.escape_y = random.randint(2, 5)

        observations = {
            a: (
                self.prisoner_x + 7 * self.prisoner_y,
                self.guard_x + 7 * self.guard_y,
                self.escape_x + 7 * self.escape_y,
            )
            for a in self.agents
        }
        return observations, {}

    def step(self, actions):
        # Execute actions
        prisoner_action = actions["prisoner"]
        guard_action = actions["guard"]

        if prisoner_action == 0 and self.prisoner_x > 0:
            self.prisoner_x -= 1
        elif prisoner_action == 1 and self.prisoner_x < 6:
            self.prisoner_x += 1
        elif prisoner_action == 2 and self.prisoner_y > 0:
            self.prisoner_y -= 1
        elif prisoner_action == 3 and self.prisoner_y < 6:
            self.prisoner_y += 1

        if guard_action == 0 and self.guard_x > 0:
            self.guard_x -= 1
        elif guard_action == 1 and self.guard_x < 6:
            self.guard_x += 1
        elif guard_action == 2 and self.guard_y > 0:
            self.guard_y -= 1
        elif guard_action == 3 and self.guard_y < 6:
            self.guard_y += 1

        # Check termination conditions
        terminations = {a: False for a in self.agents}
        rewards = {a: 0 for a in self.agents}
        if self.prisoner_x == self.guard_x and self.prisoner_y == self.guard_y:
            rewards = {"prisoner": -1, "guard": 1}
            terminations = {a: True for a in self.agents}

        elif self.prisoner_x == self.escape_x and self.prisoner_y == self.escape_y:
            rewards = {"prisoner": 1, "guard": -1}
            terminations = {a: True for a in self.agents}

        # Check truncation conditions (overwrites termination conditions)
        truncations = {a: False for a in self.agents}
        if self.timestep > 100:
            rewards = {"prisoner": 0, "guard": 0}
            truncations = {"prisoner": True, "guard": True}
            self.agents = []
        self.timestep += 1

        # Get observations
        observations = {
            a: (
                self.prisoner_x + 7 * self.prisoner_y,
                self.guard_x + 7 * self.guard_y,
                self.escape_x + 7 * self.escape_y,
            )
            for a in self.agents
        }

        # Get dummy infos (not used in this example)
        infos = {a: {} for a in self.agents}

        print(f'{terminations=}, {truncations=}')

        return observations, rewards, terminations, truncations, infos

    def render(self):
        grid = np.zeros((7, 7))
        grid[self.prisoner_y, self.prisoner_x] = "P"
        grid[self.guard_y, self.guard_x] = "G"
        grid[self.escape_y, self.escape_x] = "E"
        print(f"{grid} \n")

    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent):
        return MultiDiscrete([7 * 7 - 1] * 3)

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent):
        return Discrete(4)


In [None]:
class Wrapper:
    def __init__(self) -> None:
        pass

In [None]:
env = PrisonEnvironment()
env.reset()

In [None]:
env = PrisonEnvironment()
parallel_api_test(env, num_cycles=1_000_000)

In [None]:
def test(env, *, env_config={}, num_steps=20, num_episodes=5, env_type='parallel'):
    if inspect.isclass(env):
        env = env(env_config=env_config)

    print(f'Testing {type(env)}...')

    if num_episodes is None and num_steps is None:
        print(f'Warning: Test will run forever!')

    obs = env.reset()
    print(f'Reset: {obs}')

    current_episode, current_step = 0, 0
    while True:
        if env_type == 'parallel':
            actions = { id: {agent: env.governance_action_space.sample() for agent in env.env.agents} if id == 'gov' 
                    else env.action_space.sample() for id in obs }
        else:
            actions = { id: env.governance_action_space.sample() if id == 'gov' 
                    else env.action_space.sample() for id in obs }
        
        print(f'Actions: {actions}')
        obs, rewards, done, info = env.step(actions)
        current_step += 1
        print(f'Step:  {obs}, {rewards}, {done}')
        
        if done['__all__']:
            current_episode += 1
            obs = env.reset()
            print(f'Reset: {obs}')

        if num_steps is not None and current_step >= num_steps:
            break

        if num_episodes is not None and current_episode >= num_episodes:
            break

    print(f'Test finished!')