In [30]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
import os
import sys
sys.path.append(f'{os.getcwd()}/../')

In [32]:
from typing import Union, Optional, Any, Set
from abc import ABC

import random

import gymnasium
import numpy as np
from numpy.random import Generator
import gymnasium.spaces
from gymnasium.spaces import Discrete, Box

from src.restrictor import Restrictor, RestrictorActionSpace
from src.wrapper import RestrictionWrapper
from src.restrictions import DiscreteSetRestriction, IntervalUnionRestriction

from examples.envs.nfg import NFGEnvironment
from examples.utils import play, restriction_aware_random_policy

In [33]:
observation_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
action_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
utilities = {
    'player_0': (lambda actions: -actions['player_0'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_0']), 
    'player_1': (lambda actions: -actions['player_1'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_1'])}

env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')

## Play without wrapper

In [37]:
def agent_policy(observation):
    print(f'{observation=}')
    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': agent_policy, 'player_1': agent_policy}
play(env, policies)

player_0: None, player_1: None
observation=array([None], dtype=object)
player_0: 60, player_1: None
observation=array([60])
60
player_0: 60, player_1: 24.0
observation=array([24.])
24.0
player_0: 42.0, player_1: None
observation=array([42.])
42.0
player_0: 42.0, player_1: 33.0
observation=array([33.])
33.0
player_0: 37.5, player_1: None
observation=array([37.5])
37.5
player_0: 37.5, player_1: 35.25
observation=array([35.25])
35.25
player_0: 36.375, player_1: None
observation=array([36.375])
36.375
player_0: 36.375, player_1: 35.8125
observation=array([35.8125])
35.8125
player_0: 36.09375, player_1: None
observation=array([36.09375])
36.09375
player_0: 36.09375, player_1: 35.953125
observation=array([35.953125])
35.953125
player_0: 36.0234375, player_1: None
observation=array([36.0234375])
36.0234375
player_0: 36.0234375, player_1: 35.98828125
observation=array([35.98828125])
35.98828125
player_0: 36.005859375, player_1: None
observation=array([36.00585938])
36.005859375
player_0: 36.00

## Play with wrapper

In [43]:
class NFGRestrictor(Restrictor):
    def __init__(self) -> None:
        super().__init__()

    def act(self, observation):
        # observation = env.state() since no preprocesssing was applied
        return IntervalUnionRestriction(Box(0, 30))

In [44]:
def restriction_aware_agent_policy(observation):
    observation, restriction = observation['observation'], observation['restriction']

    

    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': restriction_aware_agent_policy, 'player_1': restriction_aware_agent_policy}

In [45]:
env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')
restrictor = NFGRestrictor()
wrapper = RestrictionWrapper(env)

play(wrapper, policies={**policies, 'restrictor_0': restrictor.act})

player_0: None, player_1: None
player_0: 74, player_1: None
74
player_0: 74, player_1: 17.0
17.0
player_0: 45.5, player_1: None
45.5
player_0: 45.5, player_1: 31.25
31.25
player_0: 38.375, player_1: None
38.375
player_0: 38.375, player_1: 34.8125
34.8125
player_0: 36.59375, player_1: None
36.59375
player_0: 36.59375, player_1: 35.703125
35.703125
player_0: 36.1484375, player_1: None
36.1484375
player_0: 36.1484375, player_1: 35.92578125
35.92578125
player_0: 36.037109375, player_1: None
36.037109375
player_0: 36.037109375, player_1: 35.9814453125
35.9814453125
player_0: 36.00927734375, player_1: None
36.00927734375
player_0: 36.00927734375, player_1: 35.995361328125
35.995361328125
player_0: 36.0023193359375, player_1: None
36.0023193359375
player_0: 36.0023193359375, player_1: 35.99884033203125
35.99884033203125
player_0: 36.000579833984375, player_1: None
36.000579833984375
player_0: 36.000579833984375, player_1: 35.99971008300781
35.99971008300781
player_0: 36.000144958496094, playe