-
Notifications
You must be signed in to change notification settings - Fork 133
/
misc.py
200 lines (157 loc) · 6.73 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
""" Various auxiliary utilities """
import math
from os.path import join, exists
import torch
from torchvision import transforms
import numpy as np
from models import MDRNNCell, VAE, Controller
import gym
import gym.envs.box2d
# A bit dirty: manually change size of car racing env
gym.envs.box2d.car_racing.STATE_W, gym.envs.box2d.car_racing.STATE_H = 64, 64
# Hardcoded for now
ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE =\
3, 32, 256, 64, 64
# Same
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.ToTensor()
])
def sample_continuous_policy(action_space, seq_len, dt):
""" Sample a continuous policy.
Atm, action_space is supposed to be a box environment. The policy is
sampled as a brownian motion a_{t+1} = a_t + sqrt(dt) N(0, 1).
:args action_space: gym action space
:args seq_len: number of actions returned
:args dt: temporal discretization
:returns: sequence of seq_len actions
"""
actions = [action_space.sample()]
for _ in range(seq_len):
daction_dt = np.random.randn(*actions[-1].shape)
actions.append(
np.clip(actions[-1] + math.sqrt(dt) * daction_dt,
action_space.low, action_space.high))
return actions
def save_checkpoint(state, is_best, filename, best_filename):
""" Save state in filename. Also save in best_filename if is_best. """
torch.save(state, filename)
if is_best:
torch.save(state, best_filename)
def flatten_parameters(params):
""" Flattening parameters.
:args params: generator of parameters (as returned by module.parameters())
:returns: flattened parameters (i.e. one tensor of dimension 1 with all
parameters concatenated)
"""
return torch.cat([p.detach().view(-1) for p in params], dim=0).cpu().numpy()
def unflatten_parameters(params, example, device):
""" Unflatten parameters.
:args params: parameters as a single 1D np array
:args example: generator of parameters (as returned by module.parameters()),
used to reshape params
:args device: where to store unflattened parameters
:returns: unflattened parameters
"""
params = torch.Tensor(params).to(device)
idx = 0
unflattened = []
for e_p in example:
unflattened += [params[idx:idx + e_p.numel()].view(e_p.size())]
idx += e_p.numel()
return unflattened
def load_parameters(params, controller):
""" Load flattened parameters into controller.
:args params: parameters as a single 1D np array
:args controller: module in which params is loaded
"""
proto = next(controller.parameters())
params = unflatten_parameters(
params, controller.parameters(), proto.device)
for p, p_0 in zip(controller.parameters(), params):
p.data.copy_(p_0)
class RolloutGenerator(object):
""" Utility to generate rollouts.
Encapsulate everything that is needed to generate rollouts in the TRUE ENV
using a controller with previously trained VAE and MDRNN.
:attr vae: VAE model loaded from mdir/vae
:attr mdrnn: MDRNN model loaded from mdir/mdrnn
:attr controller: Controller, either loaded from mdir/ctrl or randomly
initialized
:attr env: instance of the CarRacing-v0 gym environment
:attr device: device used to run VAE, MDRNN and Controller
:attr time_limit: rollouts have a maximum of time_limit timesteps
"""
def __init__(self, mdir, device, time_limit):
""" Build vae, rnn, controller and environment. """
# Loading world model and vae
vae_file, rnn_file, ctrl_file = \
[join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']]
assert exists(vae_file) and exists(rnn_file),\
"Either vae or mdrnn is untrained."
vae_state, rnn_state = [
torch.load(fname, map_location={'cuda:0': str(device)})
for fname in (vae_file, rnn_file)]
for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)):
print("Loading {} at epoch {} "
"with test loss {}".format(
m, s['epoch'], s['precision']))
self.vae = VAE(3, LSIZE).to(device)
self.vae.load_state_dict(vae_state['state_dict'])
self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device)
self.mdrnn.load_state_dict(
{k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()})
self.controller = Controller(LSIZE, RSIZE, ASIZE).to(device)
# load controller if it was previously saved
if exists(ctrl_file):
ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)})
print("Loading Controller with reward {}".format(
ctrl_state['reward']))
self.controller.load_state_dict(ctrl_state['state_dict'])
self.env = gym.make('CarRacing-v0')
self.device = device
self.time_limit = time_limit
def get_action_and_transition(self, obs, hidden):
""" Get action and transition.
Encode obs to latent using the VAE, then obtain estimation for next
latent and next hidden state using the MDRNN and compute the controller
corresponding action.
:args obs: current observation (1 x 3 x 64 x 64) torch tensor
:args hidden: current hidden state (1 x 256) torch tensor
:returns: (action, next_hidden)
- action: 1D np array
- next_hidden (1 x 256) torch tensor
"""
_, latent_mu, _ = self.vae(obs)
action = self.controller(latent_mu, hidden[0])
_, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden)
return action.squeeze().cpu().numpy(), next_hidden
def rollout(self, params, render=False):
""" Execute a rollout and returns minus cumulative reward.
Load :params: into the controller and execute a single rollout. This
is the main API of this class.
:args params: parameters as a single 1D np array
:returns: minus cumulative reward
"""
# copy params into the controller
if params is not None:
load_parameters(params, self.controller)
obs = self.env.reset()
# This first render is required !
self.env.render()
hidden = [
torch.zeros(1, RSIZE).to(self.device)
for _ in range(2)]
cumulative = 0
i = 0
while True:
obs = transform(obs).unsqueeze(0).to(self.device)
action, hidden = self.get_action_and_transition(obs, hidden)
obs, reward, done, _ = self.env.step(action)
if render:
self.env.render()
cumulative += reward
if done or i > self.time_limit:
return - cumulative
i += 1