In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.append(f'{os.getcwd()}/../')

In [8]:
import numpy as np
from numpy.random import Generator
from gymnasium.spaces import Box

from src.wrapper import RestrictionWrapper
from src.restrictors import Restrictor
from src.restrictions import IntervalUnionRestriction

from examples.envs.nfg import NFGEnvironment
from examples.utils import play

In [4]:
observation_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
action_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
utilities = {
    'player_0': (lambda actions: -actions['player_0'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_0']), 
    'player_1': (lambda actions: -actions['player_1'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_1'])}

env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')

## Play without wrapper

In [5]:
def agent_policy(observation):
    print(f'{observation=}')
    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': agent_policy, 'player_1': agent_policy}
play(env, policies)

player_0: None, player_1: None
observation=array([None], dtype=object)
player_0: 66, player_1: None
observation=array([66])
66
player_0: 66, player_1: 21.0
observation=array([21.])
21.0
player_0: 43.5, player_1: None
observation=array([43.5])
43.5
player_0: 43.5, player_1: 32.25
observation=array([32.25])
32.25
player_0: 37.875, player_1: None
observation=array([37.875])
37.875
player_0: 37.875, player_1: 35.0625
observation=array([35.0625])
35.0625
player_0: 36.46875, player_1: None
observation=array([36.46875])
36.46875
player_0: 36.46875, player_1: 35.765625
observation=array([35.765625])
35.765625
player_0: 36.1171875, player_1: None
observation=array([36.1171875])
36.1171875
player_0: 36.1171875, player_1: 35.94140625
observation=array([35.94140625])
35.94140625
player_0: 36.029296875, player_1: None
observation=array([36.02929688])
36.029296875
player_0: 36.029296875, player_1: 35.9853515625
observation=array([35.98535156])
35.9853515625
player_0: 36.00732421875, player_1: None
o

## Play with wrapper

In [13]:
class NFGRestrictor(Restrictor):
    def __init__(self, observation_space, action_space) -> None:
        super().__init__(observation_space=observation_space, action_space=action_space)

    def act(self, observation) -> IntervalUnionRestriction:
        # observation = env.state() since no preprocesssing was applied
        return IntervalUnionRestriction(Box(0.0, 30.0))

In [14]:
def restriction_aware_agent_policy(observation):
    observation, restriction = observation['observation'], observation['restriction']

    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': restriction_aware_agent_policy, 'player_1': restriction_aware_agent_policy}

In [17]:
from src.restrictors import IntervalUnionActionSpace

env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')
restrictor = NFGRestrictor(Box(0, 120), IntervalUnionActionSpace(Box(0, 120)))
wrapper = RestrictionWrapper(env, restrictor)

play(wrapper, policies={**policies, 'restrictor_0': restrictor.act})

player_0: None, player_1: None
player_0: 0, player_1: None
0
player_0: 0, player_1: 54.0
54.0
player_0: 27.0, player_1: None
27.0
player_0: 27.0, player_1: 40.5
40.5
player_0: 33.75, player_1: None
33.75
player_0: 33.75, player_1: 37.125
37.125
player_0: 35.4375, player_1: None
35.4375
player_0: 35.4375, player_1: 36.28125
36.28125
player_0: 35.859375, player_1: None
35.859375
player_0: 35.859375, player_1: 36.0703125
36.0703125
player_0: 35.96484375, player_1: None
35.96484375
player_0: 35.96484375, player_1: 36.017578125
36.017578125
player_0: 35.9912109375, player_1: None
35.9912109375
player_0: 35.9912109375, player_1: 36.00439453125
36.00439453125
player_0: 35.997802734375, player_1: None
35.997802734375
player_0: 35.997802734375, player_1: 36.0010986328125
36.0010986328125
player_0: 35.99945068359375, player_1: None
35.99945068359375
player_0: 35.99945068359375, player_1: 36.000274658203125
36.000274658203125
player_0: 35.99986267089844, player_1: None
35.99986267089844
player_0: