Skip to content

Commit

Permalink
Back to pooled rollouts, but this time with random seed set using wor…
Browse files Browse the repository at this point in the history
…ker index. (#28)
  • Loading branch information
Raelifin authored and nottombrown committed Aug 5, 2017
1 parent e72c182 commit 572b19f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
35 changes: 30 additions & 5 deletions rl_teacher/segment_sampling.py
@@ -1,4 +1,7 @@
import math
from multiprocessing import Pool
import numpy as np
import gym.spaces.prng as space_prng

from rl_teacher.envs import get_timesteps_per_episode

Expand Down Expand Up @@ -55,10 +58,16 @@ def do_rollout(env, action_function):
"human_obs": np.array(human_obs)}
return path

def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds):
""" Generate a list of path segments by doing random rollouts. """
def basic_segments_from_rand_rollout(
env_id, make_env, n_desired_segments, clip_length_in_seconds,
# These are only for use with multiprocessing
seed=0, _verbose=True, _multiplier=1
):
""" Generate a list of path segments by doing random rollouts. No multiprocessing. """
segments = []
env = make_env(env_id)
env.seed(seed)
space_prng.seed(seed)
segment_length = int(clip_length_in_seconds * env.fps)
while len(segments) < n_desired_segments:
path = do_rollout(env, random_action)
Expand All @@ -70,8 +79,24 @@ def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length
if segment:
segments.append(segment)

if len(segments) % 10 == 0 and len(segments) > 0:
print("Collected %s/%s segments" % (len(segments), n_desired_segments))
if _verbose and len(segments) % 10 == 0 and len(segments) > 0:
print("Collected %s/%s segments" % (len(segments) * _multiplier, n_desired_segments * _multiplier))

print("Successfully collected %s segments" % (len(segments)))
if _verbose:
print("Successfully collected %s segments" % (len(segments) * _multiplier))
return segments

def segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds, workers):
""" Generate a list of path segments by doing random rollouts. Can use multiple processes. """
if workers < 2: # Default to basic segment collection
return basic_segments_from_rand_rollout(env_id, make_env, n_desired_segments, clip_length_in_seconds)

pool = Pool(processes=workers)
segments_per_worker = int(math.ceil(n_desired_segments / workers))
# One job per worker. Only the first worker is verbose.
jobs = [
(env_id, make_env, segments_per_worker, clip_length_in_seconds, i, i == 0, workers)
for i in range(workers)]
results = pool.starmap(basic_segments_from_rand_rollout, jobs)
pool.close()
return [segment for sublist in results for segment in sublist]
5 changes: 3 additions & 2 deletions rl_teacher/teach.py
Expand Up @@ -228,8 +228,9 @@ def main():
label_schedule = ConstantLabelSchedule(pretrain_labels=pretrain_labels)

print("Starting random rollouts to generate pretraining segments. No learning will take place...")
pretrain_segments = segments_from_rand_rollout(env_id, make_with_torque_removed,
n_desired_segments=pretrain_labels * 2, clip_length_in_seconds=CLIP_LENGTH)
pretrain_segments = segments_from_rand_rollout(
env_id, make_with_torque_removed, n_desired_segments=pretrain_labels * 2,
clip_length_in_seconds=CLIP_LENGTH, workers=args.workers)
for i in range(pretrain_labels): # Turn our random segments into comparisons
comparison_collector.add_segment_pair(pretrain_segments[i], pretrain_segments[i + pretrain_labels])

Expand Down

0 comments on commit 572b19f

Please sign in to comment.