In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from evojax.policy.base import PolicyNetwork


class MinimalRNN(nn.Module):
    input_dim: int
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x, h):
        xh = jnp.concatenate([x, h])
        h_new = nn.tanh(nn.Dense(self.hidden_dim)(xh))
        out = nn.Dense(self.output_dim)(h_new)
        return out, h_new


class SimpleRNNPolicy(PolicyNetwork):
    def __init__(self, input_dim, output_dim, hidden_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.model = MinimalRNN(input_dim, hidden_dim, output_dim)

        dummy_input = jnp.zeros((input_dim,))
        dummy_hidden = jnp.zeros((hidden_dim,))
        self._param_tree = self.model.init(jax.random.PRNGKey(0), dummy_input, dummy_hidden)
        self._param_size = sum(x.size for x in jax.tree_util.tree_leaves(self._param_tree["params"]))

    def get_params(self):
        return self._param_tree["params"]

    @property
    def param_size(self):
        return self._param_size

    def forward(self, params, obs, hidden):
        return self.model.apply({'params': params}, obs, hidden)


In [2]:
import jax
import jax.numpy as jnp
import mujoco
import mujoco.mjx as mjx
from evojax.task.base import VectorizedTask


class MJXDummyTask(VectorizedTask):
    def __init__(self, xml_path="mjcf/half_cheetah.xml", episode_length=200):
        self.model = mjx.put_model(mujoco.MjModel.from_xml_path(xml_path))
        self.episode_length = episode_length

        self.obs_size = self.model.nq + self.model.nv
        self.act_size = self.model.nu

    @property
    def obs_shape(self):
        return (self.obs_size,)

    @property
    def action_shape(self):
        return (self.act_size,)

    def reset(self, rng):
        data = mjx.make_data(self.model)
        obs = jnp.concatenate([data.qpos, data.qvel])
        hidden = jnp.zeros((32,))  # match RNN hidden dim
        return obs, hidden

    def step(self, obs, action):
        reward = -jnp.sum(action**2)
        done = False
        return obs, reward, done, {}


In [None]:
from evojax import Trainer
from evojax.algo import PGPE


def main():
    task = MJXDummyTask()
    policy = SimpleRNNPolicy(input_dim=task.obs_shape[0], output_dim=task.act_shape[0])
    algo = PGPE(
        pop_size=64,
        param_size=policy.param_size,
    )

    trainer = Trainer(
        task=task,
        algo=algo,
        policy=policy,
        max_iterations=100,
        test_interval=10,
    )

    trainer.train()

main()


ValueError: ParseXML: Error opening file 'mjcf/half_cheetah.xml': No such file or directory