In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
import os
import sys
sys.path.append(f'{os.getcwd()}/../')

In [38]:
from typing import Union, Optional, Any, Set, Sequence
from abc import ABC

import functools
import random

import gymnasium
import numpy as np
import numpy.typing as npt
import gymnasium.spaces
from gymnasium.spaces import Discrete, Space, Dict

from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector, wrappers
from pettingzoo.test import api_test
from pettingzoo.utils.env import AECIterable, AgentID, ActionType, ObsType

from src.wrapper import RestrictionWrapper
from examples.envs.rps import RPSEnvironment

In [39]:
class Restriction(ABC, gymnasium.Space):
    def __init__(self, base_space: gymnasium.Space, *, seed: int | np.random.Generator | None = None):
        super().__init__(base_space.shape, base_space.dtype, seed)
        self.base_space = base_space

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}'

class DiscreteRestriction(Restriction, ABC):
    def __init__(self, base_space: gymnasium.spaces.Discrete, *, seed: int | np.random.Generator | None = None):
        super().__init__(base_space, seed=seed)

class ContinuousRestriction(Restriction, ABC):
    def __init__(self, base_space: gymnasium.spaces.Box, *, seed: int | np.random.Generator | None = None):
        super().__init__(base_space, seed=seed)

class DiscreteSetRestriction(DiscreteRestriction):
    def __init__(self, base_space: gymnasium.spaces.Discrete, *, allowed_actions: Optional[Set[int]] = None, seed: int | np.random.Generator | None = None):
        super().__init__(base_space, seed=seed)
        
        self.allowed_actions = allowed_actions if allowed_actions is not None else set(range(base_space.start, base_space.start + base_space.n))

    @property
    def is_np_flattenable(self) -> bool:
        return True
    
    def sample(self) -> int:
        return random.choice(tuple(self.allowed_actions))

    def contains(self, x: int) -> bool:
        return x in self.allowed_actions
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.allowed_actions})'

class DiscreteVectorRestriction(DiscreteRestriction):
    def __init__(self, base_space: gymnasium.spaces.Discrete, *, allowed_actions: Optional[np.ndarray[bool]] = None, seed: int | np.random.Generator | None = None):
        super().__init__(base_space, seed=seed)
        
        self.allowed_actions = allowed_actions if allowed_actions is not None else set(range(base_space.start, base_space.start + base_space.n))

    def __init__(self, allowed_actions: np.ndarray[bool], start: int = 0):
        self.allowed_actions = allowed_actions
        self.start = start

    @property
    def is_np_flattenable(self) -> bool:
        return True
    
    def sample(self) -> int:
        return self.start + random.choice(tuple(index for index, value in enumerate(self.allowed_actions) if value))

    def contains(self, x: int) -> bool:
        return self.allowed_actions[x - self.start]
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.allowed_actions})'

class IntervalUnionRestriction(ContinuousRestriction):
    pass

In [40]:
class Restrictor:
    default_restriction_classes = {
        gymnasium.spaces.Discrete: DiscreteSetRestriction,
        gymnasium.spaces.Box: IntervalUnionRestriction
    }

    def __init__(self, env, restriction_classes=None) -> None:
        self.action_spaces = {agent: env.action_space(agent) for agent in env.possible_agents}
        self.restriction_classes = restriction_classes or self.default_restriction_classes

class RPSRestrictor(Restrictor):
    def __init__(self, env, restriction_classes=None) -> None:
        super().__init__(env, restriction_classes=restriction_classes)

        self.observation_space = Discrete(1)

    def preprocess_observation(self, env):
        # This functions 'flattens' the environment into a valid space while preserving all information that the restrictor needs
        return {'agent': env.agent_selection, 'last_action': int(env.observe(env.possible_agents[1 - env.agent_name_mapping[env.agent_selection]]))}

    def act(self, observation):
        # Structure of observation is defined by self.preprocess_observation
        agent, last_action = observation['agent'], observation['last_action']

        return DiscreteSetRestriction(Discrete(3), allowed_actions={0, 1, 2} - {last_action})

In [41]:
# class RestrictorActionSpace(ABC, gymnasium.Space):
#     @property
#     def is_np_flattenable(self) -> bool:
#         return False

#     def sample(self) -> Restriction:
#         return DiscreteVectorRestriction(np.array([True]))
        
#     def contains(self, x: Restriction) -> bool:
#         return True
    
#     def __repr__(self) -> str:
#         return f'{self.__class__.__name__}'

In [42]:
def restriction_aware_random_policy(observation):
    observation, restriction = observation['observation'], observation['restriction']
    return restriction.sample()

def create_policies(env, restrictor, restrictor_key='restrictor_0'):
    return {**{ agent: restriction_aware_random_policy for agent in env.possible_agents }, restrictor_key: restrictor.act}

In [43]:
def play(env, policies, *, max_iter=1_000, verbose=False):
    env.reset()
    env.render()

    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()
        if verbose:
            print(f'{agent=}, {observation=}, {reward=}, {termination=}, {truncation=}, {info=}')

        action = policies[agent](observation) if not termination and not truncation else None
        if verbose:
            print(f'{action=}')

        env.step(action)
    
    env.render()

In [45]:
env = RPSEnvironment(render_mode='human')
restrictor = RPSRestrictor(env) # Restrictor blocks each player's last action
wrapper = RestrictionWrapper(env, preprocess_restrictor_observation_fn=restrictor.preprocess_observation)

play(wrapper, create_policies(env, restrictor))

player_0: NONE, player_1: NONE
player_0: ROCK, player_1: NONE
player_0: ROCK, player_1: SCISSORS
player_0: SCISSORS, player_1: NONE
player_0: SCISSORS, player_1: PAPER
player_0: PAPER, player_1: NONE
player_0: PAPER, player_1: SCISSORS
player_0: ROCK, player_1: NONE
player_0: ROCK, player_1: PAPER
player_0: PAPER, player_1: NONE
player_0: PAPER, player_1: SCISSORS
player_0: SCISSORS, player_1: NONE
player_0: SCISSORS, player_1: PAPER
player_0: ROCK, player_1: NONE
player_0: ROCK, player_1: ROCK
player_0: PAPER, player_1: NONE
player_0: PAPER, player_1: SCISSORS
player_0: ROCK, player_1: NONE
player_0: ROCK, player_1: ROCK
player_0: SCISSORS, player_1: NONE
player_0: SCISSORS, player_1: PAPER
Game over
