This repository has been archived by the owner on Sep 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 154
/
model_env.py
191 lines (168 loc) · 8.1 KB
/
model_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
import torch
import mbrl.types
from . import Model
class ModelEnv:
"""Wraps a dynamics model into a gym-like environment.
This class can wrap a dynamics model to be used as an environment. The only requirement
to use this class is for the model to use this wrapper is to have a method called
``predict()``
with signature `next_observs, rewards = model.predict(obs,actions, sample=, rng=)`
Args:
env (gym.Env): the original gym environment for which the model was trained.
model (:class:`mbrl.models.Model`): the model to wrap.
termination_fn (callable): a function that receives actions and observations, and
returns a boolean flag indicating whether the episode should end or not.
reward_fn (callable, optional): a function that receives actions and observations
and returns the value of the resulting reward in the environment.
Defaults to ``None``, in which case predicted rewards will be used.
generator (torch.Generator, optional): a torch random number generator (must be in the
same device as the given model). If None (default value), a new generator will be
created using the default torch seed.
"""
def __init__(
self,
env: gym.Env,
model: Model,
termination_fn: mbrl.types.TermFnType,
reward_fn: Optional[mbrl.types.RewardFnType] = None,
generator: Optional[torch.Generator] = None,
):
self.dynamics_model = model
self.termination_fn = termination_fn
self.reward_fn = reward_fn
self.device = model.device
self.observation_space = env.observation_space
self.action_space = env.action_space
self._current_obs: torch.Tensor = None
self._propagation_method: Optional[str] = None
self._model_indices = None
if generator:
self._rng = generator
else:
self._rng = torch.Generator(device=self.device)
self._return_as_np = True
def reset(
self, initial_obs_batch: np.ndarray, return_as_np: bool = True
) -> Dict[str, torch.Tensor]:
"""Resets the model environment.
Args:
initial_obs_batch (np.ndarray): a batch of initial observations. One episode for
each observation will be run in parallel. Shape must be ``B x D``, where
``B`` is batch size, and ``D`` is the observation dimension.
return_as_np (bool): if ``True``, this method and :meth:`step` will return
numpy arrays, otherwise it returns torch tensors in the same device as the
model. Defaults to ``True``.
Returns:
(dict(str, tensor)): the model state returned by `self.dynamics_model.reset()`.
"""
if isinstance(self.dynamics_model, mbrl.models.OneDTransitionRewardModel):
assert len(initial_obs_batch.shape) == 2 # batch, obs_dim
with torch.no_grad():
model_state = self.dynamics_model.reset(
initial_obs_batch.astype(np.float32), rng=self._rng
)
self._return_as_np = return_as_np
return model_state if model_state is not None else {}
def step(
self,
actions: mbrl.types.TensorType,
model_state: Dict[str, torch.Tensor],
sample: bool = False,
) -> Tuple[mbrl.types.TensorType, mbrl.types.TensorType, np.ndarray, Dict]:
"""Steps the model environment with the given batch of actions.
Args:
actions (torch.Tensor or np.ndarray): the actions for each "episode" to rollout.
Shape must be ``B x A``, where ``B`` is the batch size (i.e., number of episodes),
and ``A`` is the action dimension. Note that ``B`` must correspond to the
batch size used when calling :meth:`reset`. If a np.ndarray is given, it's
converted to a torch.Tensor and sent to the model device.
model_state (dict(str, tensor)): the model state as returned by :meth:`reset()`.
sample (bool): if ``True`` model predictions are stochastic. Defaults to ``False``.
Returns:
(tuple): contains the predicted next observation, reward, done flag and metadata.
The done flag is computed using the termination_fn passed in the constructor.
"""
assert len(actions.shape) == 2 # batch, action_dim
with torch.no_grad():
# if actions is tensor, code assumes it's already on self.device
if isinstance(actions, np.ndarray):
actions = torch.from_numpy(actions).to(self.device)
(
next_observs,
pred_rewards,
pred_terminals,
next_model_state,
) = self.dynamics_model.sample(
actions,
model_state,
deterministic=not sample,
rng=self._rng,
)
rewards = (
pred_rewards
if self.reward_fn is None
else self.reward_fn(actions, next_observs)
)
dones = self.termination_fn(actions, next_observs)
if pred_terminals is not None:
raise NotImplementedError(
"ModelEnv doesn't yet support simulating terminal indicators."
)
if self._return_as_np:
next_observs = next_observs.cpu().numpy()
rewards = rewards.cpu().numpy()
dones = dones.cpu().numpy()
return next_observs, rewards, dones, next_model_state
def render(self, mode="human"):
pass
def evaluate_action_sequences(
self,
action_sequences: torch.Tensor,
initial_state: np.ndarray,
num_particles: int,
) -> torch.Tensor:
"""Evaluates a batch of action sequences on the model.
Args:
action_sequences (torch.Tensor): a batch of action sequences to evaluate. Shape must
be ``B x H x A``, where ``B``, ``H``, and ``A`` represent batch size, horizon,
and action dimension, respectively.
initial_state (np.ndarray): the initial state for the trajectories.
num_particles (int): number of times each action sequence is replicated. The final
value of the sequence will be the average over its particles values.
Returns:
(torch.Tensor): the accumulated reward for each action sequence, averaged over its
particles.
"""
with torch.no_grad():
assert len(action_sequences.shape) == 3
population_size, horizon, action_dim = action_sequences.shape
# either 1-D state or 3-D pixel observation
assert initial_state.ndim in (1, 3)
tiling_shape = (num_particles * population_size,) + tuple(
[1] * initial_state.ndim
)
initial_obs_batch = np.tile(initial_state, tiling_shape).astype(np.float32)
model_state = self.reset(initial_obs_batch, return_as_np=False)
batch_size = initial_obs_batch.shape[0]
total_rewards = torch.zeros(batch_size, 1).to(self.device)
terminated = torch.zeros(batch_size, 1, dtype=bool).to(self.device)
for time_step in range(horizon):
action_for_step = action_sequences[:, time_step, :]
action_batch = torch.repeat_interleave(
action_for_step, num_particles, dim=0
)
_, rewards, dones, model_state = self.step(
action_batch, model_state, sample=True
)
rewards[terminated] = 0
terminated |= dones
total_rewards += rewards
total_rewards = total_rewards.reshape(-1, num_particles)
return total_rewards.mean(dim=1)