Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions maro/rl/agent/abs_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ class AbsAgent(ABC):
choosing actions and optimizing models.
experience_pool (AbsStore): A data store that stores experiences generated by the experience shaper.
"""
def __init__(self,
name: str,
algorithm: AbsAlgorithm,
experience_pool: AbsStore
):
def __init__(
self, name: str, algorithm: AbsAlgorithm, experience_pool: AbsStore
):
self._name = name
self._algorithm = algorithm
self._experience_pool = experience_pool
Expand Down Expand Up @@ -82,10 +80,13 @@ def load_model_dict_from_file(self, file_path):

def dump_model_dict(self, dir_path: str):
"""Dump models to disk."""
torch.save({model_key: model.state_dict() for model_key, model in self._algorithm.model_dict.items()},
os.path.join(dir_path, self._name))
torch.save(
{model_key: model.state_dict() for model_key, model in self._algorithm.model_dict.items()},
os.path.join(dir_path, self._name)
)

def dump_experience_store(self, dir_path: str):
def dump_experience_pool(self, dir_path: str):
"""Dump the experience pool to disk."""
with open(os.path.join(dir_path, self._name)) as fp:
os.makedirs(dir_path, exist_ok=True)
with open(os.path.join(dir_path, self._name), "wb") as fp:
pickle.dump(self._experience_pool, fp)
16 changes: 13 additions & 3 deletions maro/rl/storage/column_based_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def __next__(self):
def __getitem__(self, index: int):
return {k: lst[index] for k, lst in self._store.items()}

def __getstate__(self):
"""A patch for picking the object with lambda.
Using the default ``__dict__`` would make the object unpicklable due to the lambda function involved in the
``defaultdict`` definition of the ``_store`` attribute.
Comment thread
ArthurJiang marked this conversation as resolved.
"""
obj_dict = self.__dict__
obj_dict["_store"] = dict(obj_dict["_store"])
return obj_dict

@property
def capacity(self):
"""Store capacity.
Expand Down Expand Up @@ -92,8 +101,9 @@ def put(self, contents: dict, overwrite_indexes: Sequence = None) -> List[int]:
self._size += added_size
return list(range(self._size - added_size, self._size))
else:
write_indexes = get_update_indexes(self._size, added_size, self._capacity, self._overwrite_type,
overwrite_indexes=overwrite_indexes)
write_indexes = get_update_indexes(
self._size, added_size, self._capacity, self._overwrite_type, overwrite_indexes=overwrite_indexes
)
self.update(write_indexes, contents)
self._size = min(self._capacity, self._size + added_size)
return write_indexes
Expand Down Expand Up @@ -125,7 +135,7 @@ def apply_multi_filters(self, filters: Sequence[Callable]):

Args:
filters (Sequence[Callable]): Filter list, each item is a lambda function,
e.g., [lambda d: d['a'] == 1 and d['b'] == 1].
e.g., [lambda d: d['a'] == 1 and d['b'] == 1].
Returns:
Filtered indexes and corresponding objects.
"""
Expand Down