Skip to content

Commit

Permalink
[RLlib] Provide msgpack checkpoint translation utility for Policy-onl…
Browse files Browse the repository at this point in the history
…y cases. (ray-project#38825)

Signed-off-by: Jim Thompson <jimthompson5802@gmail.com>
  • Loading branch information
sven1977 authored and jimthompson5802 committed Sep 12, 2023
1 parent 7183d3c commit e70e743
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
10 changes: 8 additions & 2 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,15 @@ def from_checkpoint(

# Policy checkpoint: Return a single Policy instance.
else:
msgpack = None
if checkpoint_info.get("format") == "msgpack":
msgpack = try_import_msgpack(error=True)

with open(checkpoint_info["state_file"], "rb") as f:
state = pickle.load(f)
if msgpack is not None:
state = msgpack.load(f)
else:
state = pickle.load(f)
return Policy.from_state(state)

@staticmethod
Expand Down Expand Up @@ -1843,7 +1850,6 @@ def get_gym_space_from_struct_of_tensors(
value: Union[Mapping, Tuple, List, TensorType],
batched_input=True,
) -> gym.Space:

start_idx = 1 if batched_input else 0
struct = tree.map_structure(
lambda x: gym.spaces.Box(
Expand Down
36 changes: 36 additions & 0 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,42 @@ def convert_to_msgpack_checkpoint(
return msgpack_checkpoint_dir


@PublicAPI(stability="beta")
def convert_to_msgpack_policy_checkpoint(
policy_checkpoint: Union[str, Checkpoint, NewCheckpoint],
msgpack_checkpoint_dir: str,
) -> str:
"""Converts a Policy checkpoint (pickle based) to a msgpack based one.
Msgpack has the advantage of being python version independent.
Args:
policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
based).
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
based checkpoint.
Returns:
The directory in which the msgpack checkpoint has been created. Note that
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.policy.policy import Policy

policy = Policy.from_checkpoint(policy_checkpoint)

os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
policy.export_checkpoint(
msgpack_checkpoint_dir,
policy_state=policy.get_state(),
checkpoint_format="msgpack",
)

# Release all resources used by the Policy.
del policy

return msgpack_checkpoint_dir


@PublicAPI
def try_import_msgpack(error: bool = False):
"""Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
Expand Down
45 changes: 43 additions & 2 deletions rllib/utils/tests/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.simple_q import SimpleQConfig
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.checkpoints import (
get_checkpoint_info,
convert_to_msgpack_checkpoint,
convert_to_msgpack_policy_checkpoint,
)
from ray.rllib.utils.test_utils import check
from ray import tune
Expand Down Expand Up @@ -86,7 +88,7 @@ def test_get_policy_checkpoint_info_v1_1(self):
def test_msgpack_checkpoint_translation(self):
"""Tests, whether a checkpoint can be translated into a msgpack-checkpoint ...
... and recovered back into and Algorithm, which is identical to a
... and recovered back into an Algorithm, which is identical to a
pickle-checkpoint-recovered Algorithm (given same initial config).
"""
# Base config used for both pickle-based checkpoint and msgpack-based one.
Expand Down Expand Up @@ -141,9 +143,10 @@ def test_msgpack_checkpoint_translation(self):
def test_msgpack_checkpoint_translation_multi_agent(self):
"""Tests, whether a checkpoint can be translated into a msgpack-checkpoint ...
... and recovered back into and Algorithm, which is identical to a
... and recovered back into an Algorithm, which is identical to a
pickle-checkpoint-recovered Algorithm (given same initial config).
"""

# Base config used for both pickle-based checkpoint and msgpack-based one.
def mapping_fn(aid, episode, worker, **kwargs):
return "pol" + str(aid)
Expand Down Expand Up @@ -223,6 +226,44 @@ def mapping_fn(aid, episode, worker, **kwargs):
algo1.stop()
algo2.stop()

def test_msgpack_policy_checkpoint_translation(self):
"""Tests, whether a Policy checkpoint can be translated into msgpack ...
... and recovered back into a Policy, which is identical to a
pickle-checkpoint-recovered Policy (given same initial config).
"""
# Base config used for both pickle-based checkpoint and msgpack-based one.
config = SimpleQConfig().environment("CartPole-v1")
# Build algorithm/policy objects.
algo1 = config.build()
pol1 = algo1.get_policy()
# Get its state.
pickle_state = pol1.get_state()

# Create standard (pickle-based) checkpoint.
with tempfile.TemporaryDirectory() as pickle_cp_dir:
pol1.export_checkpoint(pickle_cp_dir)
# Now convert pickle checkpoint to msgpack using the provided
# utility function.
with tempfile.TemporaryDirectory() as msgpack_cp_dir:
convert_to_msgpack_policy_checkpoint(pickle_cp_dir, msgpack_cp_dir)
msgpack_cp_info = get_checkpoint_info(msgpack_cp_dir)
self.assertTrue(msgpack_cp_info["type"] == "Policy")
self.assertTrue(msgpack_cp_info["format"] == "msgpack")
self.assertTrue(msgpack_cp_info["policy_ids"] is None)
# Try recreating a new policy object from the msgpack checkpoint.
pol2 = Policy.from_checkpoint(msgpack_cp_dir)
# Get the state of the policy recovered from msgpack.
msgpack_state = pol2.get_state()

# Make sure the states? match 100%. Our `check` utility
# cannot handle comparing types/classes, so we'll have to serialize the
# pickle'd config (which contains types, rather than class strings).
pickle_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict(
pickle_state["policy_spec"]["config"]
)
check(pickle_state, msgpack_state)


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit e70e743

Please sign in to comment.