diff --git a/pearl/api/environment.py b/pearl/api/environment.py index ef9b0df..6c26d24 100644 --- a/pearl/api/environment.py +++ b/pearl/api/environment.py @@ -16,6 +16,7 @@ from pearl.api.action_result import ActionResult from pearl.api.action_space import ActionSpace from pearl.api.observation import Observation +from pearl.api.space import Space class Environment(ABC): @@ -32,11 +33,11 @@ def action_space(self) -> ActionSpace: pass # FIXME: add this and in implement in all concrete subclasses - # @property - # @abstractmethod - # def observation_space(self) -> Space: - # """Returns the observation space of the environment.""" - # pass + @property + @abstractmethod + def observation_space(self) -> Space: + """Returns the observation space of the environment.""" + pass @abstractmethod def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: diff --git a/pearl/utils/instantiations/environments/__init__.py b/pearl/utils/instantiations/environments/__init__.py index 3dc815e..decfb24 100644 --- a/pearl/utils/instantiations/environments/__init__.py +++ b/pearl/utils/instantiations/environments/__init__.py @@ -17,8 +17,8 @@ OneHotObservationsFromDiscrete, ) from .gym_environment import GymEnvironment -from .reward_is_equal_to_ten_times_action_contextual_bandit_environment import ( - RewardIsEqualToTenTimesActionContextualBanditEnvironment, +from .reward_is_equal_to_ten_times_action_multi_arm_bandit_environment import ( + RewardIsEqualToTenTimesActionMultiArmBanditEnvironment, ) from .sparse_reward_environment import ( ContinuousSparseRewardEnvironment, @@ -36,7 +36,7 @@ "FixedNumberOfStepsEnvironment", "GymEnvironment", "OneHotObservationsFromDiscrete", - "RewardIsEqualToTenTimesActionContextualBanditEnvironment", + "RewardIsEqualToTenTimesActionMultiArmBanditEnvironment", "SLCBEnvironment", "SparseRewardEnvironment", ] diff --git a/pearl/utils/instantiations/environments/contextual_bandit_environment.py b/pearl/utils/instantiations/environments/contextual_bandit_environment.py index 5d58033..c10f580 100644 --- a/pearl/utils/instantiations/environments/contextual_bandit_environment.py +++ b/pearl/utils/instantiations/environments/contextual_bandit_environment.py @@ -15,6 +15,7 @@ from pearl.api.environment import Environment from pearl.api.reward import Reward +from pearl.api.space import Space class ContextualBanditEnvironment(Environment, ABC): @@ -31,6 +32,16 @@ class ContextualBanditEnvironment(Environment, ABC): to determine the ActionResult reward. """ + @property + @abstractmethod + def observation_space(self) -> Space: + """ + For multi-arm bandit environments (i.e. no context), we should set the observation space to + simply be an empty tensor, i.e. {tensor([])}. For contextual bandit environments, if the + context space is not be known in advance, then it should also be set to an empty tensor. + """ + pass + @property @abstractmethod def action_space(self) -> ActionSpace: diff --git a/pearl/utils/instantiations/environments/contextual_bandit_linear_synthetic_environment.py b/pearl/utils/instantiations/environments/contextual_bandit_linear_synthetic_environment.py index de192ad..f23e47c 100644 --- a/pearl/utils/instantiations/environments/contextual_bandit_linear_synthetic_environment.py +++ b/pearl/utils/instantiations/environments/contextual_bandit_linear_synthetic_environment.py @@ -16,9 +16,11 @@ from pearl.api.observation import Observation from pearl.api.reward import Value +from pearl.api.space import Space from pearl.utils.instantiations.environments.contextual_bandit_environment import ( ContextualBanditEnvironment, ) +from pearl.utils.instantiations.spaces.box import BoxSpace from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace @@ -35,7 +37,11 @@ class ContextualBanditLinearSyntheticEnvironment(ContextualBanditEnvironment): "A Contextual-Bandit Approach to Personalized News Article Recommendation," The context for an arm is the concatenation of the observation feature vector - and the arm feature vevctor. + and the arm feature vector. + + In this example, the observation space (i.e. the context space) is taken to be + [0, 1]^{`observation_dim`} where `observation_dim` is specified during initialization. + Observations are taken to be random vectors generated from this observation space. """ def __init__( @@ -58,6 +64,9 @@ def __init__( """ assert isinstance(action_space, DiscreteActionSpace) self._action_space: DiscreteActionSpace = action_space + self._observation_space: Space = BoxSpace( + low=torch.zeros((observation_dim)), high=torch.ones((observation_dim)) + ) self.observation_dim = observation_dim self._arm_feature_vector_dim = arm_feature_vector_dim self.reward_noise_sigma = reward_noise_sigma @@ -71,6 +80,10 @@ def __init__( def action_space(self) -> ActionSpace: return self._action_space + @property + def observation_space(self) -> Optional[Space]: + return self._observation_space + @property def arm_feature_vector_dim(self) -> int: return self._arm_feature_vector_dim @@ -107,7 +120,7 @@ def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: """ Provides the observation and action space to the agent. """ - self._observation = torch.rand(self.observation_dim) + self._observation = self._observation_space.sample() return self._observation, self.action_space def get_reward(self, action: Action) -> Value: @@ -122,7 +135,7 @@ def get_reward(self, action: Action) -> Value: def get_regret(self, action: Action) -> Value: """ Given action, environment will return regret for choosing this action - regret == max(reward over all action) - reward for current action + regret == max(reward over all actions) - reward for current action """ rewards = [ self._compute_reward_from_context(self._get_context_for_arm(i)) diff --git a/pearl/utils/instantiations/environments/contextual_bandit_uci_environment.py b/pearl/utils/instantiations/environments/contextual_bandit_uci_environment.py index a5c91fb..1e224dc 100644 --- a/pearl/utils/instantiations/environments/contextual_bandit_uci_environment.py +++ b/pearl/utils/instantiations/environments/contextual_bandit_uci_environment.py @@ -9,13 +9,23 @@ from pearl.api.observation import Observation from pearl.api.reward import Value +from pearl.api.space import Space from pearl.utils.instantiations.environments.contextual_bandit_environment import ( ContextualBanditEnvironment, ) +from pearl.utils.instantiations.spaces.box import BoxSpace from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace class SLCBEnvironment(ContextualBanditEnvironment): + """ + A contextual bandit environment to work with uci datasets. + + Note that the context features are assumed to be continuous. For datasets with discrete + features, please modify the code used to normalize features and set the _observation_space + attribute to be a `DiscreteSpace` object. + """ + def __init__( self, path_filename: str, @@ -86,8 +96,14 @@ def __init__( ) self._action_dim_env: int = self._action_space[0].shape[0] - # Set observation dimension + # Set observation space and observation dimension self.observation_dim: int = tensor.size()[1] - 1 # 0th index is the target + self._observation_space: Space = ( + BoxSpace( # Box space with low for each dimension = -inf, high for each dimension = inf + high=torch.full((self.observation_dim,), float("inf")), + low=torch.full((self.observation_dim,), float("-inf")), + ) + ) # Set noise to be added to reward self.reward_noise_sigma = reward_noise_sigma @@ -122,6 +138,10 @@ def action_transfomer( else: raise Exception("Invalid action_embeddings type") + @property + def observation_space(self) -> Optional[Space]: + return self._observation_space + @property def action_space(self) -> ActionSpace: return self._action_space diff --git a/pearl/utils/instantiations/environments/environments.py b/pearl/utils/instantiations/environments/environments.py index da75b3c..6ab5fe6 100644 --- a/pearl/utils/instantiations/environments/environments.py +++ b/pearl/utils/instantiations/environments/environments.py @@ -36,9 +36,9 @@ class FixedNumberOfStepsEnvironment(Environment): reward are the number of steps. """ - def __init__(self, number_of_steps: int = 100) -> None: + def __init__(self, max_number_of_steps: int = 100) -> None: self.number_of_steps_so_far = 0 - self.number_of_steps: int = number_of_steps + self.max_number_of_steps: int = max_number_of_steps self._action_space = DiscreteActionSpace( [torch.tensor(False), torch.tensor(True)] ) @@ -60,7 +60,20 @@ def render(self) -> None: def action_space(self) -> ActionSpace: return self._action_space + @property + def observation_space(self) -> DiscreteSpace: + return DiscreteSpace( + [torch.tensor(i) for i in range(self.max_number_of_steps + 1)] + ) + def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: + """ + Provides the observation and action space to the agent. + """ + # clipping the observation to be within the range of [0, max_number_of_steps] + self.number_of_steps_so_far = max( + self.number_of_steps_so_far, self.max_number_of_steps + ) return self.number_of_steps_so_far, self.action_space def __str__(self) -> str: @@ -77,7 +90,9 @@ def __init__( base_environment: Environment, ) -> None: self.base_environment = base_environment - self.observation_space: Space = self.make_observation_space(base_environment) + self._observation_space: Space = self.make_observation_space( + self.base_environment + ) @staticmethod @abstractmethod @@ -92,6 +107,10 @@ def compute_tensor_observation(self, observation: Observation) -> torch.Tensor: def action_space(self) -> ActionSpace: return self.base_environment.action_space + @property + def observation_space(self) -> Space: + return self._observation_space + def step(self, action: Action) -> ActionResult: action_result = self.base_environment.step(action) action_result.observation = self.compute_tensor_observation( @@ -113,11 +132,13 @@ def short_description(self) -> str: class OneHotObservationsFromDiscrete(ObservationTransformationEnvironmentAdapterBase): """ - An environment adapter mapping a Discrete observation space into - a Box observation space with dimension 1 - where the observation is a one-hot vector. + A wrapper around a base environment that transforms the observation space of the base + environment from a DiscreteSpace with a finite subset of integers (for e.g. a Discrete + environment in Gymnasium, gym.spaces.Discrete, {0, 1, 2, ... end}) to a DiscreteSpace + in Pearl where the observations are represented as one hot vectors. - This is useful to use with agents expecting tensor observations. + This is useful to use with agents expecting one-hot tensor observations. One-hot encoding + is a common way to represent discrete observations in RL. """ def __init__(self, base_environment: Environment) -> None: @@ -125,8 +146,6 @@ def __init__(self, base_environment: Environment) -> None: @staticmethod def make_observation_space(base_environment: Environment) -> Space: - # pyre-fixme: need to add `observation_space` property in Environment - # and implement it in all concrete subclasses assert isinstance(base_environment.observation_space, DiscreteSpace) n = base_environment.observation_space.n elements = [F.one_hot(torch.tensor(i), n).float() for i in range(n)] @@ -137,8 +156,6 @@ def compute_tensor_observation(self, observation: Observation) -> torch.Tensor: observation_tensor = observation else: observation_tensor = torch.tensor(observation) - # pyre-fixme: need to add `observation_space` property in Environment - # and implement it in all concrete subclasses assert isinstance(self.base_environment.observation_space, DiscreteSpace) return F.one_hot( observation_tensor, diff --git a/pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_contextual_bandit_environment.py b/pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_multi_arm_bandit_environment.py similarity index 68% rename from pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_contextual_bandit_environment.py rename to pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_multi_arm_bandit_environment.py index 0f133ae..05ef5c9 100644 --- a/pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_contextual_bandit_environment.py +++ b/pearl/utils/instantiations/environments/reward_is_equal_to_ten_times_action_multi_arm_bandit_environment.py @@ -9,20 +9,25 @@ from typing import Optional, Tuple +import torch + from pearl.api.action import Action from pearl.api.action_space import ActionSpace from pearl.api.observation import Observation from pearl.api.reward import Value +from pearl.api.space import Space from pearl.utils.instantiations.environments.contextual_bandit_environment import ( ContextualBanditEnvironment, ) +from pearl.utils.instantiations.spaces.discrete import DiscreteSpace -class RewardIsEqualToTenTimesActionContextualBanditEnvironment( +class RewardIsEqualToTenTimesActionMultiArmBanditEnvironment( ContextualBanditEnvironment ): """ - A example implementation of a contextual bandit environment. + A example implementation of a bandit environment. For simplicity, we assume + no context. Therefore, it is a multi-arm bandit environment. """ def __init__(self, action_space: ActionSpace) -> None: @@ -32,6 +37,13 @@ def __init__(self, action_space: ActionSpace) -> None: def action_space(self) -> ActionSpace: return self._action_space + @property + def observation_space(self) -> Optional[Space]: + # For multi-arm bandit environments (where there are no 'observations'), we set the + # observation space to be a DiscreteSpace with a single element, taken to be an empty + # tensor. + return DiscreteSpace([torch.tensor([])]) + def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: # Function returning the context and the available action space # Here, we use no context (None), but we could return varied implementations. diff --git a/pearl/utils/instantiations/environments/sparse_reward_environment.py b/pearl/utils/instantiations/environments/sparse_reward_environment.py index d16e238..4e50497 100644 --- a/pearl/utils/instantiations/environments/sparse_reward_environment.py +++ b/pearl/utils/instantiations/environments/sparse_reward_environment.py @@ -19,7 +19,7 @@ There are 2 versions in this file: - one for discrete action space -- one for contineous action space +- one for continuous action space """ import math import random @@ -33,36 +33,60 @@ from pearl.api.action_space import ActionSpace from pearl.api.environment import Environment +from pearl.utils.instantiations.spaces.box import BoxSpace from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace +# FIXME by @Jalaj: this file needs thorough fixing. + class SparseRewardEnvironment(Environment): def __init__( self, - length: float, + width: float, height: float, max_episode_duration: int = 500, reward_distance: float = 1, ) -> None: - self._length = length + self._width = width self._height = height self._max_episode_duration = max_episode_duration - # reset will initialize following + self._reward_distance = reward_distance + + # reset will initialize the agent position, goal and step count self._agent_position: Optional[Tuple[float, float]] = None self._goal: Optional[Tuple[float, float]] = None self._step_count = 0 - self._reward_distance = reward_distance @abstractmethod def step(self, action: Action) -> ActionResult: pass + @property + def observation_space(self) -> BoxSpace: + """ + The observation space is a 2d box space, with the range in the x-coordinate + being [0, width] and the range in the y-coordinate being [0, height]. + """ + observation_space = BoxSpace( + low=torch.tensor([0, 0]), high=torch.tensor([self._width, self._height]) + ) + return observation_space + def reset(self, seed: Optional[int] = None) -> Tuple[torch.Tensor, ActionSpace]: - # reset (x, y) - self._agent_position = (self._length / 2, self._height / 2) - self._goal = (random.uniform(0, self._length), random.uniform(0, self._height)) - self._step_count = 0 + # reset (x, y) for agent position + self._agent_position = ( + self._width / 2, + self._height / 2, + ) + + # reset (x, y) for goal + self._goal = ( + random.uniform(0, self._width), + random.uniform(0, self._height), + ) + + self._step_count = 0 # reset step_count assert self._agent_position is not None assert (goal := self._goal) is not None return ( @@ -72,24 +96,27 @@ def reset(self, seed: Optional[int] = None) -> Tuple[torch.Tensor, ActionSpace]: def _update_position(self, delta: Tuple[float, float]) -> None: """ - This API is to update and clip and ensure agent always stay in map + Update the agent position, say (x, y) --> (x', y') where: + x' = x + delta_x + y' = y + delta_y + + A clip operation is added to ensure the agent always stay in 2d grid. """ delta_x, delta_y = delta assert self._agent_position is not None x, y = self._agent_position self._agent_position = ( - max(min(x + delta_x, self._length), 0), + max(min(x + delta_x, self._width), 0), max(min(y + delta_y, self._height), 0), ) def _check_win(self) -> bool: """ - Return: - True if reached goal - False if not reached goal + Indicates whether the agent position is close enough (in Euclidean distance) to the goal. """ assert self._agent_position is not None assert self._goal is not None + if math.dist(self._agent_position, self._goal) < self._reward_distance: return True return False @@ -126,7 +153,7 @@ def action_space(self) -> ActionSpace: class DiscreteSparseRewardEnvironment(ContinuousSparseRewardEnvironment): """ - Given action count N, action index will be 0,...,N-1 + Given action count N, action index will be 0, ..., N-1 For action n, position will be changed by: x += cos(360/N * n) * step_size y += sin(360/N * n) * step_size @@ -135,18 +162,18 @@ class DiscreteSparseRewardEnvironment(ContinuousSparseRewardEnvironment): # FIXME: This environment mixes the concepts of action index and action feature. def __init__( self, - length: float, + width: float, height: float, + action_count: int, + reward_distance: float, step_size: float = 0.01, - action_count: int = 4, max_episode_duration: int = 500, - reward_distance: Optional[float] = None, ) -> None: super(DiscreteSparseRewardEnvironment, self).__init__( - length, - height, - max_episode_duration, - reward_distance if reward_distance is not None else step_size, + width=width, + height=height, + max_episode_duration=max_episode_duration, + reward_distance=reward_distance, ) self._step_size = step_size self._action_count = action_count diff --git a/test/integration/test_integration_replay_buffer.py b/test/integration/test_integration_replay_buffer.py index baa7c01..3319bd9 100644 --- a/test/integration/test_integration_replay_buffer.py +++ b/test/integration/test_integration_replay_buffer.py @@ -46,7 +46,7 @@ def test_her(self) -> None: DQN is not able to solve this problem within 1000 episodes """ env: DiscreteSparseRewardEnvironment = DiscreteSparseRewardEnvironment( - length=50, + width=50, height=50, step_size=1, action_count=8, diff --git a/test/unit/with_pytorch/test_agent.py b/test/unit/with_pytorch/test_agent.py index 9b4805a..8fc3764 100644 --- a/test/unit/with_pytorch/test_agent.py +++ b/test/unit/with_pytorch/test_agent.py @@ -53,8 +53,8 @@ FixedNumberOfStepsEnvironment, ) from pearl.utils.instantiations.environments.gym_environment import GymEnvironment -from pearl.utils.instantiations.environments.reward_is_equal_to_ten_times_action_contextual_bandit_environment import ( # noqa: E501 - RewardIsEqualToTenTimesActionContextualBanditEnvironment, +from pearl.utils.instantiations.environments.reward_is_equal_to_ten_times_action_multi_arm_bandit_environment import ( # Noqa E501 + RewardIsEqualToTenTimesActionMultiArmBanditEnvironment, ) from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace @@ -207,7 +207,7 @@ def test_with_linear_contextual(self) -> None: self.assertTrue(sum(regrets[10:]) >= sum(regrets[-10:])) def test_online_rl(self) -> None: - env = FixedNumberOfStepsEnvironment(number_of_steps=100) + env = FixedNumberOfStepsEnvironment(max_number_of_steps=100) agent = PearlAgent(TabularQLearning()) online_learning(agent, env, number_of_episodes=1000) @@ -241,7 +241,7 @@ def test_tabular_q_learning_online_rl(self) -> None: def test_contextual_bandit_with_tabular_q_learning_online_rl(self) -> None: num_actions = 5 max_action = num_actions - 1 - env = RewardIsEqualToTenTimesActionContextualBanditEnvironment( + env = RewardIsEqualToTenTimesActionMultiArmBanditEnvironment( action_space=DiscreteActionSpace( actions=list(torch.arange(num_actions).view(-1, 1)) ) diff --git a/test/unit/with_pytorch/test_sparse_reward_environment.py b/test/unit/with_pytorch/test_sparse_reward_environment.py index 8b690f3..dedfb61 100644 --- a/test/unit/with_pytorch/test_sparse_reward_environment.py +++ b/test/unit/with_pytorch/test_sparse_reward_environment.py @@ -20,7 +20,7 @@ class TestSparseRewardEnvironment(unittest.TestCase): def test_basic(self) -> None: env = DiscreteSparseRewardEnvironment( - length=100, height=100, step_size=1, action_count=4 + width=100, height=100, reward_distance=1, step_size=1, action_count=4 ) # Test reset