Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional API and proof-of-concept jax classic-control envs #2958

Closed
wants to merge 15 commits into from

Conversation

RedTachyon
Copy link
Contributor

@RedTachyon RedTachyon commented Jul 11, 2022

This is very WIP and more of a proof-of-concept to show how I envision a functional API that could support jax/brax envs, inconjunction with #2955 (which would convert the functional API into the classic OOP API)

The idea is that this would, at first, be a relatively spartan internal API to have a coherent way of creating jittable envs. Over time, we would potentially evolve this as an alternative API, which would allow for jitting entire workflows. (I'm pretty sure #2955 can't be fully jitted by jax due to the mutable state)

Known sharp edges:

  • In the Jax envs, reusing vmapped functions within other vmapped functions can break stuff. For example calling env.terminal(state) inside of env.reward(state) will cause issues. Don't do that.

@RedTachyon RedTachyon marked this pull request as draft July 11, 2022 10:09
@pseudo-rnd-thoughts
Copy link
Contributor

pseudo-rnd-thoughts commented Jul 13, 2022

The working functional API specification and ideas by @RedTachyon, @balisujohn and myself

  1. The functional API should be separate from gym.Env - The current gym.Env follows a strictly object-oriented style that in some ways are quite separate from POMDP theory which is functional in nature. Therefore, we propose that this functional API doesn't follow the gym.Env API but is rather a separate abstract implementation.
  2. Agnostic transformations - This idea originally began as thinking about how to implement new Jax-based environments that use jax.jit and jax.vmap optimisations. However, we realised that a more general implementation that could support new or alternative optimisation features could be helpful, i.e. numba.
  3. OOP translation - The advantage of having environments written in the functional API is that users can take full advantage of the "raw" reset and step functions. However, for a majority of users, they will want to use the normal OOP gym.Env style. Therefore, for the Jax environment an interface class is developed that can automatically transform a Jax based functional environment to a jax gym.Env. This allows more interface classes to be developed and added to gym as technologies are developed.

@verbose-void
Copy link

verbose-void commented Jul 26, 2022

it's funny how you're building a functional API but using a class to do it 😂 of course, it makes sense it basically is just like a module.

just for clarity, these functional envs are required to have no internal state right?

@RedTachyon
Copy link
Contributor Author

RedTachyon commented Jul 26, 2022

functional API but using a class to do it

'tis the Python life. The class is meant to effectively work as a namespace. Maybe a typeclass in some more sane languages. Ultimately, everything in Python is an object, so we wouldn't escape it anyways.

these functional envs are required to have no internal state right?

Maybe not necessarily required, but it is strongly encouraged, yes. The only way in which the "internal state" is used is as a namespace for constants, so that we don't have to define global parameters, or drop magic numbers everywhere.

In the internal gym code, we'll definitely write everything in a stateless/immutable/pure way. If someone decides they want to do this and mutate the class somehow... that's their funeral.

@RedTachyon RedTachyon changed the title First version of a functional API Functional API and proof-of-concept classic-control envs Jul 28, 2022
@RedTachyon RedTachyon changed the title Functional API and proof-of-concept classic-control envs Functional API and proof-of-concept jax classic-control envs Jul 28, 2022
@RedTachyon RedTachyon marked this pull request as ready for review July 28, 2022 11:45
@DavidSlayback
Copy link

Just found this PR and would be extremely interested in having a functional API. I work with a lot of POMDPs in my own work, as well as planning algorithms (which need a functional simulator) and batching/vectorization (which is much easier with jax.vmap()). I had a couple of points I want to raise:

  1. In the POMDP literature, the observation function is typically a function of state AND action o ~ O(s,a). Additionally, environment-specific observation noise is a key part of many environments (e.g., sensor noise in sampling for Rock Sample). For observations, I'd propose:
    observation(self, state: StateType, action: Optional[ActType], rng: Optional[Any]):

  2. Similarly, terminal may depend on the previous action. For instance, in the Tiger problem, the state is whether the tiger is behind the left or right door. The episode ends on opening the door, but the actual state of the environment doesn't change. Maybe: terminal(state: StateType, action: Optional[ActType]?

  3. Finally, it feels like rng might be a valid input to any of the functions given. Bandit rewards are typically stochastic, and termination conditions sometimes are as well. Obviously, noise could also just be implemented as an additional function or wrapper, so I'm not sure whether it makes sense to clutter up all the functions with it.

@DavidSlayback
Copy link

Sorry, I had another idea I wanted to bring up! I'm less wedded to it, but thought it might be worth considering.

A possible advantage of the functional API is that we could adjust environment parameters as we run. This could be used for meta-learning. Maybe we vmap over a batch of cartpole environments with different physics parameters so that our agent has to learn a more robust policy. Or maybe we test transfer by switching goals in the middle of a run. Specifically, I'm taking inspiration from the gymnax environment.

Each environment file defines its own EnvState and EnvParams objects as seen here (CartPole):

@flax.struct.dataclass
class EnvState:
    x: float
    x_dot: float
    theta: float
    theta_dot: float
    time: int


@flax.struct.dataclass
class EnvParams:
    gravity: float = 9.8
    masscart: float = 1.0
    masspole: float = 0.1
    total_mass: float = 1.0 + 0.1  # (masscart + masspole)
    length: float = 0.5
    polemass_length: float = 0.05  # (masspole * length)
    force_mag: float = 10.0
    tau: float = 0.02
    theta_threshold_radians: float = 12 * 2 * jnp.pi / 360
    x_threshold: float = 2.4
    max_steps_in_episode: int = 500  # v0 had only 200 steps!

EnvState gets used in the same way as our StateType from above.

EnvParams is passed to step (transition), reset (initial), reward, and terminal, as well as being used to define state/action/observation spaces. If none is provided, the default parameters (which would currently be stored in our FuncEnv) are used. But we can also use different params for each environment or change them during training without changing the underlying environment and breaking its static property

I'm not sure about it because I think it's probably adding significant complexity to what is meant to be a minimal, flexible base API, and I also think that some of these parameters probably should be a component of the environment, but I wanted to raise the idea!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants