In [None]:
import jax
import brax
from brax import envs
from typing import NamedTuple
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper
from gymnax.environments import environment, spaces

import jax.numpy as jnp
from jaxrl_m.run_supersac import create_learner
from typing import Optional, Tuple, Union, Any

from functools import partial
import chex
from flax import struct
import jax

jax.config.update('jax_log_compiles', False)

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray



class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env

    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)


class VecEnv(GymnaxWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
        self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
    
    
class BraxGymnaxWrapper:
    def __init__(self, env_name, backend="mjx"):
        env = envs.get_environment(env_name=env_name, backend=backend)
        #env = EpisodeWrapper(env, episode_length=1000, action_repeat=1)
        #env = AutoResetWrapper(env)
        self._env = env
        self.action_size = env.action_size
        self.observation_size = (env.observation_size,)

    def reset(self, key, params=None):
        state = self._env.reset(key)
        return state.obs, state

    def step(self, key, state, action, params=None):
        next_state = self._env.step(state, action)
        return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}

    def observation_space(self, params=None):
        return spaces.Box(
            low=-jnp.inf,
            high=jnp.inf,
            shape=(self._env.observation_size,),
        )

    def action_space(self, params=None):
        return spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(self._env.action_size,),
        )
     
     
@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    valid_mask : bool
    disc_valid_mask : float
    episode_returns: float
    disc_episode_returns : float
    
    episode_lengths: int
    timestep: int   

class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)
        

    #@partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 1,1,0, 0, 0,0 )
        return obs, state

    #@partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        
        discount = 0.99
        valid_mask = state.valid_mask
        new_episode_returns = state.episode_returns + reward * valid_mask
        new_disc_episode_returns = state.disc_episode_returns + reward * state.disc_valid_mask
        new_episode_length = state.episode_lengths + valid_mask
        new_valid_mask = state.valid_mask * (1 - done)
        new_disc_valid_mask = state.disc_valid_mask * discount * (1-done)
        

        state = LogEnvState(
            env_state=env_state,
            valid_mask = new_valid_mask,
            disc_valid_mask = new_disc_valid_mask,
            episode_returns= new_episode_returns,
            disc_episode_returns = new_disc_episode_returns,
            episode_lengths=new_episode_length,
            timestep=state.timestep + valid_mask,
        )
        
        
        info["timestep"] = state.timestep
        info["valid_mask"] = state.valid_mask
        info["episode_returns"] = state.episode_returns
        info["episode_lengths"] = state.episode_lengths
        return obs, state, reward, done, info

In [None]:

config = {
    "NUM_ENVS": 10,
    "NUM_STEPS": 1000,
    "BATCH_SIZE": 1,
    "discount":0.99,
    "ENV_NAME": "halfcheetah",
}


rng = jax.random.PRNGKey(0)  # use a dummy rng here
env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
env = LogWrapper(env)
env = VecEnv(env)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = env.reset(reset_rng, env_params)
action = env.action_space().sample(rng)


learner_args = {
    "seed":42,
    "observations":obsv[0].reshape(1,-1),
    "actions":action.reshape(1,-1),
    "discount":0.99,"discount_actor":True,"discount_entropy":True,"num_critics":5}


agent = create_learner(**learner_args)



In [None]:

runner_state = (agent, env_state, obsv, rng)

def _env_step(runner_state, unused):
                agent, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi = agent.actor(last_obs, temperature=1.)
                action = pi.sample(seed=rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, reward, log_prob, last_obs, info
                )
                runner_state = (agent, env_state, obsv, rng)
                return runner_state, transition



@jax.jit
def run_parallel_envs(runner_state):

    runner_state, traj_batch = jax.lax.scan(
        _env_step, runner_state, None, config["NUM_STEPS"]
    )
    
    return traj_batch



In [None]:
traj = run_parallel_envs(runner_state)