In [1]:
from functools import partial

import gymnasium as gym
import numpy as np
from scipy.stats import multivariate_normal

import pyrlmala.envs

In [2]:
sample_dim = 2
num_envs = 10
log_target_pdf = partial(multivariate_normal.logpdf, mean=np.zeros(sample_dim), cov=np.eye(sample_dim))
grad_log_target_pdf = lambda x: -x
initial_sample = np.zeros(sample_dim)

In [3]:
envs = gym.make_vec("BarkerEnv-v1.0", num_envs=num_envs, vectorization_mode="sync", log_target_pdf_unsafe=log_target_pdf, grad_log_target_pdf_unsafe=grad_log_target_pdf, initial_sample=initial_sample)

In [4]:
obs, _ = envs.reset(seed=1234)
obs

array([[ 0.        ,  0.        , -1.60383681,  0.06409991],
       [ 0.        ,  0.        , -0.90966538,  0.2652688 ],
       [ 0.        ,  0.        , -1.35852801, -1.15616959],
       [ 0.        ,  0.        ,  0.0024004 ,  0.33966628],
       [ 0.        ,  0.        ,  0.97214775, -0.31469331],
       [ 0.        ,  0.        ,  0.75944749, -1.6306404 ],
       [ 0.        ,  0.        , -0.46964345, -0.85787206],
       [ 0.        ,  0.        , -1.82276655,  1.66011418],
       [ 0.        ,  0.        ,  1.06659101, -1.06469565],
       [ 0.        ,  0.        ,  0.24460552,  0.87543632]])

In [5]:
step_size = np.array([0.1])
actions = np.repeat(envs.call("inverse_softplus", step_size), 2).reshape(num_envs, 2)
observations, rewards, terminations, truncations, infos = envs.step(actions)

In [6]:
observations

array([[ 0.        ,  0.        ,  0.01526192,  0.08637439],
       [-0.90966538,  0.2652688 , -0.79067223,  0.32657826],
       [-1.35852801, -1.15616959, -1.32450949, -1.16475811],
       [ 0.0024004 ,  0.33966628, -0.04646464,  0.36359864],
       [ 0.97214775, -0.31469331,  0.90651413, -0.17758806],
       [ 0.        ,  0.        ,  0.17243632,  0.00887074],
       [-0.46964345, -0.85787206, -0.35675755, -0.99046971],
       [-1.82276655,  1.66011418, -1.97864842,  1.77535415],
       [ 1.06659101, -1.06469565,  0.93275125, -1.10599102],
       [ 0.24460552,  0.87543632,  0.48402639,  0.91642673]])

In [7]:
rewards

array([129.47703584,  77.33997263, 174.04786875,   5.94803361,
        88.5334897 , 150.9022721 ,  83.63933504, 104.6026067 ,
       162.84315547,  71.59068804])