In [None]:
from collections import deque

import jax
import numpy as np
from brax import envs
from tqdm import tqdm
from brax.envs.wrappers import gym as gym_wrapper


total_timesteps = 1_000_000
seed = 0
env = envs.create("halfcheetah", batch_size=1)
env = gym_wrapper.VectorGymWrapper(env)

action_size = np.prod(env.action_space.shape)
action_low = -1
action_high = 1

key = jax.random.split(jax.random.PRNGKey(seed), 3)

state = env.reset()

log_episodic_returns = deque(maxlen=5)
log_episodic_lengths = deque(maxlen=5)

reward_log = []

In [None]:
for global_step in tqdm(range(total_timesteps)):
    # action = jax.random.uniform(key=key, shape=(action_size,), minval=action_low, maxval=action_high)
    action = np.random.uniform(low=action_low, high=action_high, size=(1, action_size))

    # Perform action
    obs, reward, done, info = env.step(action)

    reward_log.append(reward)

    # Log episodic return and length
    if done:
        log_episodic_returns.append(np.sum(reward_log))
        log_episodic_lengths.append(len(reward_log))

        print(f"Episodic return: {np.sum(reward_log)} | Episodic length: {len(reward_log)}")
        reward_log = []