/
rollout.py
169 lines (138 loc) · 6.62 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from collections import deque
import numpy as np
import pickle
from baselines.her.util import convert_episode_to_batch_major, store_args
class RolloutWorker:
@store_args
def __init__(self, venv, policy, dims, logger, T, rollout_batch_size=1,
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
random_eps=0, history_len=100, render=False, monitor=False, **kwargs):
"""Rollout worker generates experience by interacting with one or many environments.
Args:
venv: vectorized gym environments.
policy (object): the policy that is used to act
dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u)
logger (object): the logger that is used by the rollout worker
rollout_batch_size (int): the number of parallel rollouts that should be used
exploit (boolean): whether or not to exploit, i.e. to act optimally according to the
current policy without any exploration
use_target_net (boolean): whether or not to use the target net for rollouts
compute_Q (boolean): whether or not to compute the Q values alongside the actions
noise_eps (float): scale of the additive Gaussian noise
random_eps (float): probability of selecting a completely random action
history_len (int): length of history for statistics smoothing
render (boolean): whether or not to render the rollouts
"""
assert self.T > 0
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
self.success_history = deque(maxlen=history_len)
self.Q_history = deque(maxlen=history_len)
self.n_episodes = 0
self.reset_all_rollouts()
self.clear_history()
def reset_all_rollouts(self):
self.obs_dict = self.venv.reset()
self.initial_o = self.obs_dict['observation']
self.initial_ag = self.obs_dict['achieved_goal']
self.g = self.obs_dict['desired_goal']
def generate_rollouts(self):
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
policy acting on it accordingly.
"""
self.reset_all_rollouts()
# compute observations
o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
o[:] = self.initial_o
ag[:] = self.initial_ag
# generate episodes
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
dones = []
info_values = [np.empty((self.T - 1, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
Qs = []
for t in range(self.T):
policy_output = self.policy.get_actions(
o, ag, self.g,
compute_Q=self.compute_Q,
noise_eps=self.noise_eps if not self.exploit else 0.,
random_eps=self.random_eps if not self.exploit else 0.,
use_target_net=self.use_target_net)
if self.compute_Q:
u, Q = policy_output
Qs.append(Q)
else:
u = policy_output
if u.ndim == 1:
# The non-batched case should still have a reasonable shape.
u = u.reshape(1, -1)
o_new = np.empty((self.rollout_batch_size, self.dims['o']))
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
success = np.zeros(self.rollout_batch_size)
# compute new states and observations
obs_dict_new, _, done, info = self.venv.step(u)
o_new = obs_dict_new['observation']
ag_new = obs_dict_new['achieved_goal']
success = np.array([i.get('is_success', 0.0) for i in info])
if any(done):
# here we assume all environments are done is ~same number of steps, so we terminate rollouts whenever any of the envs returns done
# trick with using vecenvs is not to add the obs from the environments that are "done", because those are already observations
# after a reset
break
for i, info_dict in enumerate(info):
for idx, key in enumerate(self.info_keys):
info_values[idx][t, i] = info[i][key]
if np.isnan(o_new).any():
self.logger.warn('NaN caught during rollout generation. Trying again...')
self.reset_all_rollouts()
return self.generate_rollouts()
dones.append(done)
obs.append(o.copy())
achieved_goals.append(ag.copy())
successes.append(success.copy())
acts.append(u.copy())
goals.append(self.g.copy())
o[...] = o_new
ag[...] = ag_new
obs.append(o.copy())
achieved_goals.append(ag.copy())
episode = dict(o=obs,
u=acts,
g=goals,
ag=achieved_goals)
for key, value in zip(self.info_keys, info_values):
episode['info_{}'.format(key)] = value
# stats
successful = np.array(successes)[-1, :]
assert successful.shape == (self.rollout_batch_size,)
success_rate = np.mean(successful)
self.success_history.append(success_rate)
if self.compute_Q:
self.Q_history.append(np.mean(Qs))
self.n_episodes += self.rollout_batch_size
return convert_episode_to_batch_major(episode)
def clear_history(self):
"""Clears all histories that are used for statistics
"""
self.success_history.clear()
self.Q_history.clear()
def current_success_rate(self):
return np.mean(self.success_history)
def current_mean_Q(self):
return np.mean(self.Q_history)
def save_policy(self, path):
"""Pickles the current policy for later inspection.
"""
with open(path, 'wb') as f:
pickle.dump(self.policy, f)
def logs(self, prefix='worker'):
"""Generates a dictionary that contains all collected statistics.
"""
logs = []
logs += [('success_rate', np.mean(self.success_history))]
if self.compute_Q:
logs += [('mean_Q', np.mean(self.Q_history))]
logs += [('episode', self.n_episodes)]
if prefix != '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs