Skip to content

Commit

Permalink
Merge pull request #945 from deepmind:lanctot-patch-43
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 479897183
Change-Id: I2580ec42849ff87a45b046a92d53843e33b9a081
  • Loading branch information
lanctot committed Oct 9, 2022
2 parents f9279c6 + 1ac0c33 commit 1617276
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions open_spiel/python/algorithms/rnad/rnad.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ def update_parameters(
timestep: TimeStep,
alpha: float,
learner_steps: int,
update_target_net: bool,
) -> Tuple[Tuple[Any, Any, Any, Any, Any, Any], dict[str, float]]:
update_target_net: bool):
"""A jitted pure-functional part of the `step`."""
loss_val, grad = self._loss_and_grad(params, params_target, params_prev,
params_prev_, timestep, alpha,
Expand All @@ -859,7 +858,7 @@ def update_parameters(
return (params, params_target, params_prev, params_prev_, optimizer,
optimizer_target), logs

def __getstate__(self) -> dict[str, Any]:
def __getstate__(self):
"""To serialize the agent."""
return dict(
# RNaD config.
Expand All @@ -883,7 +882,7 @@ def __getstate__(self) -> dict[str, Any]:
optimizer_target=self.optimizer_target.state,
)

def __setstate__(self, state: dict[str, Any]):
def __setstate__(self, state):
"""To deserialize the agent."""
# RNaD config.
self.config = state["config"]
Expand All @@ -907,7 +906,7 @@ def __setstate__(self, state: dict[str, Any]):
self.optimizer.state = state["optimizer"]
self.optimizer_target.state = state["optimizer_target"]

def step(self) -> dict[str, float]:
def step(self):
"""One step of the algorithm, that plays the game and improves params."""
timestep = self.collect_batch_trajectory()
alpha, update_target_net = self._entropy_schedule(self.learner_steps)
Expand Down Expand Up @@ -957,7 +956,7 @@ def _state_as_env_step(self, state: pyspiel.State) -> EnvStep:

def action_probabilities(self,
state: pyspiel.State,
player_id: Any = None) -> dict[int, float]:
player_id: Any = None):
"""Returns action probabilities dict for a single batch."""
env_step = self._batch_of_states_as_env_step([state])
probs = self._network_jit_apply_and_post_process(
Expand All @@ -975,7 +974,7 @@ def _network_jit_apply_and_post_process(

@functools.partial(jax.jit, static_argnums=(0,))
def actor_step(self, env_step: EnvStep,
rng_key: chex.PRNGKey) -> Tuple[chex.Array, ActorStep]:
rng_key: chex.PRNGKey):
pi, _, _, _ = self.network.apply(self.params, env_step)
# TODO(author18): is this policy normalization really needed?
pi = pi / jnp.sum(pi, axis=-1, keepdims=True)
Expand Down

0 comments on commit 1617276

Please sign in to comment.