-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathspdr.py
More file actions
542 lines (459 loc) · 24.7 KB
/
spdr.py
File metadata and controls
542 lines (459 loc) · 24.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
# Copyright (c) 2020, Fabio Muratore, Honda Research Institute Europe GmbH, and
# Technical University of Darmstadt.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the name of Fabio Muratore, Honda Research Institute Europe GmbH,
# or Technical University of Darmstadt, nor the names of its contributors may
# be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL FABIO MURATORE, HONDA RESEARCH INSTITUTE EUROPE GMBH,
# OR TECHNICAL UNIVERSITY OF DARMSTADT BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import os.path
from csv import DictWriter
from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch as to
from scipy.optimize import NonlinearConstraint, minimize
from torch.distributions import MultivariateNormal
import pyrado
from pyrado.algorithms.base import Algorithm
from pyrado.algorithms.step_based.actor_critic import ActorCritic
from pyrado.algorithms.utils import RolloutSavingWrapper, until_thold_exceeded
from pyrado.domain_randomization.domain_parameter import SelfPacedDomainParam
from pyrado.domain_randomization.transformations import DomainParamTransform
from pyrado.environment_wrappers.domain_randomization import DomainRandWrapper
from pyrado.environment_wrappers.utils import typed_env
from pyrado.environments.base import Env
from pyrado.policies.base import Policy
from pyrado.sampling.step_sequence import StepSequence
def ravel_tril_elements(A: to.Tensor) -> to.Tensor:
if not (len(A.shape) == 2):
raise pyrado.ShapeErr(msg="A must be two-dimensional")
if not (A.shape[0] == A.shape[1]):
raise pyrado.ShapeErr(msg="A must be square")
return to.cat([A[i, : i + 1] for i in range(A.shape[0])], dim=0)
def unravel_tril_elements(a: to.Tensor) -> to.Tensor:
if not (len(a.shape) == 1):
raise pyrado.ShapeErr(msg="a must be one-dimensional")
raveled_dim = a.shape[0]
dim = int((np.sqrt(8 * raveled_dim + 1) - 1) / 2) # inverse Gaussian summation formula
A = to.zeros((dim, dim)).double()
for i in range(dim):
A[i, : i + 1] = a[int(i * (i + 1) / 2) :][: i + 1]
return A
class MultivariateNormalWrapper:
"""
A wrapper for PyTorch's multivariate normal distribution with diagonal covariance.
It is used to get a SciPy optimizer-ready version of the parameters of a distribution,
i.e. a vector that can be used as the target variable.
"""
def __init__(self, mean: to.Tensor, cov_chol: to.Tensor):
"""
Constructor.
:param mean: mean of the distribution; shape `(k,)`
:param cov_chol: Cholesky decomposition of the covariance matrix; must be lower triangular; shape `(k, k)`
if it is the actual matrix or shape `(k * (k + 1) / 2,)` if it is raveled
"""
if not (len(mean.shape) == 1):
raise pyrado.ShapeErr(msg="mean must be one-dimensional")
self._k = mean.shape[0]
cov_chol_tril_is_raveled = len(cov_chol.shape) == 1
if cov_chol_tril_is_raveled:
if not (cov_chol.shape[0] == self._k * (self._k + 1) / 2):
raise pyrado.ShapeErr(msg="raveled cov_chol must have shape (k (k + 1) / 2,)")
else:
if not (len(cov_chol.shape) == 2):
raise pyrado.ShapeErr(msg="cov_chol must be two-dimensional")
if not (cov_chol.shape[0] == cov_chol.shape[1]):
raise pyrado.ShapeErr(msg="cov_chol must be square")
if not (cov_chol.shape[0] == mean.shape[0]):
raise pyrado.ShapeErr(msg="cov_chol and mean must have same size")
if not (to.allclose(cov_chol, cov_chol.tril())):
raise pyrado.ValueErr(msg="cov_chol must be lower triangular")
self._mean = mean.clone().detach().requires_grad_(True)
cov_chol_tril = cov_chol.clone().detach()
if not cov_chol_tril_is_raveled:
cov_chol_tril = ravel_tril_elements(cov_chol_tril)
self._cov_chol_tril = cov_chol_tril.requires_grad_(True)
self._update_distribution()
@staticmethod
def from_stacked(dim: int, stacked: np.ndarray) -> "MultivariateNormalWrapper":
r"""
Creates an instance of this class from the given stacked numpy array as generated e.g. by
`MultivariateNormalWrapper.get_stacked(self)`.
:param dim: dimensionality `k` of the random variable
:param stacked: array containing the mean and standard deviations of shape `(k + k * (k + 1) / 2,)`, where the
first `k` entries are the mean and the last `k * (k + 1) / 2` entries are lower triangular
entries of the Cholesky decomposition of the covariance matrix
:return: a `MultivariateNormalWrapper` with the given mean/cov.
"""
if not (len(stacked.shape) == 1):
raise pyrado.ValueErr(msg="Stacked has invalid shape! Must be 1-dimensional.")
if not (stacked.shape[0] == dim + dim * (dim + 1) / 2):
raise pyrado.ValueErr(given_name="stacked", msg="invalid size, must be dim + dim * (dim + 1) / 2)")
mean = stacked[:dim]
cov_chol_tril = stacked[dim:]
return MultivariateNormalWrapper(to.tensor(mean).double(), to.tensor(cov_chol_tril).double())
@property
def dim(self):
"""Get the size (dimensionality) of the random variable."""
return self._mean.shape[0]
@property
def mean(self):
"""Get the mean."""
return self._mean
@mean.setter
def mean(self, mean: to.Tensor):
"""Set the mean."""
if not (mean.shape == self.mean.shape):
raise pyrado.ShapeErr(given_name="mean", expected_match=self.mean.shape)
self._mean = mean
self._update_distribution()
@property
def cov(self):
"""Get the covariance matrix, shape `(k, k)`."""
return self.cov_chol @ self.cov_chol.T
@property
def cov_chol(self) -> to.Tensor:
"""Get the Cholesky decomposition of the covariance; shape `(k, k)`."""
return unravel_tril_elements(self._cov_chol_tril)
@property
def cov_chol_tril(self) -> to.Tensor:
"""Get the lower triangular of the Cholesky decomposition of the covariance; shape `(k * (k + 1) / 2)`."""
return self._cov_chol_tril
@cov_chol_tril.setter
def cov_chol_tril(self, cov_chol_tril: to.Tensor):
"""Set the standard deviations, shape `(k,)`."""
if not (cov_chol_tril.shape == self.cov_chol_tril.shape):
raise pyrado.ShapeErr(given_name="cov_chol_tril", expected_match=self.cov_chol_tril.shape)
self._cov_chol_tril = cov_chol_tril
self._update_distribution()
def parameters(self) -> Iterator[to.Tensor]:
"""Get the parameters (mean and lower triangular covariance Cholesky) of this distribution."""
yield self.mean
yield self.cov_chol_tril
def get_stacked(self) -> np.ndarray:
"""
Get the numpy representations of the mean and transformed covariance stacked on top of each other.
:return: stacked mean and transformed covariance; shape `(k + k * (k + 1) / 2,)`
"""
return np.concatenate([self.mean.detach().numpy(), self.cov_chol_tril.detach().numpy()])
def _update_distribution(self):
"""Update `self.distribution` according to the current mean and covariance."""
self.distribution = MultivariateNormal(self.mean, self.cov)
class SPDR(Algorithm):
"""
Self-Paced Domain Randomization (SPDR)
This algorithm wraps another algorithm. The main purpose is to apply self-paced RL [1].
.. seealso::
[1] P. Klink, H. Abdulsamad, B. Belousov, C. D'Eramo, J. Peters, and J. Pajarinen,
"A Probabilistic Interpretation of Self-Paced Learning with Applications to Reinforcement Learning", arXiv, 2021
"""
name: str = "spdr"
def __init__(
self,
env: DomainRandWrapper,
subroutine: Algorithm,
kl_constraints_ub: float,
max_iter: int,
performance_lower_bound: float,
var_lower_bound: Optional[float] = 0.04,
kl_threshold: float = 0.1,
optimize_mean: bool = True,
optimize_cov: bool = True,
max_subrtn_retries: int = 1,
):
"""
Constructor
:param env: environment wrapped in a DomainRandWrapper
:param subroutine: algorithm which performs the policy/value-function optimization, which
must expose its sampler
:param kl_constraints_ub: upper bound for the KL-divergence
:param max_iter: Maximal iterations for the SPDR algorithm (not for the subroutine)
:param performance_lower_bound: lower bound for the performance SPDR tries to stay above
during distribution updates
:param var_lower_bound: clipping value for the variance,necessary when using very small target variances; prefer
a log-transformation instead
:param kl_threshold: threshold for the KL-divergence until which std_lower_bound is enforced
:param optimize_mean: whether the mean should be changed or considered fixed
:param optimize_cov: whether the (co-)variance should be changed or considered fixed
:param max_subrtn_retries: how often a failed (median performance < 30 % of performance_lower_bound)
training attempt of the subroutine should be reattempted
"""
if not isinstance(subroutine, Algorithm):
raise pyrado.TypeErr(given_name="subroutine", given=subroutine, expected_type=Algorithm)
if not hasattr(subroutine, "sampler"):
raise AttributeError("The subroutine must have a sampler attribute!")
if not typed_env(env, DomainRandWrapper):
raise pyrado.TypeErr(given_name="env", given=env, expected_type=DomainRandWrapper)
# Call Algorithm's constructor with the subroutine's properties
super().__init__(subroutine.save_dir, max_iter, subroutine.policy, subroutine.logger)
# Wrap the sampler of the subroutine with an rollout saving wrapper
self._subrtn = subroutine
self._subrtn.sampler = RolloutSavingWrapper(subroutine.sampler)
self._subrtn.save_name = self._subrtn.name
self._env = env
# Properties for the variance bound and kl constraint
self._kl_constraints_ub = kl_constraints_ub
self._var_lower_bound = var_lower_bound
self._kl_threshold = kl_threshold
# Properties of the performance constraint
self._performance_lower_bound = performance_lower_bound
self._performance_lower_bound_reached = False
self._optimize_mean = optimize_mean
self._optimize_cov = optimize_cov
self._max_subrtn_retries = max_subrtn_retries
self._spl_parameter = None
for param in env.randomizer.domain_params:
if isinstance(param, SelfPacedDomainParam):
if self._spl_parameter is None:
self._spl_parameter = param
else:
raise pyrado.ValueErr(msg="randomizer contains more than one spl param")
# evaluation multidim
header = ["iteration", "objective_output", "status", "cg_stop_cond", "mean", "cov"]
f = open(os.path.join(subroutine.save_dir, "optimizer.csv"), "w", buffering=1)
global optimize_logger
optimize_logger = DictWriter(f, fieldnames=header)
optimize_logger.writeheader()
@property
def sample_count(self) -> int:
# Forward to subroutine
return self._subrtn.sample_count
@property
def dim(self) -> int:
return self._spl_parameter.target_mean.shape[0]
@property
def subrtn_sampler(self) -> RolloutSavingWrapper:
# It is checked in the constructor that the sampler is a RolloutSavingWrapper.
# noinspection PyTypeChecker
return self._subrtn.sampler
def step(self, snapshot_mode: str, meta_info: dict = None):
"""
Perform a step of SPDR. This includes training the subroutine and updating the context distribution accordingly.
For a description of the parameters see `pyrado.algorithms.base.Algorithm.step`.
"""
self.save_snapshot()
# Add these keys to the logger as dummy values.
self.logger.add_value("sprl number of particles", 0)
self.logger.add_value("spdr constraint kl", 0.0)
self.logger.add_value("spdr constraint performance", 0.0)
self.logger.add_value("spdr objective", 0.0)
self._log_context_distribution()
# If we are in the first iteration and have a bad performance,
# we want to completely reset the policy if training is unsuccessful
reset_policy = False
if self.curr_iter == 0:
reset_policy = True
until_thold_exceeded(self._performance_lower_bound * 0.3, self._max_subrtn_retries)(
self._train_subroutine_and_evaluate_perf
)(snapshot_mode, meta_info, reset_policy)
previous_distribution = MultivariateNormalWrapper(
self._spl_parameter.context_mean.double(), self._spl_parameter.context_cov_chol.double()
)
target_distribution = MultivariateNormalWrapper(
self._spl_parameter.target_mean.double(), self._spl_parameter.target_cov_chol.double()
)
proposal_rollouts = self._sample_proposal_rollouts()
contexts, contexts_old_log_prob, values = self._extract_particles(proposal_rollouts, previous_distribution)
# Define the SPRL optimization problem
kl_constraint = self._make_kl_constraint(previous_distribution, self._kl_constraints_ub)
performance_constraint = self._make_performance_constraint(
contexts, contexts_old_log_prob, values, self._performance_lower_bound
)
constraints = [kl_constraint, performance_constraint]
objective_fn = self._make_objective_fn(target_distribution)
x0 = previous_distribution.get_stacked()
minimize_args = dict(
fun=objective_fn,
x0=x0,
method="trust-constr",
jac=True,
constraints=constraints,
options={"gtol": 1e-4, "xtol": 1e-6},
# bounds=bounds,
)
print("Performing SPDR update.")
try:
result = minimize(**minimize_args)
new_x = result.x
# Reset parameters if optimization was not successful
if not result.success:
# If optimization process was not a success
old_f = objective_fn(previous_distribution.get_stacked())[0]
constraints_satisfied = all((const.lb <= const.fun(result.x) <= const.ub for const in constraints))
# std_ok = bounds is None or (np.all(bounds.lb <= result.x)) and np.all(result.x <= bounds.ub)
std_ok = True
update_successful = constraints_satisfied and std_ok and result.fun < old_f
if not update_successful:
print(f"Update unsuccessful, keeping old SPDR parameters.")
new_x = x0
except ValueError as e:
print(f"Update failed with error, keeping old SPDR parameters.", e)
new_x = x0
self._adapt_parameters(new_x)
# we can't use the stored values here as new_x might not be result.x
self.logger.add_value("spdr constraint kl", kl_constraint.fun(new_x))
self.logger.add_value("spdr constraint performance", performance_constraint.fun(new_x))
self.logger.add_value("spdr objective", objective_fn(new_x)[0])
def reset(self, seed: int = None):
# Forward to subroutine
self._subrtn.reset(seed)
self.subrtn_sampler.reset_rollouts()
def save_snapshot(self, meta_info: dict = None):
self.subrtn_sampler.reset_rollouts()
super().save_snapshot(meta_info)
if meta_info is None:
# This algorithm instance is not a subroutine of another algorithm
self._subrtn.save_snapshot(meta_info)
def load_snapshot(self, parsed_args) -> Tuple[Env, Policy, dict]:
env, policy, extra = super().load_snapshot(parsed_args)
# Algorithm specific
if isinstance(self._subrtn, ActorCritic):
ex_dir = self._save_dir or getattr(parsed_args, "dir", None)
extra["vfcn"] = pyrado.load(
f"{parsed_args.vfcn_name}.pt", ex_dir, obj=self._subrtn.critic.vfcn, verbose=True
)
return env, policy, extra
def _make_objective_fn(
self, target_distribution: MultivariateNormalWrapper
) -> Callable[[np.ndarray], Tuple[float, np.ndarray]]:
def objective_fn(x):
"""Tries to find the minimum kl divergence between the current and the update distribution, which
still satisfies the minimum update constraint and the performance constraint."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(distribution.distribution, target_distribution.distribution)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))
return (
kl_divergence.detach().numpy().item(),
np.concatenate([g.detach().numpy() for g in grads]),
)
return objective_fn
def _make_kl_constraint(
self, previous_distribution: MultivariateNormalWrapper, kl_constraint_ub: float
) -> NonlinearConstraint:
def kl_constraint_fn(x):
"""Compute the constraint for the KL-divergence between current and proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
return kl_divergence.detach().numpy().item()
def kl_constraint_fn_prime(x):
"""Compute the derivative for the KL-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
kl_divergence = to.distributions.kl_divergence(
previous_distribution.distribution, distribution.distribution
)
grads = to.autograd.grad(kl_divergence, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])
return NonlinearConstraint(
fun=kl_constraint_fn,
lb=-np.inf,
ub=kl_constraint_ub,
jac=kl_constraint_fn_prime,
)
def _make_performance_constraint(
self, contexts: to.Tensor, contexts_old_log_prob: to.Tensor, values: to.Tensor, performance_lower_bound: float
) -> NonlinearConstraint:
def performance_constraint_fn(x):
"""Compute the constraint for the expected performance under the proposed distribution."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
return performance.detach().numpy().item()
def performance_constraint_fn_prime(x):
"""Compute the derivative for the performance-constraint (used for scipy optimizer)."""
distribution = MultivariateNormalWrapper.from_stacked(self.dim, x)
performance = self._compute_expected_performance(distribution, contexts, contexts_old_log_prob, values)
grads = to.autograd.grad(performance, list(distribution.parameters()))
return np.concatenate([g.detach().numpy() for g in grads])
return NonlinearConstraint(
fun=performance_constraint_fn,
lb=performance_lower_bound,
ub=np.inf,
jac=performance_constraint_fn_prime,
)
def _log_context_distribution(self):
context_mean = self._spl_parameter.context_mean.double()
context_cov = self._spl_parameter.context_cov.double()
context_cov_chol = self._spl_parameter.context_cov_chol.double()
for param_a_idx, param_a_name in enumerate(self._spl_parameter.name):
for param_b_idx, param_b_name in enumerate(self._spl_parameter.name):
self.logger.add_value(
f"context cov for {param_a_name}--{param_b_name}",
context_cov[param_a_idx, param_b_idx].item(),
)
self.logger.add_value(
f"context cov_chol for {param_a_name}--{param_b_name}",
context_cov_chol[param_a_idx, param_b_idx].item(),
)
if param_a_name == param_b_name:
self.logger.add_value(f"context mean for {param_a_name}", context_mean[param_a_idx].item())
break
def _sample_proposal_rollouts(self) -> List[List[StepSequence]]:
return self.subrtn_sampler.rollouts
def _extract_particles(
self, rollouts_all: List[List[StepSequence]], distribution: MultivariateNormalWrapper
) -> Tuple[to.Tensor, to.Tensor, to.Tensor]:
def get_domain_param_value(ro: StepSequence, param_name: str) -> np.ndarray:
domain_param_dict = ro.rollout_info["domain_param"]
untransformed_param_name = param_name + DomainParamTransform.UNTRANSFORMED_DOMAIN_PARAMETER_SUFFIX
if untransformed_param_name in domain_param_dict:
return domain_param_dict[untransformed_param_name]
return domain_param_dict[param_name]
contexts = to.tensor(
[
[to.from_numpy(get_domain_param_value(ro, name)) for rollouts in rollouts_all for ro in rollouts]
for name in self._spl_parameter.name
],
requires_grad=True,
).T
self.logger.add_value("sprl number of particles", contexts.shape[0])
contexts_log_prob = distribution.distribution.log_prob(contexts.double())
values = to.tensor([ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts])
return contexts, contexts_log_prob, values
# noinspection PyMethodMayBeStatic
def _compute_expected_performance(
self, distribution: MultivariateNormalWrapper, context: to.Tensor, old_log_prop: to.Tensor, values: to.Tensor
) -> to.Tensor:
"""Calculate the expected performance after an update step."""
context_ratio = to.exp(distribution.distribution.log_prob(context) - old_log_prop)
return to.mean(context_ratio * values)
def _adapt_parameters(self, result: np.ndarray) -> None:
"""Update the parameters of the distribution based on the result of
the optimization step and the general algorithm settings."""
context_distr = MultivariateNormalWrapper.from_stacked(self.dim, result)
self._spl_parameter.adapt("context_mean", context_distr.mean)
self._spl_parameter.adapt("context_cov_chol", context_distr.cov_chol)
def _train_subroutine_and_evaluate_perf(
self, snapshot_mode: str, meta_info: dict = None, reset_policy: bool = False, **kwargs
) -> float:
"""
Internal method required by the `until_thold_exceeded` function.
The parameters are the same as for the regular `train()` call and are explained there.
:param reset_policy: if `True` the policy will be reset before training
:return: the median undiscounted return
"""
if reset_policy:
self._subrtn.init_modules(False)
self._subrtn.reset()
self._subrtn.train(snapshot_mode, None, meta_info)
rollouts_all = self.subrtn_sampler.rollouts
return np.median([[ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts]]).item()