Skip to content

Commit

Permalink
Add observation_space property to all subclasses of Environment
Browse files Browse the repository at this point in the history
Summary:
Added the `observation_space` property, akin to the `action_space` property to all subclasses of the `Environment`. Also made the following changes:

1) Fixed some things in the `SparseRewardEnvironment` class. Not all things make sense there, so I plan to do a follow up diff making more changes.

2) Renamed `RewardIsEqualToTenTimesActionContextualBanditEnvironment` to `RewardIsEqualToTenTimesActionMultiArmBanditEnvironment` since it is a multi-arm bandit environment.

3) Made a small changes to `FixedNumberOfStepsEnvironment` so that we have a well defined `observation_space`.

Doc strings are missing from a lot of places here, so it took some effort in understanding. I will add it in a follow up diff.

Reviewed By: rodrigodesalvobraz

Differential Revision: D55748340

fbshipit-source-id: 3b885d6a1029c482634f52055fe6cbee9eff42f6
  • Loading branch information
jb3618 authored and facebook-github-bot committed Apr 27, 2024
1 parent 0dfd91f commit 8a36e3a
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 53 deletions.
11 changes: 6 additions & 5 deletions pearl/api/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pearl.api.action_result import ActionResult
from pearl.api.action_space import ActionSpace
from pearl.api.observation import Observation
from pearl.api.space import Space


class Environment(ABC):
Expand All @@ -32,11 +33,11 @@ def action_space(self) -> ActionSpace:
pass

# FIXME: add this and in implement in all concrete subclasses
# @property
# @abstractmethod
# def observation_space(self) -> Space:
# """Returns the observation space of the environment."""
# pass
@property
@abstractmethod
def observation_space(self) -> Space:
"""Returns the observation space of the environment."""
pass

@abstractmethod
def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
Expand Down
6 changes: 3 additions & 3 deletions pearl/utils/instantiations/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
OneHotObservationsFromDiscrete,
)
from .gym_environment import GymEnvironment
from .reward_is_equal_to_ten_times_action_contextual_bandit_environment import (
RewardIsEqualToTenTimesActionContextualBanditEnvironment,
from .reward_is_equal_to_ten_times_action_multi_arm_bandit_environment import (
RewardIsEqualToTenTimesActionMultiArmBanditEnvironment,
)
from .sparse_reward_environment import (
ContinuousSparseRewardEnvironment,
Expand All @@ -36,7 +36,7 @@
"FixedNumberOfStepsEnvironment",
"GymEnvironment",
"OneHotObservationsFromDiscrete",
"RewardIsEqualToTenTimesActionContextualBanditEnvironment",
"RewardIsEqualToTenTimesActionMultiArmBanditEnvironment",
"SLCBEnvironment",
"SparseRewardEnvironment",
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pearl.api.environment import Environment
from pearl.api.reward import Reward
from pearl.api.space import Space


class ContextualBanditEnvironment(Environment, ABC):
Expand All @@ -31,6 +32,16 @@ class ContextualBanditEnvironment(Environment, ABC):
to determine the ActionResult reward.
"""

@property
@abstractmethod
def observation_space(self) -> Space:
"""
For multi-arm bandit environments (i.e. no context), we should set the observation space to
simply be an empty tensor, i.e. {tensor([])}. For contextual bandit environments, if the
context space is not be known in advance, then it should also be set to an empty tensor.
"""
pass

@property
@abstractmethod
def action_space(self) -> ActionSpace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from pearl.api.observation import Observation
from pearl.api.reward import Value
from pearl.api.space import Space
from pearl.utils.instantiations.environments.contextual_bandit_environment import (
ContextualBanditEnvironment,
)
from pearl.utils.instantiations.spaces.box import BoxSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


Expand All @@ -35,7 +37,11 @@ class ContextualBanditLinearSyntheticEnvironment(ContextualBanditEnvironment):
"A Contextual-Bandit Approach to Personalized News Article Recommendation,"
The context for an arm is the concatenation of the observation feature vector
and the arm feature vevctor.
and the arm feature vector.
In this example, the observation space (i.e. the context space) is taken to be
[0, 1]^{`observation_dim`} where `observation_dim` is specified during initialization.
Observations are taken to be random vectors generated from this observation space.
"""

def __init__(
Expand All @@ -58,6 +64,9 @@ def __init__(
"""
assert isinstance(action_space, DiscreteActionSpace)
self._action_space: DiscreteActionSpace = action_space
self._observation_space: Space = BoxSpace(
low=torch.zeros((observation_dim)), high=torch.ones((observation_dim))
)
self.observation_dim = observation_dim
self._arm_feature_vector_dim = arm_feature_vector_dim
self.reward_noise_sigma = reward_noise_sigma
Expand All @@ -71,6 +80,10 @@ def __init__(
def action_space(self) -> ActionSpace:
return self._action_space

@property
def observation_space(self) -> Optional[Space]:
return self._observation_space

@property
def arm_feature_vector_dim(self) -> int:
return self._arm_feature_vector_dim
Expand Down Expand Up @@ -107,7 +120,7 @@ def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
"""
Provides the observation and action space to the agent.
"""
self._observation = torch.rand(self.observation_dim)
self._observation = self._observation_space.sample()
return self._observation, self.action_space

def get_reward(self, action: Action) -> Value:
Expand All @@ -122,7 +135,7 @@ def get_reward(self, action: Action) -> Value:
def get_regret(self, action: Action) -> Value:
"""
Given action, environment will return regret for choosing this action
regret == max(reward over all action) - reward for current action
regret == max(reward over all actions) - reward for current action
"""
rewards = [
self._compute_reward_from_context(self._get_context_for_arm(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@

from pearl.api.observation import Observation
from pearl.api.reward import Value
from pearl.api.space import Space
from pearl.utils.instantiations.environments.contextual_bandit_environment import (
ContextualBanditEnvironment,
)
from pearl.utils.instantiations.spaces.box import BoxSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


class SLCBEnvironment(ContextualBanditEnvironment):
"""
A contextual bandit environment to work with uci datasets.
Note that the context features are assumed to be continuous. For datasets with discrete
features, please modify the code used to normalize features and set the _observation_space
attribute to be a `DiscreteSpace` object.
"""

def __init__(
self,
path_filename: str,
Expand Down Expand Up @@ -86,8 +96,14 @@ def __init__(
)
self._action_dim_env: int = self._action_space[0].shape[0]

# Set observation dimension
# Set observation space and observation dimension
self.observation_dim: int = tensor.size()[1] - 1 # 0th index is the target
self._observation_space: Space = (
BoxSpace( # Box space with low for each dimension = -inf, high for each dimension = inf
high=torch.full((self.observation_dim,), float("inf")),
low=torch.full((self.observation_dim,), float("-inf")),
)
)

# Set noise to be added to reward
self.reward_noise_sigma = reward_noise_sigma
Expand Down Expand Up @@ -122,6 +138,10 @@ def action_transfomer(
else:
raise Exception("Invalid action_embeddings type")

@property
def observation_space(self) -> Optional[Space]:
return self._observation_space

@property
def action_space(self) -> ActionSpace:
return self._action_space
Expand Down
39 changes: 28 additions & 11 deletions pearl/utils/instantiations/environments/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class FixedNumberOfStepsEnvironment(Environment):
reward are the number of steps.
"""

def __init__(self, number_of_steps: int = 100) -> None:
def __init__(self, max_number_of_steps: int = 100) -> None:
self.number_of_steps_so_far = 0
self.number_of_steps: int = number_of_steps
self.max_number_of_steps: int = max_number_of_steps
self._action_space = DiscreteActionSpace(
[torch.tensor(False), torch.tensor(True)]
)
Expand All @@ -60,7 +60,20 @@ def render(self) -> None:
def action_space(self) -> ActionSpace:
return self._action_space

@property
def observation_space(self) -> DiscreteSpace:
return DiscreteSpace(
[torch.tensor(i) for i in range(self.max_number_of_steps + 1)]
)

def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
"""
Provides the observation and action space to the agent.
"""
# clipping the observation to be within the range of [0, max_number_of_steps]
self.number_of_steps_so_far = max(
self.number_of_steps_so_far, self.max_number_of_steps
)
return self.number_of_steps_so_far, self.action_space

def __str__(self) -> str:
Expand All @@ -77,7 +90,9 @@ def __init__(
base_environment: Environment,
) -> None:
self.base_environment = base_environment
self.observation_space: Space = self.make_observation_space(base_environment)
self._observation_space: Space = self.make_observation_space(
self.base_environment
)

@staticmethod
@abstractmethod
Expand All @@ -92,6 +107,10 @@ def compute_tensor_observation(self, observation: Observation) -> torch.Tensor:
def action_space(self) -> ActionSpace:
return self.base_environment.action_space

@property
def observation_space(self) -> Space:
return self._observation_space

def step(self, action: Action) -> ActionResult:
action_result = self.base_environment.step(action)
action_result.observation = self.compute_tensor_observation(
Expand All @@ -113,20 +132,20 @@ def short_description(self) -> str:

class OneHotObservationsFromDiscrete(ObservationTransformationEnvironmentAdapterBase):
"""
An environment adapter mapping a Discrete observation space into
a Box observation space with dimension 1
where the observation is a one-hot vector.
A wrapper around a base environment that transforms the observation space of the base
environment from a DiscreteSpace with a finite subset of integers (for e.g. a Discrete
environment in Gymnasium, gym.spaces.Discrete, {0, 1, 2, ... end}) to a DiscreteSpace
in Pearl where the observations are represented as one hot vectors.
This is useful to use with agents expecting tensor observations.
This is useful to use with agents expecting one-hot tensor observations. One-hot encoding
is a common way to represent discrete observations in RL.
"""

def __init__(self, base_environment: Environment) -> None:
super(OneHotObservationsFromDiscrete, self).__init__(base_environment)

@staticmethod
def make_observation_space(base_environment: Environment) -> Space:
# pyre-fixme: need to add `observation_space` property in Environment
# and implement it in all concrete subclasses
assert isinstance(base_environment.observation_space, DiscreteSpace)
n = base_environment.observation_space.n
elements = [F.one_hot(torch.tensor(i), n).float() for i in range(n)]
Expand All @@ -137,8 +156,6 @@ def compute_tensor_observation(self, observation: Observation) -> torch.Tensor:
observation_tensor = observation
else:
observation_tensor = torch.tensor(observation)
# pyre-fixme: need to add `observation_space` property in Environment
# and implement it in all concrete subclasses
assert isinstance(self.base_environment.observation_space, DiscreteSpace)
return F.one_hot(
observation_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@

from typing import Optional, Tuple

import torch

from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.observation import Observation
from pearl.api.reward import Value
from pearl.api.space import Space
from pearl.utils.instantiations.environments.contextual_bandit_environment import (
ContextualBanditEnvironment,
)
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace


class RewardIsEqualToTenTimesActionContextualBanditEnvironment(
class RewardIsEqualToTenTimesActionMultiArmBanditEnvironment(
ContextualBanditEnvironment
):
"""
A example implementation of a contextual bandit environment.
A example implementation of a bandit environment. For simplicity, we assume
no context. Therefore, it is a multi-arm bandit environment.
"""

def __init__(self, action_space: ActionSpace) -> None:
Expand All @@ -32,6 +37,13 @@ def __init__(self, action_space: ActionSpace) -> None:
def action_space(self) -> ActionSpace:
return self._action_space

@property
def observation_space(self) -> Optional[Space]:
# For multi-arm bandit environments (where there are no 'observations'), we set the
# observation space to be a DiscreteSpace with a single element, taken to be an empty
# tensor.
return DiscreteSpace([torch.tensor([])])

def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
# Function returning the context and the available action space
# Here, we use no context (None), but we could return varied implementations.
Expand Down

0 comments on commit 8a36e3a

Please sign in to comment.