diff --git a/rl_teacher/segment_sampling.py b/rl_teacher/segment_sampling.py index ee209b9..4c0cde4 100644 --- a/rl_teacher/segment_sampling.py +++ b/rl_teacher/segment_sampling.py @@ -1,4 +1,7 @@ +import math +from multiprocessing import Pool import numpy as np +import gym.spaces.prng as space_prng from rl_teacher.envs import get_timesteps_per_episode @@ -55,10 +58,16 @@ def do_rollout(env, action_function): "human_obs": np.array(human_obs)} return path -def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds): - """ Generate a list of path segments by doing random rollouts. """ +def basic_segments_from_rand_rollout( + env_id, make_env, n_desired_segments, clip_length_in_seconds, + # These are only for use with multiprocessing + seed=0, _verbose=True, _multiplier=1 +): + """ Generate a list of path segments by doing random rollouts. No multiprocessing. """ segments = [] env = make_env(env_id) + env.seed(seed) + space_prng.seed(seed) segment_length = int(clip_length_in_seconds * env.fps) while len(segments) < n_desired_segments: path = do_rollout(env, random_action) @@ -70,8 +79,24 @@ def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length if segment: segments.append(segment) - if len(segments) % 10 == 0 and len(segments) > 0: - print("Collected %s/%s segments" % (len(segments), n_desired_segments)) + if _verbose and len(segments) % 10 == 0 and len(segments) > 0: + print("Collected %s/%s segments" % (len(segments) * _multiplier, n_desired_segments * _multiplier)) - print("Successfully collected %s segments" % (len(segments))) + if _verbose: + print("Successfully collected %s segments" % (len(segments) * _multiplier)) return segments + +def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds, workers): + """ Generate a list of path segments by doing random rollouts. Can use multiple processes. """ + if workers < 2: # Default to basic segment collection + return basic_segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds) + + pool = Pool(processes=workers) + segments_per_worker = int(math.ceil(n_desired_segments / workers)) + # One job per worker. Only the first worker is verbose. + jobs = [ + (env_id, make_env, segments_per_worker, clip_length_in_seconds, i, i == 0, workers) + for i in range(workers)] + results = pool.starmap(basic_segments_from_rand_rollout, jobs) + pool.close() + return [segment for sublist in results for segment in sublist] diff --git a/rl_teacher/teach.py b/rl_teacher/teach.py index 65ca31d..3a213c0 100644 --- a/rl_teacher/teach.py +++ b/rl_teacher/teach.py @@ -228,8 +228,9 @@ def main(): label_schedule = ConstantLabelSchedule(pretrain_labels=pretrain_labels) print("Starting random rollouts to generate pretraining segments. No learning will take place...") - pretrain_segments = segments_from_rand_rollout(env_id, make_with_torque_removed, - n_desired_segments=pretrain_labels * 2, clip_length_in_seconds=CLIP_LENGTH) + pretrain_segments = segments_from_rand_rollout( + env_id, make_with_torque_removed, n_desired_segments=pretrain_labels * 2, + clip_length_in_seconds=CLIP_LENGTH, workers=args.workers) for i in range(pretrain_labels): # Turn our random segments into comparisons comparison_collector.add_segment_pair(pretrain_segments[i], pretrain_segments[i + pretrain_labels])