Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional buffer_batch_size param. Fix manifest directories. #5

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion gym_recording/playback.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def scan_recorded_traces(directory, episode_cb=None, max_episodes=None):
added_episode_count = 0
for batch in rdr.get_recorded_batches():
for ep in rdr.get_recorded_episodes(batch):
episode_cb(ep['observations'], ep['actions'], ep['rewards'])
episode_cb(ep['observations'], ep['actions'], ep['rewards'], ep['infos'])
added_episode_count += 1
if max_episodes is not None and added_episode_count >= max_episodes: return
rdr.close()
21 changes: 11 additions & 10 deletions gym_recording/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@

class TraceRecording(object):
_id_counter = 0
def __init__(self, directory=None):
def __init__(self, directory=None, buffer_batch_size=100):
"""
Create a TraceRecording, writing into directory
"""

if directory is None:
directory = os.path.join('/tmp', 'openai.gym.{}.{}'.format(time.time(), os.getpid()))
os.mkdir(directory)
Expand All @@ -30,10 +29,11 @@ def __init__(self, directory=None):
self.actions = []
self.observations = []
self.rewards = []
self.infos = []
self.episode_id = 0

self.buffered_step_count = 0
self.buffer_batch_size = 100
self.buffer_batch_size = buffer_batch_size

self.episodes_first = 0
self.episodes = []
Expand All @@ -44,11 +44,12 @@ def add_reset(self, observation):
self.end_episode()
self.observations.append(observation)

def add_step(self, action, observation, reward):
def add_step(self, action, observation, reward, info):
assert not self.closed
self.actions.append(action)
self.observations.append(observation)
self.rewards.append(reward)
self.infos.append(info)
self.buffered_step_count += 1

def end_episode(self):
Expand All @@ -64,10 +65,12 @@ def end_episode(self):
'actions': optimize_list_of_ndarrays(self.actions),
'observations': optimize_list_of_ndarrays(self.observations),
'rewards': optimize_list_of_ndarrays(self.rewards),
'infos': optimize_list_of_ndarrays(self.infos),
})
self.actions = []
self.observations = []
self.rewards = []
self.infos = []
self.episode_id += 1

if self.buffered_step_count >= self.buffer_batch_size:
Expand All @@ -83,7 +86,6 @@ def save_complete(self):

batch_fn = '{}.ep{:09}.json'.format(self.file_prefix, self.episodes_first)
bin_fn = '{}.ep{:09}.bin'.format(self.file_prefix, self.episodes_first)

with atomic_write.atomic_write(os.path.join(self.directory, batch_fn), False) as batch_f:
with atomic_write.atomic_write(os.path.join(self.directory, bin_fn), True) as bin_f:

Expand All @@ -99,25 +101,24 @@ def json_encode(obj):
return obj

json.dump({'episodes': self.episodes}, batch_f, default=json_encode)

bytes_per_step = float(bin_f.tell() + batch_f.tell()) / float(self.buffered_step_count)

# bytes_per_step = float(bin_f.tell() + batch_f.tell()) / float(self.buffered_step_count)

self.batches.append({
'first': self.episodes_first,
'len': len(self.episodes),
'fn': batch_fn})

manifest = {'batches': self.batches}
manifest_fn = os.path.join(self.directory, '{}.manifest.json'.format(self.file_prefix))
with atomic_write.atomic_write(os.path.join(self.directory, manifest_fn), False) as f:
with atomic_write.atomic_write(manifest_fn, False) as f:
json.dump(manifest, f)

# Adjust batch size, aiming for 5 MB per file.
# This seems like a reasonable tradeoff between:
# writing speed (not too much overhead creating small files)
# local memory usage (buffering an entire batch before writing)
# random read access (loading the whole file isn't too much work when just grabbing one episode)
self.buffer_batch_size = max(1, min(50000, int(5000000 / bytes_per_step + 1)))
# self.buffer_batch_size = max(1, min(50000, int(5000000 / bytes_per_step + 1)))

self.episodes = []
self.episodes_first = None
Expand Down
6 changes: 3 additions & 3 deletions gym_recording/wrappers/trace_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def episode_cb(observations, actions, rewards):


"""
def __init__(self, env, directory=None):
def __init__(self, env, directory=None, buffer_batch_size=100):
"""
Create a TraceRecordingWrapper around env, writing into directory
"""
super(TraceRecordingWrapper, self).__init__(env)
self.recording = None
trace_record_closer.register(self)

self.recording = TraceRecording(None)
self.recording = TraceRecording(directory, buffer_batch_size)
self.directory = self.recording.directory

def _step(self, action):
observation, reward, done, info = self.env.step(action)
self.recording.add_step(action, observation, reward)
self.recording.add_step(action, observation, reward, info)
return observation, reward, done, info

def _reset(self):
Expand Down
41 changes: 41 additions & 0 deletions test/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from gym_recording import playback
import matplotlib.pyplot as plt
import numpy as np
import glob, os

def handle_ep(observations, actions, rewards, infos):
# Plot config
plt.ion()
fig = plt.figure()
fig.gca().set_aspect('equal', adjustable='box')
ax = fig.gca()

xs = np.array([])
ys = np.array([])
alts = np.array([])

# plot empty line to generate line object
line, = ax.plot(xs, ys)

plt.ioff() # turn off interactive mode

print('\n\nAn episode begins!')
for obs, a, r, i in zip(observations, actions, rewards, infos):
print('Obs: {} a: {} r: {} info: {}'.format(obs, a, r, i))
if i:
x = i['self_state']['lon']
y = i['self_state']['lat']
alt = i['relative_alt']
xs = np.append(xs,x)
ys = np.append(ys,y)
alts = np.append(alts, alt)

plt.plot(xs, ys, 'o-')

if __name__ == '__main__':
path = '/tmp/gym/traces/CoGLE-nav-virtual-v0/train/'
files = glob.glob(os.path.join(path, '*'))
files.sort()

playback.scan_recorded_traces(files[-1], handle_ep)
plt.show()
17 changes: 17 additions & 0 deletions test/rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import gym
from gym_recording.wrappers import TraceRecordingWrapper

def main():
env = gym.make('CartPole-v0')
env = TraceRecordingWrapper(env, directory='./t', buffer_batch_size=10)
print('log dir {}'.format(env.directory))
print(env.__dict__)
env.reset()
for _ in range(10000):
_, _, done, _ = env.step(env.action_space.sample()) # take a random action
if done:
env.reset()
print('Done')

if __name__ == '__main__':
main()