diff --git a/maro/rl/agent/abs_agent.py b/maro/rl/agent/abs_agent.py index 0aa9afd76..03967875d 100644 --- a/maro/rl/agent/abs_agent.py +++ b/maro/rl/agent/abs_agent.py @@ -27,11 +27,9 @@ class AbsAgent(ABC): choosing actions and optimizing models. experience_pool (AbsStore): A data store that stores experiences generated by the experience shaper. """ - def __init__(self, - name: str, - algorithm: AbsAlgorithm, - experience_pool: AbsStore - ): + def __init__( + self, name: str, algorithm: AbsAlgorithm, experience_pool: AbsStore + ): self._name = name self._algorithm = algorithm self._experience_pool = experience_pool @@ -82,10 +80,13 @@ def load_model_dict_from_file(self, file_path): def dump_model_dict(self, dir_path: str): """Dump models to disk.""" - torch.save({model_key: model.state_dict() for model_key, model in self._algorithm.model_dict.items()}, - os.path.join(dir_path, self._name)) + torch.save( + {model_key: model.state_dict() for model_key, model in self._algorithm.model_dict.items()}, + os.path.join(dir_path, self._name) + ) - def dump_experience_store(self, dir_path: str): + def dump_experience_pool(self, dir_path: str): """Dump the experience pool to disk.""" - with open(os.path.join(dir_path, self._name)) as fp: + os.makedirs(dir_path, exist_ok=True) + with open(os.path.join(dir_path, self._name), "wb") as fp: pickle.dump(self._experience_pool, fp) diff --git a/maro/rl/storage/column_based_store.py b/maro/rl/storage/column_based_store.py index 8709a8d8e..46398185e 100644 --- a/maro/rl/storage/column_based_store.py +++ b/maro/rl/storage/column_based_store.py @@ -53,6 +53,15 @@ def __next__(self): def __getitem__(self, index: int): return {k: lst[index] for k, lst in self._store.items()} + def __getstate__(self): + """A patch for picking the object with lambda. + Using the default ``__dict__`` would make the object unpicklable due to the lambda function involved in the + ``defaultdict`` definition of the ``_store`` attribute. + """ + obj_dict = self.__dict__ + obj_dict["_store"] = dict(obj_dict["_store"]) + return obj_dict + @property def capacity(self): """Store capacity. @@ -92,8 +101,9 @@ def put(self, contents: dict, overwrite_indexes: Sequence = None) -> List[int]: self._size += added_size return list(range(self._size - added_size, self._size)) else: - write_indexes = get_update_indexes(self._size, added_size, self._capacity, self._overwrite_type, - overwrite_indexes=overwrite_indexes) + write_indexes = get_update_indexes( + self._size, added_size, self._capacity, self._overwrite_type, overwrite_indexes=overwrite_indexes + ) self.update(write_indexes, contents) self._size = min(self._capacity, self._size + added_size) return write_indexes @@ -125,7 +135,7 @@ def apply_multi_filters(self, filters: Sequence[Callable]): Args: filters (Sequence[Callable]): Filter list, each item is a lambda function, - e.g., [lambda d: d['a'] == 1 and d['b'] == 1]. + e.g., [lambda d: d['a'] == 1 and d['b'] == 1]. Returns: Filtered indexes and corresponding objects. """