Skip to content

Commit

Permalink
RNaD: use an enum instead of a bare string.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 492192763
Change-Id: I0419e90e1416b686ef3f4c59c90925b4640afda4
  • Loading branch information
DeepMind Technologies Ltd authored and lanctot committed Dec 5, 2022
1 parent 0e9056c commit 49d3036
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions open_spiel/python/algorithms/rnad/rnad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Python implementation of R-NaD (https://arxiv.org/pdf/2206.15378.pdf)."""

import enum
import functools
from typing import Any, Callable, Sequence, Tuple

Expand Down Expand Up @@ -596,6 +597,11 @@ class NerdConfig:
clip: float = 10_000


class StateRepresentation(str, enum.Enum):
INFO_SET = "info_set"
OBSERVATION = "observation"


@chex.dataclass(frozen=True)
class RNaDConfig:
"""Configuration parameters for the RNaDSolver."""
Expand All @@ -605,7 +611,7 @@ class RNaDConfig:
trajectory_max: int = 10

# The content of the EnvStep.obs tensor.
state_representation: str = "info_set" # or "observation"
state_representation: StateRepresentation = StateRepresentation.INFO_SET

# Network configuration.
policy_network_layers: Sequence[int] = (256, 256)
Expand Down Expand Up @@ -955,14 +961,13 @@ def _state_as_env_step(self, state: pyspiel.State) -> EnvStep:
if not valid:
state = self._ex_state

if self.config.state_representation == "observation":
if self.config.state_representation == StateRepresentation.OBSERVATION:
obs = state.observation_tensor()
elif self.config.state_representation == "info_set":
elif self.config.state_representation == StateRepresentation.INFO_SET:
obs = state.information_state_tensor()
else:
raise ValueError(
f"Invalid state_representation: {self.config.state_representation}. "
"Must be either 'info_set' or 'observation'.")
f"Invalid StateRepresentation: {self.config.state_representation}.")

# TODO(author16): clarify the story around rewards and valid.
return EnvStep(
Expand Down

0 comments on commit 49d3036

Please sign in to comment.