In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append(f'{os.getcwd()}/../')

In [None]:
from typing import Union, Optional, Any, Set
from abc import ABC

import random

import gymnasium
import numpy as np
from numpy.random import Generator
import gymnasium.spaces
from gymnasium.spaces import Discrete, Box

from src.restrictor import Restrictor, RestrictorActionSpace
from src.wrapper import RestrictionWrapper
from src.restriction import DiscreteSetRestriction, IntervalUnionRestriction, PredicateRestriction

from examples.envs.nfg import NFGEnvironment
from examples.utils import play, restriction_aware_random_policy

In [None]:
observation_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
action_spaces = {'player_0': Box(0, 120), 'player_1': Box(0, 120)}
utilities = {
    'player_0': (lambda actions: -actions['player_0'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_0']), 
    'player_1': (lambda actions: -actions['player_1'] ** 2 - actions['player_0'] * actions['player_1'] + 108 * actions['player_1'])}

env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')

## Play without wrapper

In [None]:
def agent_policy(observation):
    print(f'{observation=}')
    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': agent_policy, 'player_1': agent_policy}
play(env, policies)

## Play with wrapper

In [None]:
class NFGRestrictor(Restrictor):
    def __init__(self) -> None:
        super().__init__()

    def act(self, observation) -> RestrictorActionSpace:
        # observation = env.state() since no preprocesssing was applied
        return IntervalUnionRestriction(Box(0, 30))

In [None]:
def restriction_aware_agent_policy(observation):
    observation, restriction = observation['observation'], observation['restriction']

    opponent_action = observation[0]
    if opponent_action is None:
        return np.random.randint(0, 121)
    else:
        print(opponent_action)
        return np.clip(54 - opponent_action / 2, 0, 120)

policies = {'player_0': restriction_aware_agent_policy, 'player_1': restriction_aware_agent_policy}

In [None]:
env = NFGEnvironment(observation_spaces, action_spaces, utilities, 10, render_mode='human')
restrictor = NFGRestrictor()
wrapper = RestrictionWrapper(env)

play(wrapper, policies={**policies, 'restrictor_0': restrictor.act})