Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #4

Merged
merged 14 commits into from
Oct 10, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ build/
baselines_DEPRL/
_static
_templates
.DS_Store
58 changes: 30 additions & 28 deletions deprl/custom_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def proc(
action_pipe,
output_queue,
seed,
group_seed,
build_dict,
max_episode_steps,
index,
Expand All @@ -18,10 +18,8 @@ def proc(
header,
):
"""Process holding a sequential group of environments."""
envs = Sequential(
build_dict, max_episode_steps, workers, index, env_args, header
)
envs.initialize(seed)
envs = Sequential(build_dict, max_episode_steps, workers, env_args, header)
envs.initialize(group_seed)

observations = envs.start()
output_queue.put((index, observations))
Expand All @@ -36,21 +34,20 @@ class Sequential:
"""A group of environments used in sequence."""

def __init__(
self, build_dict, max_episode_steps, workers, index, env_args, header
self, build_dict, max_episode_steps, workers, env_args, header
):
if header is not None:
with stdout_suppression():
exec(header)
if hasattr(build_env_from_dict(build_dict)().unwrapped, "environment"):
if hasattr(build_env_from_dict(build_dict).unwrapped, "environment"):
# its a deepmind env
self.environments = [
build_env_from_dict(build_dict)() for i in range(workers)
]
else:
# its a gym env
self.environments = [
build_env_from_dict(build_dict)(identifier=index * workers + i)
for i in range(workers)
build_env_from_dict(build_dict) for i in range(workers)
]
if env_args is not None:
[x.merge_args(env_args) for x in self.environments]
Expand All @@ -62,6 +59,7 @@ def __init__(
self.num_workers = workers

def initialize(self, seed):
# group seed is given, the others are determined from it
for i, environment in enumerate(self.environments):
environment.seed(seed + i)

Expand Down Expand Up @@ -145,7 +143,7 @@ def __init__(
self.header = header

def initialize(self, seed):
dummy_environment = build_env_from_dict(self.build_dict)()
dummy_environment = build_env_from_dict(self.build_dict)
dummy_environment.merge_args(self.env_args)
dummy_environment.apply_args()

Expand All @@ -164,15 +162,14 @@ def initialize(self, seed):
pipe, worker_end = context.Pipe()
self.action_pipes.append(pipe)
group_seed = (
seed * (self.worker_groups * self.workers_per_group)
+ i * self.workers_per_group
seed * self.workers_per_group + i * self.workers_per_group
)

# required for spawnstart_method
# required for spawnstart_method for macos and windows
proc_kwargs = {
"action_pipe": worker_end,
"output_queue": self.output_queue,
"seed": group_seed,
"group_seed": group_seed,
"build_dict": self.build_dict,
"max_episode_steps": self._max_episode_steps,
"index": i,
Expand Down Expand Up @@ -254,41 +251,46 @@ def close(self):


def distribute(
build_dict,
worker_groups=1,
workers_per_group=1,
env_args=None,
header=None,
environment,
tonic_conf,
env_args,
parallel=None,
sequential=None,
):
"""Distributes workers over parallel and sequential groups."""
parallel = tonic_conf["parallel"] if parallel is None else parallel
sequential = tonic_conf["sequential"] if sequential is None else sequential
build_dict = dict(
env=environment, parallel=parallel, sequential=sequential
)

dummy_environment = build_env_from_dict(build_dict)()
dummy_environment = build_env_from_dict(build_dict)
max_episode_steps = dummy_environment._max_episode_steps
del dummy_environment

if worker_groups < 2:
if parallel < 2:
return Sequential(
build_dict=build_dict,
max_episode_steps=max_episode_steps,
workers=workers_per_group,
workers=sequential,
env_args=env_args,
header=header,
index=0,
header=tonic_conf["header"],
)
return Parallel(
build_dict,
worker_groups=worker_groups,
workers_per_group=workers_per_group,
worker_groups=parallel,
workers_per_group=sequential,
max_episode_steps=max_episode_steps,
env_args=env_args,
header=header,
header=tonic_conf["header"],
)


def build_env_from_dict(build_dict):
assert build_dict["env"] is not None
if type(build_dict) == dict:
from deprl import env_tonic_compat

return env_tonic_compat(**build_dict)
else:
return lambda identifier=0: build_dict()
return build_dict()
5 changes: 4 additions & 1 deletion deprl/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def run(self, params, steps=0, epochs=0, episodes=0):
os.remove(os.path.join(path, file))
checkpoint_name = f"step_{self.steps}"
save_path = os.path.join(path, checkpoint_name)
# save agent checkpoint
self.agent.save(save_path, full_save=self.full_save)
# logger.save(save_path)
# save logger checkpoint
logger.save(save_path)
# save time iteration dict
self.save_time(save_path, epochs, episodes)
steps_since_save = self.steps % self.save_steps
current_time = time.time()
Expand Down
4 changes: 2 additions & 2 deletions deprl/env_wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def apply_wrapper(env):
return GymWrapper(env)


def env_tonic_compat(env, preid=5, parallel=1, sequential=1):
def env_tonic_compat(env, id=5, parallel=1, sequential=1):
"""
Applies wrapper for tonic and passes random seed.
"""
return lambda identifier=0: apply_wrapper(eval(env))
return apply_wrapper(eval(env))


__all__ = [env_tonic_compat, apply_wrapper]
Loading