# K-Scale Humanoid Benchmark

Welcome to the K-Scale Humanoid Benchmark! This notebook will walk you through training your own reinforcement learning policy, which you can then use to control a K-Scale robot.

## Dependencies and Config

The K-Scale Humanoid Benchmark uses K-Scale's open-source RL framework [K-Sim](https://github.com/kscalelabs/ksim) for training and the [K-Scale API](https://github.com/kscalelabs/kscale) for asset management.

To get your API key, install the K-Scale CLI with `pip install kscale` and run `ks user key` in your terminal.

In [1]:
# Install packages

!pip install -q ksim
!pip install -q kos-sim
!pip uninstall -y xax
!pip install git+https://github.com/kscalelabs/xax.git@fix-for-notebook

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/99.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.8/99.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Set up environment variables
import os
from google.colab import userdata
%env TENSORBOARD_PORT=6036
%env MUJOCO_GL=egl


os.environ["KSCALE_API_KEY"] = userdata.get('kscale-api-key')



env: TENSORBOARD_PORT=6036
env: MUJOCO_GL=egl


In [3]:
import asyncio
import math
from dataclasses import dataclass
from typing import Self

import attrs
import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
import ksim
import mujoco
import mujoco_scenes
import mujoco_scenes.mjcf
import optax
import xax
from jaxtyping import Array, PRNGKeyArray
from kscale.web.gen.api import JointMetadataOutput
import nest_asyncio
nest_asyncio.apply()

In [4]:
NUM_JOINTS = 20
NUM_ACTOR_INPUTS = 43
NUM_CRITIC_INPUTS = 444

# These are in the order of the neural network outputs.
ZEROS: list[tuple[str, float]] = [
    ("dof_right_shoulder_pitch_03", 0.0),
    ("dof_right_shoulder_roll_03", math.radians(-10.0)),
    ("dof_right_shoulder_yaw_02", 0.0),
    ("dof_right_elbow_02", math.radians(90.0)),
    ("dof_right_wrist_00", 0.0),
    ("dof_left_shoulder_pitch_03", 0.0),
    ("dof_left_shoulder_roll_03", math.radians(10.0)),
    ("dof_left_shoulder_yaw_02", 0.0),
    ("dof_left_elbow_02", math.radians(-90.0)),
    ("dof_left_wrist_00", 0.0),
    ("dof_right_hip_pitch_04", math.radians(-25.0)),
    ("dof_right_hip_roll_03", 0.0),
    ("dof_right_hip_yaw_03", 0.0),
    ("dof_right_knee_04", math.radians(-50.0)),
    ("dof_right_ankle_02", math.radians(25.0)),
    ("dof_left_hip_pitch_04", math.radians(25.0)),
    ("dof_left_hip_roll_03", 0.0),
    ("dof_left_hip_yaw_03", 0.0),
    ("dof_left_knee_04", math.radians(50.0)),
    ("dof_left_ankle_02", math.radians(-25.0)),
]

# These are the torques we clip outputs to when deploying the policy.
MAX_TORQUE = {
    "00": 1.0,  # 00 motor
    "02": 13.0,  # 02 motor
    "03": 48.0,  # 03 motor
    "04": 96.0,  # 04 motor
}

## Rewards

When training a reinforcement learning agent, the most important thing to define is what reward you want the agent to maximimze. `ksim` includes a number of useful default rewards for training walking agents, but it is often a good idea to define new rewards to encourage specific types of behavior. The cell below shows an example of how to define a custom reward. A similar pattern can be used to define custom objectives, events, observations, and more.

In [5]:
@attrs.define
class BentArmPenalty(ksim.Reward):
    arm_indices: tuple[int, ...] = attrs.field()
    arm_targets: tuple[float, ...] = attrs.field()

    def get_reward(self, trajectory: ksim.Trajectory) -> Array:
        qpos = trajectory.qpos[..., self.arm_indices]
        qpos_targets = jnp.array(self.arm_targets)
        qpos_diff = qpos - qpos_targets
        return xax.get_norm(qpos_diff, "l1").mean(axis=-1)

    @classmethod
    def create(
        cls,
        model: ksim.PhysicsModel,
        scale: float,
        scale_by_curriculum: bool = False,
    ) -> Self:
        qpos_mapping = ksim.get_qpos_data_idxs_by_name(model)

        names = [
            "dof_right_shoulder_pitch_03",
            "dof_right_shoulder_roll_03",
            "dof_right_shoulder_yaw_02",
            "dof_right_elbow_02",
            "dof_right_wrist_00",
            "dof_left_shoulder_pitch_03",
            "dof_left_shoulder_roll_03",
            "dof_left_shoulder_yaw_02",
            "dof_left_elbow_02",
            "dof_left_wrist_00",
        ]

        zeros = {k: v for k, v in ZEROS}
        arm_indices = [qpos_mapping[name][0] for name in names]
        arm_targets = [zeros[name] for name in names]

        return cls(
            arm_indices=tuple(arm_indices),
            arm_targets=tuple(arm_targets),
            scale=scale,
            scale_by_curriculum=scale_by_curriculum,
        )

## Actor-Critic Model

We train our reinforcement learning agent using an RNN-based actor and critic, which we define below.

In [6]:
class Actor(eqx.Module):
    """Actor for the walking task."""

    input_proj: eqx.nn.Linear
    rnns: tuple[eqx.nn.GRUCell, ...]
    output_proj: eqx.nn.Linear
    num_inputs: int = eqx.static_field()
    num_outputs: int = eqx.static_field()
    num_mixtures: int = eqx.static_field()
    min_std: float = eqx.static_field()
    max_std: float = eqx.static_field()
    var_scale: float = eqx.static_field()

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        num_inputs: int,
        num_outputs: int,
        min_std: float,
        max_std: float,
        var_scale: float,
        hidden_size: int,
        num_mixtures: int,
        depth: int,
    ) -> None:
        # Project input to hidden size
        key, input_proj_key = jax.random.split(key)
        self.input_proj = eqx.nn.Linear(
            in_features=num_inputs,
            out_features=hidden_size,
            key=input_proj_key,
        )

        # Create RNN layer
        key, rnn_key = jax.random.split(key)
        self.rnns = tuple(
            [
                eqx.nn.GRUCell(
                    input_size=hidden_size,
                    hidden_size=hidden_size,
                    key=rnn_key,
                )
                for _ in range(depth)
            ]
        )

        # Project to output
        self.output_proj = eqx.nn.Linear(
            in_features=hidden_size,
            out_features=num_outputs * 3 * num_mixtures,
            key=key,
        )

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.num_mixtures = num_mixtures
        self.min_std = min_std
        self.max_std = max_std
        self.var_scale = var_scale

    def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:
        x_n = self.input_proj(obs_n)
        out_carries = []
        for i, rnn in enumerate(self.rnns):
            x_n = rnn(x_n, carry[i])
            out_carries.append(x_n)
        out_n = self.output_proj(x_n)

        # Reshape the output to be a mixture of gaussians.
        slice_len = NUM_JOINTS * self.num_mixtures
        mean_nm = out_n[..., :slice_len].reshape(NUM_JOINTS, self.num_mixtures)
        std_nm = out_n[..., slice_len : slice_len * 2].reshape(NUM_JOINTS, self.num_mixtures)
        logits_nm = out_n[..., slice_len * 2 :].reshape(NUM_JOINTS, self.num_mixtures)

        # Softplus and clip to ensure positive standard deviations.
        std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)

        # Apply bias to the means.
        mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]

        dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)

        return dist_n, jnp.stack(out_carries, axis=0)


class Critic(eqx.Module):
    """Critic for the walking task."""

    input_proj: eqx.nn.Linear
    rnns: tuple[eqx.nn.GRUCell, ...]
    output_proj: eqx.nn.Linear

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        hidden_size: int,
        depth: int,
    ) -> None:
        num_inputs = NUM_CRITIC_INPUTS
        num_outputs = 1

        # Project input to hidden size
        key, input_proj_key = jax.random.split(key)
        self.input_proj = eqx.nn.Linear(
            in_features=num_inputs,
            out_features=hidden_size,
            key=input_proj_key,
        )

        # Create RNN layer
        key, rnn_key = jax.random.split(key)
        self.rnns = tuple(
            [
                eqx.nn.GRUCell(
                    input_size=hidden_size,
                    hidden_size=hidden_size,
                    key=rnn_key,
                )
                for _ in range(depth)
            ]
        )

        # Project to output
        self.output_proj = eqx.nn.Linear(
            in_features=hidden_size,
            out_features=num_outputs,
            key=key,
        )

    def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:
        x_n = self.input_proj(obs_n)
        out_carries = []
        for i, rnn in enumerate(self.rnns):
            x_n = rnn(x_n, carry[i])
            out_carries.append(x_n)
        out_n = self.output_proj(x_n)

        return out_n, jnp.stack(out_carries, axis=0)


class Model(eqx.Module):
    actor: Actor
    critic: Critic

    def __init__(
        self,
        key: PRNGKeyArray,
        *,
        num_inputs: int,
        num_outputs: int,
        min_std: float,
        max_std: float,
        hidden_size: int,
        num_mixtures: int,
        depth: int,
    ) -> None:
        self.actor = Actor(
            key,
            num_inputs=num_inputs,
            num_outputs=num_outputs,
            min_std=min_std,
            max_std=max_std,
            var_scale=0.5,
            hidden_size=hidden_size,
            num_mixtures=num_mixtures,
            depth=depth,
        )
        self.critic = Critic(
            key,
            hidden_size=hidden_size,
            depth=depth,
        )

## Config

The [ksim framework](https://github.com/kscalelabs/ksim) is based on [xax](https://github.com/kscalelabs/xax), a Jax training library built by K-Scale. To provide configuration options, Xax uses a Config dataclass to parse command-line options. We define the config here.

In [7]:
@dataclass
class HumanoidWalkingTaskConfig(ksim.PPOConfig):
    """Config for the humanoid walking task."""

    # Model parameters.
    hidden_size: int = xax.field(
        value=128,
        help="The hidden size for the MLPs.",
    )
    depth: int = xax.field(
        value=5,
        help="The depth for the MLPs.",
    )
    num_mixtures: int = xax.field(
        value=5,
        help="The number of mixtures for the actor.",
    )
    scale: float = xax.field(
        value=0.1,
        help="The maximum position delta on each step, in radians.",
    )

    # Optimizer parameters.
    learning_rate: float = xax.field(
        value=3e-4,
        help="Learning rate for PPO.",
    )
    max_grad_norm: float = xax.field(
        value=2.0,
        help="Maximum gradient norm for clipping.",
    )
    adam_weight_decay: float = xax.field(
        value=1e-5,
        help="Weight decay for the Adam optimizer.",
    )

    # Curriculum parameters.
    num_curriculum_levels: int = xax.field(
        value=10,
        help="The number of curriculum levels to use.",
    )
    increase_threshold: float = xax.field(
        value=3.0,
        help="Increase the curriculum level when the mean trajectory length is above this threshold.",
    )
    decrease_threshold: float = xax.field(
        value=1.0,
        help="Decrease the curriculum level when the mean trajectory length is below this threshold.",
    )
    min_level_steps: int = xax.field(
        value=50,
        help="The minimum number of steps to wait before changing the curriculum level.",
    )
    min_curriculum_level: float = xax.field(
        value=0.0,
        help="The minimum curriculum level to use.",
    )

    # Rendering parameters.
    render_track_body_id: int | None = xax.field(
        value=0,
        help="The body id to track with the render camera.",
    )

## Task

The meat-and-potatoes of our training code is the task. This defines the observations, rewards, model calling logic, and everything else needed by `ksim` to train our reinforcement learning agent.

In [8]:
class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):
    def get_optimizer(self) -> optax.GradientTransformation:
        optimizer = optax.chain(
            optax.clip_by_global_norm(self.config.max_grad_norm),
            (
                optax.adam(self.config.learning_rate)
                if self.config.adam_weight_decay == 0.0
                else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)
            ),
        )

        return optimizer

    def get_mujoco_model(self) -> mujoco.MjModel:
        mjcf_path = asyncio.run(ksim.get_mujoco_model_path("kbot-v2-feet", name="robot"))
        return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene="smooth")

    def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> dict[str, JointMetadataOutput]:
        metadata = asyncio.run(ksim.get_mujoco_model_metadata("kbot-v2-feet"))
        if metadata.joint_name_to_metadata is None:
            raise ValueError("Joint metadata is not available")
        return metadata.joint_name_to_metadata

    def get_actuators(
        self,
        physics_model: ksim.PhysicsModel,
        metadata: dict[str, JointMetadataOutput] | None = None,
    ) -> ksim.Actuators:
        assert metadata is not None, "Metadata is required"
        return ksim.MITPositionActuators(
            physics_model=physics_model,
            joint_name_to_metadata=metadata,
            ctrl_clip=[
                # right arm
                MAX_TORQUE["03"],
                MAX_TORQUE["03"],
                MAX_TORQUE["02"],
                MAX_TORQUE["02"],
                MAX_TORQUE["00"],
                # left arm
                MAX_TORQUE["03"],
                MAX_TORQUE["03"],
                MAX_TORQUE["02"],
                MAX_TORQUE["02"],
                MAX_TORQUE["00"],
                # right leg
                MAX_TORQUE["04"],
                MAX_TORQUE["03"],
                MAX_TORQUE["03"],
                MAX_TORQUE["04"],
                MAX_TORQUE["02"],
                # left leg
                MAX_TORQUE["04"],
                MAX_TORQUE["03"],
                MAX_TORQUE["03"],
                MAX_TORQUE["04"],
                MAX_TORQUE["02"],
            ],
        )

    def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:
        return [
            ksim.StaticFrictionRandomizer(),
            ksim.FloorFrictionRandomizer.from_geom_name(physics_model, "floor", scale_lower=0.8, scale_upper=1.2),
            ksim.ArmatureRandomizer(),
            ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),
            ksim.JointDampingRandomizer(),
            ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),
        ]

    def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:
        return [
            ksim.PushEvent(
                x_force=1.5,
                y_force=1.5,
                z_force=0.1,
                x_angular_force=0.1,
                y_angular_force=0.1,
                z_angular_force=0.3,
                interval_range=(0.5, 4.0),
            ),
        ]

    def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:
        return [
            ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),
            ksim.RandomJointVelocityReset(),
        ]

    def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:
        return [
            ksim.JointPositionObservation(),
            ksim.JointVelocityObservation(),
            ksim.ActuatorForceObservation(),
            ksim.CenterOfMassInertiaObservation(),
            ksim.CenterOfMassVelocityObservation(),
            ksim.BasePositionObservation(),
            ksim.BaseOrientationObservation(),
            ksim.BaseLinearVelocityObservation(),
            ksim.BaseAngularVelocityObservation(),
            ksim.BaseLinearAccelerationObservation(),
            ksim.BaseAngularAccelerationObservation(),
            ksim.ProjectedGravityObservation.create(
                physics_model=physics_model,
                framequat_name="base_link_quat",
                lag_range=(0.0, 0.5),
            ),
            ksim.ActuatorAccelerationObservation(),
            ksim.BasePositionObservation(),
            ksim.BaseOrientationObservation(),
            ksim.BaseLinearVelocityObservation(),
            ksim.BaseAngularVelocityObservation(),
            ksim.CenterOfMassVelocityObservation(),
            ksim.SensorObservation.create(physics_model=physics_model, sensor_name="imu_acc"),
            ksim.SensorObservation.create(physics_model=physics_model, sensor_name="imu_gyro"),
        ]

    def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
        return []

    def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
        return [
            # Standard rewards.
            ksim.StayAliveReward(scale=1.0),
            ksim.NaiveForwardReward(clip_min=0.0, clip_max=0.5, scale=1.0),
            ksim.UprightReward(index="x", inverted=False, scale=0.1),
            # Normalization penalties.
            ksim.ActionInBoundsReward.create(physics_model, scale=0.01),
            ksim.ActionSmoothnessPenalty(scale=-0.01),
            ksim.ActuatorJerkPenalty(ctrl_dt=self.config.ctrl_dt, scale=-0.001),
            ksim.ActuatorRelativeForcePenalty.create(physics_model, scale=-0.001),
            ksim.AngularVelocityPenalty(index="x", scale=-0.0005),
            ksim.AngularVelocityPenalty(index="y", scale=-0.0005),
            ksim.AngularVelocityPenalty(index="z", scale=-0.0005),
            ksim.LinearVelocityPenalty(index="y", scale=-0.0005),
            ksim.LinearVelocityPenalty(index="z", scale=-0.0005),
            # Bespoke rewards.
            BentArmPenalty.create(physics_model, scale=-0.01),
        ]

    def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:
        return [
            ksim.BadZTermination(unhealthy_z_lower=0.9, unhealthy_z_upper=1.6),
            ksim.PitchTooGreatTermination(max_pitch=math.radians(30)),
            ksim.RollTooGreatTermination(max_roll=math.radians(30)),
            ksim.HighVelocityTermination(),
            ksim.FarFromOriginTermination(max_dist=10.0),
        ]

    def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:
        return ksim.EpisodeLengthCurriculum(
            num_levels=self.config.num_curriculum_levels,
            increase_threshold=self.config.increase_threshold,
            decrease_threshold=self.config.decrease_threshold,
            min_level_steps=self.config.min_level_steps,
            dt=self.config.ctrl_dt,
            min_level=self.config.min_curriculum_level,
        )

    def get_model(self, key: PRNGKeyArray) -> Model:
        return Model(
            key,
            num_inputs=NUM_ACTOR_INPUTS,
            num_outputs=NUM_JOINTS,
            min_std=0.01,
            max_std=1.0,
            hidden_size=self.config.hidden_size,
            num_mixtures=self.config.num_mixtures,
            depth=self.config.depth,
        )

    def run_actor(
        self,
        model: Actor,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        carry: Array,
    ) -> tuple[distrax.Distribution, Array]:
        joint_pos_n = observations["joint_position_observation"]
        joint_vel_n = observations["joint_velocity_observation"]
        proj_grav_3 = observations["projected_gravity_observation"]

        obs_n = jnp.concatenate(
            [
                joint_pos_n,  # NUM_JOINTS
                joint_vel_n,  # NUM_JOINTS
                proj_grav_3,  # 3
            ],
            axis=-1,
        )

        action, carry = model.forward(obs_n, carry)

        return action, carry

    def run_critic(
        self,
        model: Critic,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        carry: Array,
    ) -> tuple[Array, Array]:
        dh_joint_pos_j = observations["joint_position_observation"]
        dh_joint_vel_j = observations["joint_velocity_observation"]
        com_inertia_n = observations["center_of_mass_inertia_observation"]
        com_vel_n = observations["center_of_mass_velocity_observation"]
        imu_acc_3 = observations["sensor_observation_imu_acc"]
        imu_gyro_3 = observations["sensor_observation_imu_gyro"]
        proj_grav_3 = observations["projected_gravity_observation"]
        act_frc_obs_n = observations["actuator_force_observation"]
        base_pos_3 = observations["base_position_observation"]
        base_quat_4 = observations["base_orientation_observation"]

        obs_n = jnp.concatenate(
            [
                dh_joint_pos_j,  # NUM_JOINTS
                dh_joint_vel_j / 10.0,  # NUM_JOINTS
                com_inertia_n,  # 160
                com_vel_n,  # 96
                imu_acc_3,  # 3
                imu_gyro_3,  # 3
                proj_grav_3,  # 3
                act_frc_obs_n / 100.0,  # NUM_JOINTS
                base_pos_3,  # 3
                base_quat_4,  # 4
            ],
            axis=-1,
        )

        return model.forward(obs_n, carry)

    def get_ppo_variables(
        self,
        model: Model,
        trajectory: ksim.Trajectory,
        model_carry: tuple[Array, Array],
        rng: PRNGKeyArray,
    ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:
        def scan_fn(
            actor_critic_carry: tuple[Array, Array],
            transition: ksim.Trajectory,
        ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:
            actor_carry, critic_carry = actor_critic_carry
            actor_dist, next_actor_carry = self.run_actor(
                model=model.actor,
                observations=transition.obs,
                commands=transition.command,
                carry=actor_carry,
            )
            log_probs = actor_dist.log_prob(transition.action)
            assert isinstance(log_probs, Array)
            value, next_critic_carry = self.run_critic(
                model=model.critic,
                observations=transition.obs,
                commands=transition.command,
                carry=critic_carry,
            )

            transition_ppo_variables = ksim.PPOVariables(
                log_probs=log_probs,
                values=value.squeeze(-1),
            )

            next_carry = jax.tree.map(
                lambda x, y: jnp.where(transition.done, x, y),
                self.get_initial_model_carry(rng),
                (next_actor_carry, next_critic_carry),
            )

            return next_carry, transition_ppo_variables

        next_model_carry, ppo_variables = jax.lax.scan(scan_fn, model_carry, trajectory)

        return ppo_variables, next_model_carry

    def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
        return (
            jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
            jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
        )

    def sample_action(
        self,
        model: Model,
        model_carry: tuple[Array, Array],
        physics_model: ksim.PhysicsModel,
        physics_state: ksim.PhysicsState,
        observations: xax.FrozenDict[str, Array],
        commands: xax.FrozenDict[str, Array],
        rng: PRNGKeyArray,
        argmax: bool,
    ) -> ksim.Action:
        actor_carry_in, critic_carry_in = model_carry

        # Runs the actor model to get the action distribution.
        action_dist_j, actor_carry = self.run_actor(
            model=model.actor,
            observations=observations,
            commands=commands,
            carry=actor_carry_in,
        )

        action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)

        return ksim.Action(
            action=action_j,
            carry=(actor_carry, critic_carry_in),
            aux_outputs=None,
        )

## Launching an Experiment

To launch an experiment with `xax`, you can use `Task.launch(config)`. Note that this is usually intended to be called from the command-line, so it will by default attempt to parse additional command-line arguments unless `use_cli=False` is set.

In [None]:
HumanoidWalkingTask.launch(
    HumanoidWalkingTaskConfig(
        # Training parameters.
        num_envs=2048,
        batch_size=128,
        num_passes=2,
        epochs_per_log_step=1,
        rollout_length_seconds=1.0,
        # Simulation parameters.
        dt=0.002,
        ctrl_dt=0.02,
        iterations=8,
        ls_iterations=8,
        max_action_latency=0.01,
        # Checkpointing parameters.
        save_every_n_seconds=60,
        # Xax parameters.
        disable_multiprocessing=True,
    ),
    use_cli=False,
)

[31m┌────────────────────────────────────────────────────────────┐[0m
[31m│[0m [1;91mNo config file was found in /root/.xax.yml; writing one...[0m [31m│[0m
[31m└────────────────────────────────────────────────────────────┘[0m






INFO:xax.task.mixins.compile:Setting JAX logging level to INFO


  [1;36mINFO[0m  [90m2025-05-02 22:13:22[0m [[1;34mxax.task.mixins.compile[0m] Setting JAX logging level to INFO


INFO:xax.task.mixins.compile:Setting JAX compilation cache directory to /root/.cache/jax/jaxcache


  [1;36mINFO[0m  [90m2025-05-02 22:13:22[0m [[1;34mxax.task.mixins.compile[0m] Setting JAX compilation cache directory to /root/.cache/jax/jaxcache


INFO:xax.task.mixins.compile:Configuring JAX compilation cache parameters


  [1;36mINFO[0m  [90m2025-05-02 22:13:22[0m [[1;34mxax.task.mixins.compile[0m] Configuring JAX compilation cache parameters


INFO:2025-05-02 22:13:23,760:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


  [1;36mINFO[0m  [90m2025-05-02 22:13:23[0m [[1;34mjax._src.xla_bridge[0m] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-05-02 22:13:23,788:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


  [1;36mINFO[0m  [90m2025-05-02 22:13:23[0m [[1;34mjax._src.xla_bridge[0m] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory






STATUS:xax.task.mixins.artifacts:/content/humanoid_walking_task/run_0


 [1;32mSTATUS[0m [90m2025-05-02 22:13:25[0m [[1;34mxax.task.mixins.artifacts[0m] /content/humanoid_walking_task/run_0












STATUS:xax.task.mixins.train:/content


 [1;32mSTATUS[0m [90m2025-05-02 22:13:25[0m [[1;34mxax.task.mixins.train[0m] /content


STATUS:xax.task.mixins.train:humanoid_walking_task


 [1;32mSTATUS[0m [90m2025-05-02 22:13:25[0m [[1;34mxax.task.mixins.train[0m] humanoid_walking_task


STATUS:xax.task.mixins.train:JAX devices: [CudaDevice(id=0)]


 [1;32mSTATUS[0m [90m2025-05-02 22:13:25[0m [[1;34mxax.task.mixins.train[0m] JAX devices: [CudaDevice(id=0)]






INFO:httpx:HTTP Request: GET https://api.kscale.dev/robots/urdf/kbot-v2-feet "HTTP/1.1 200 OK"


  [1;36mINFO[0m  [90m2025-05-02 22:13:26[0m [[1;34mhttpx[0m] HTTP Request: GET https://api.kscale.dev/robots/urdf/kbot-v2-feet "HTTP/1.1 200 OK"


INFO:kscale.web.clients.robot_class:Downloading URDF file from https://kscale-www-production.s3.amazonaws.com/urdfs/81d7c38e0690537e/robot.tgz?AWSAccessKeyId=ASIA2R4HRCAH4IJGHQEC&Signature=7dn5jhz2suGee4wQJOI7IU%2BusSE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEEYaCXVzLWVhc3QtMSJHMEUCIQDKeUPH6mSTaWmge7NSWzZ%2F%2FsrmIp3MeyQFSz812iNAugIgRx7GGXWQIB0pOVdkiDCeUnB3FUvvQuS1TBfJYNzLlYkq6gII3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAAGgw3MjU1OTY4MzU4NTUiDBlczJNUQg75%2B1tkHyq%2BApe847C3oV33Ur%2BFa2nOyobu15iBog8JAG%2F1kIXex3%2B%2BnyrEsPAv%2BG7caLjlah5WHbWiM0WB6alilFxeyrT%2BzQpnAJ6Prk40eobgQDbyLUVDMPZ9wnlpqaLKdowoebRb6UIVlSfbzWokELzwbRkkfpIMrEXJ4ioQE1hLLHMdUH6hj4JCc%2BvvwjhO2KuYMRvnFKvC%2FcAcYwshlL1sfmjit146YwAnIV%2FYAAkUOBuF7GzNxPc0zDZkqW1SMk68DNr3JaiP0sbcejSEcY%2FbASHQ8dOZ4qKbWoX%2BcGMiz%2BY3IDPL7sUCUcQ2SOxEtp3eogQcEZHT%2Bjs2LeAC%2FYmsp0vlut4PMDln4hW73BT7Y7e%2B5nIH6sV8AXKz1gY6NhFHcpekJSzKmueE3dWCZZNHQt43gHWaQETfRJP7ijaRnTdq1jDN%2F9TABjqPARuwhVSOLo%2BjewGPs4f5gTTVl8pJF7gutdDqAnVy3lYh2qbsWFFVCgV88voNVGhpC

  [1;36mINFO[0m  [90m2025-05-02 22:13:26[0m [[1;34mkscale.web.clients.robot_class[0m] Downloading URDF file from https://kscale-www-production.s3.amazonaws.com/urdfs/81d7c38e0690537e/robot.tgz?AWSAccessKeyId=ASIA2R4HRCAH4IJGHQEC&Signature=7dn5jhz2suGee4wQJOI7IU%2BusSE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEEYaCXVzLWVhc3QtMSJHMEUCIQDKeUPH6mSTaWmge7NSWzZ%2F%2FsrmIp3MeyQFSz812iNAugIgRx7GGXWQIB0pOVdkiDCeUnB3FUvvQuS1TBfJYNzLlYkq6gII3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAAGgw3MjU1OTY4MzU4NTUiDBlczJNUQg75%2B1tkHyq%2BApe847C3oV33Ur%2BFa2nOyobu15iBog8JAG%2F1kIXex3%2B%2BnyrEsPAv%2BG7caLjlah5WHbWiM0WB6alilFxeyrT%2BzQpnAJ6Prk40eobgQDbyLUVDMPZ9wnlpqaLKdowoebRb6UIVlSfbzWokELzwbRkkfpIMrEXJ4ioQE1hLLHMdUH6hj4JCc%2BvvwjhO2KuYMRvnFKvC%2FcAcYwshlL1sfmjit146YwAnIV%2FYAAkUOBuF7GzNxPc0zDZkqW1SMk68DNr3JaiP0sbcejSEcY%2FbASHQ8dOZ4qKbWoX%2BcGMiz%2BY3IDPL7sUCUcQ2SOxEtp3eogQcEZHT%2Bjs2LeAC%2FYmsp0vlut4PMDln4hW73BT7Y7e%2B5nIH6sV8AXKz1gY6NhFHcpekJSzKmueE3dWCZZNHQt43gHWaQETfRJP7ijaRnTdq1jDN%2F9TABjqPARuwhVSOLo%

INFO:httpx:HTTP Request: GET https://kscale-www-production.s3.amazonaws.com/urdfs/81d7c38e0690537e/robot.tgz?AWSAccessKeyId=ASIA2R4HRCAH4IJGHQEC&Signature=7dn5jhz2suGee4wQJOI7IU%2BusSE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEEYaCXVzLWVhc3QtMSJHMEUCIQDKeUPH6mSTaWmge7NSWzZ%2F%2FsrmIp3MeyQFSz812iNAugIgRx7GGXWQIB0pOVdkiDCeUnB3FUvvQuS1TBfJYNzLlYkq6gII3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAAGgw3MjU1OTY4MzU4NTUiDBlczJNUQg75%2B1tkHyq%2BApe847C3oV33Ur%2BFa2nOyobu15iBog8JAG%2F1kIXex3%2B%2BnyrEsPAv%2BG7caLjlah5WHbWiM0WB6alilFxeyrT%2BzQpnAJ6Prk40eobgQDbyLUVDMPZ9wnlpqaLKdowoebRb6UIVlSfbzWokELzwbRkkfpIMrEXJ4ioQE1hLLHMdUH6hj4JCc%2BvvwjhO2KuYMRvnFKvC%2FcAcYwshlL1sfmjit146YwAnIV%2FYAAkUOBuF7GzNxPc0zDZkqW1SMk68DNr3JaiP0sbcejSEcY%2FbASHQ8dOZ4qKbWoX%2BcGMiz%2BY3IDPL7sUCUcQ2SOxEtp3eogQcEZHT%2Bjs2LeAC%2FYmsp0vlut4PMDln4hW73BT7Y7e%2B5nIH6sV8AXKz1gY6NhFHcpekJSzKmueE3dWCZZNHQt43gHWaQETfRJP7ijaRnTdq1jDN%2F9TABjqPARuwhVSOLo%2BjewGPs4f5gTTVl8pJF7gutdDqAnVy3lYh2qbsWFFVCgV88voNVGhpC0%2Bp086%2F6bw1P0DKtVlJCJE8zZD9o%2

  [1;36mINFO[0m  [90m2025-05-02 22:13:26[0m [[1;34mhttpx[0m] HTTP Request: GET https://kscale-www-production.s3.amazonaws.com/urdfs/81d7c38e0690537e/robot.tgz?AWSAccessKeyId=ASIA2R4HRCAH4IJGHQEC&Signature=7dn5jhz2suGee4wQJOI7IU%2BusSE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEEYaCXVzLWVhc3QtMSJHMEUCIQDKeUPH6mSTaWmge7NSWzZ%2F%2FsrmIp3MeyQFSz812iNAugIgRx7GGXWQIB0pOVdkiDCeUnB3FUvvQuS1TBfJYNzLlYkq6gII3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAAGgw3MjU1OTY4MzU4NTUiDBlczJNUQg75%2B1tkHyq%2BApe847C3oV33Ur%2BFa2nOyobu15iBog8JAG%2F1kIXex3%2B%2BnyrEsPAv%2BG7caLjlah5WHbWiM0WB6alilFxeyrT%2BzQpnAJ6Prk40eobgQDbyLUVDMPZ9wnlpqaLKdowoebRb6UIVlSfbzWokELzwbRkkfpIMrEXJ4ioQE1hLLHMdUH6hj4JCc%2BvvwjhO2KuYMRvnFKvC%2FcAcYwshlL1sfmjit146YwAnIV%2FYAAkUOBuF7GzNxPc0zDZkqW1SMk68DNr3JaiP0sbcejSEcY%2FbASHQ8dOZ4qKbWoX%2BcGMiz%2BY3IDPL7sUCUcQ2SOxEtp3eogQcEZHT%2Bjs2LeAC%2FYmsp0vlut4PMDln4hW73BT7Y7e%2B5nIH6sV8AXKz1gY6NhFHcpekJSzKmueE3dWCZZNHQt43gHWaQETfRJP7ijaRnTdq1jDN%2F9TABjqPARuwhVSOLo%2BjewGPs4f5gTTVl8pJF7gutdDqAnVy3lY

INFO:kscale.web.clients.robot_class:Checking MD5 hash of downloaded file


  [1;36mINFO[0m  [90m2025-05-02 22:13:27[0m [[1;34mkscale.web.clients.robot_class[0m] Checking MD5 hash of downloaded file


INFO:kscale.web.clients.robot_class:Updating downloaded file information


  [1;36mINFO[0m  [90m2025-05-02 22:13:27[0m [[1;34mkscale.web.clients.robot_class[0m] Updating downloaded file information


INFO:kscale.web.clients.robot_class:Unpacking URDF file


  [1;36mINFO[0m  [90m2025-05-02 22:13:27[0m [[1;34mkscale.web.clients.robot_class[0m] Unpacking URDF file


INFO:kscale.web.clients.robot_class:Updating downloaded file information


  [1;36mINFO[0m  [90m2025-05-02 22:13:27[0m [[1;34mkscale.web.clients.robot_class[0m] Updating downloaded file information


INFO:xax.task.mixins.train:Starting a new training run


  [1;36mINFO[0m  [90m2025-05-02 22:13:28[0m [[1;34mxax.task.mixins.train[0m] Starting a new training run


PING:ksim.task.rl:Model size: 1,089,581 parameters


  [1;35mPING[0m  [90m2025-05-02 22:13:31[0m [[1;34mksim.task.rl[0m] Model size: 1,089,581 parameters


PING:ksim.task.rl:Optimizer size: 2,179,162 parameters


  [1;35mPING[0m  [90m2025-05-02 22:13:31[0m [[1;34mksim.task.rl[0m] Optimizer size: 2,179,162 parameters


INFO:httpx:HTTP Request: GET https://api.kscale.dev/robots/name/kbot-v2-feet "HTTP/1.1 200 OK"


  [1;36mINFO[0m  [90m2025-05-02 22:13:32[0m [[1;34mhttpx[0m] HTTP Request: GET https://api.kscale.dev/robots/name/kbot-v2-feet "HTTP/1.1 200 OK"
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m1[0m
 ↪ Samples: [36m102,400[0m
 ↪ Elapsed Time: [36m0s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001238
 ↪ action_smoothness_penalty: -5.036e-05
 ↪ actuator_jerk_penalty: -6.031e-06
 ↪ actuator_relative_force_penalty: -5.12e-06
 ↪ bent_arm_penalty: -2.36e-05
 ↪ stay_alive_reward: 0.001084
 ↪ total: 0.004403
 ↪ upright_reward: 0.003135
 ↪ x_angular_velocity_penalty: -1.518e-05
 ↪ x_naive_forward_reward: 0.0002308
 ↪ y_angular_velocity_penalty: -3.02e-05
 ↪ y_linear_velocity_penalty: -3.103e-07
 ↪ z_angular_velocity_penalty: -3.812e-05
 ↪ z_linear_velocity_penalty: -2.335e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0
 ↪ dt: 235.1

[1;90mStatus[0m
 ✦ [32mJAX devices: [CudaDevice(id=0)][0m
 ✦ [32mhumanoid_walking_task[0m
 ✦ [32m/content[0m
 ✦ [32m/content/huma

STATUS:ksim.task.rl:First step time: 2m, 53s


 [1;32mSTATUS[0m [90m2025-05-02 22:17:18[0m [[1;34mksim.task.rl[0m] First step time: 2m, 53s


STATUS:root:Tensorboard: http://3521daa4530a:6036/


 [1;32mSTATUS[0m [90m2025-05-02 22:17:26[0m [[1;34mroot[0m] Tensorboard: http://3521daa4530a:6036/


INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.1.bin


  [1;36mINFO[0m  [90m2025-05-02 22:17:35[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.1.bin
[2J[H[1;90mPhase: [0m[1;33mvalid[0m
 ↪ Steps: [36m1[0m
 ↪ Samples: [36m102,400[0m
 ↪ Elapsed Time: [36m2m, 53s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001239
 ↪ action_smoothness_penalty: -5.028e-05
 ↪ actuator_jerk_penalty: -6.001e-06
 ↪ actuator_relative_force_penalty: -5.231e-06
 ↪ bent_arm_penalty: -2.411e-05
 ↪ stay_alive_reward: 0.001023
 ↪ total: 0.004335
 ↪ upright_reward: 0.003132
 ↪ x_angular_velocity_penalty: -1.545e-05
 ↪ x_naive_forward_reward: 0.0002296
 ↪ y_angular_velocity_penalty: -3.013e-05
 ↪ y_linear_velocity_penalty: -3.795e-07
 ↪ z_angular_velocity_penalty: -3.982e-05
 ↪ z_linear_velocity_penalty: -2.061e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0
 ↪ dt: 253.1

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[0m
 ✦ 

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.5.bin


  [1;36mINFO[0m  [90m2025-05-02 22:19:16[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.5.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m6[0m
 ↪ Samples: [36m614,400[0m
 ↪ Elapsed Time: [36m4m, 2s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001255
 ↪ action_smoothness_penalty: -4.839e-05
 ↪ actuator_jerk_penalty: -5.932e-06
 ↪ actuator_relative_force_penalty: -5.273e-06
 ↪ bent_arm_penalty: -2.446e-05
 ↪ stay_alive_reward: 0.00139
 ↪ total: 0.005039
 ↪ upright_reward: 0.003126
 ↪ x_angular_velocity_penalty: -1.572e-05
 ↪ x_naive_forward_reward: 0.0005677
 ↪ y_angular_velocity_penalty: -2.712e-05
 ↪ y_linear_velocity_penalty: -6.225e-07
 ↪ z_angular_velocity_penalty: -4.065e-05
 ↪ z_linear_velocity_penalty: -1.693e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.07236
 ↪ dt: 17.21

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[0m

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.9.bin


  [1;36mINFO[0m  [90m2025-05-02 22:20:26[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.9.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m10[0m
 ↪ Samples: [36m1,024,000[0m
 ↪ Elapsed Time: [36m5m, 12s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001315
 ↪ action_smoothness_penalty: -4.684e-05
 ↪ actuator_jerk_penalty: -5.804e-06
 ↪ actuator_relative_force_penalty: -5.25e-06
 ↪ bent_arm_penalty: -2.46e-05
 ↪ stay_alive_reward: 0.001502
 ↪ total: 0.00587
 ↪ upright_reward: 0.003107
 ↪ x_angular_velocity_penalty: -1.56e-05
 ↪ x_naive_forward_reward: 0.001299
 ↪ y_angular_velocity_penalty: -2.8e-05
 ↪ y_linear_velocity_penalty: -7.809e-07
 ↪ z_angular_velocity_penalty: -4.108e-05
 ↪ z_linear_velocity_penalty: -1.64e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.06467
 ↪ dt: 17.45

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[0m
 ✦

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.13.bin


  [1;36mINFO[0m  [90m2025-05-02 22:21:58[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.13.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m14[0m
 ↪ Samples: [36m1,433,600[0m
 ↪ Elapsed Time: [36m6m, 22s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001366
 ↪ action_smoothness_penalty: -4.522e-05
 ↪ actuator_jerk_penalty: -5.706e-06
 ↪ actuator_relative_force_penalty: -5.181e-06
 ↪ bent_arm_penalty: -2.454e-05
 ↪ stay_alive_reward: 0.001526
 ↪ total: 0.006497
 ↪ upright_reward: 0.00309
 ↪ x_angular_velocity_penalty: -1.523e-05
 ↪ x_naive_forward_reward: 0.001913
 ↪ y_angular_velocity_penalty: -2.908e-05
 ↪ y_linear_velocity_penalty: -7.515e-07
 ↪ z_angular_velocity_penalty: -4.055e-05
 ↪ z_linear_velocity_penalty: -1.631e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05638
 ↪ dt: 17.23

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.17.bin


  [1;36mINFO[0m  [90m2025-05-02 22:23:08[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.17.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m18[0m
 ↪ Samples: [36m1,843,200[0m
 ↪ Elapsed Time: [36m7m, 32s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001411
 ↪ action_smoothness_penalty: -4.392e-05
 ↪ actuator_jerk_penalty: -5.477e-06
 ↪ actuator_relative_force_penalty: -5.135e-06
 ↪ bent_arm_penalty: -2.461e-05
 ↪ stay_alive_reward: 0.001565
 ↪ total: 0.006556
 ↪ upright_reward: 0.00311
 ↪ x_angular_velocity_penalty: -1.458e-05
 ↪ x_naive_forward_reward: 0.001904
 ↪ y_angular_velocity_penalty: -2.647e-05
 ↪ y_linear_velocity_penalty: -7.788e-07
 ↪ z_angular_velocity_penalty: -4.034e-05
 ↪ z_linear_velocity_penalty: -1.443e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05653
 ↪ dt: 17.27

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.21.bin


  [1;36mINFO[0m  [90m2025-05-02 22:24:38[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.21.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m22[0m
 ↪ Samples: [36m2,252,800[0m
 ↪ Elapsed Time: [36m8m, 42s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001412
 ↪ action_smoothness_penalty: -4.33e-05
 ↪ actuator_jerk_penalty: -5.386e-06
 ↪ actuator_relative_force_penalty: -5.09e-06
 ↪ bent_arm_penalty: -2.479e-05
 ↪ stay_alive_reward: 0.00159
 ↪ total: 0.006566
 ↪ upright_reward: 0.003104
 ↪ x_angular_velocity_penalty: -1.384e-05
 ↪ x_naive_forward_reward: 0.001891
 ↪ y_angular_velocity_penalty: -2.478e-05
 ↪ y_linear_velocity_penalty: -8.145e-07
 ↪ z_angular_velocity_penalty: -4.036e-05
 ↪ z_linear_velocity_penalty: -1.353e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05369
 ↪ dt: 17.3

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[0

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.25.bin


  [1;36mINFO[0m  [90m2025-05-02 22:25:48[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.25.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m26[0m
 ↪ Samples: [36m2,662,400[0m
 ↪ Elapsed Time: [36m9m, 52s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001431
 ↪ action_smoothness_penalty: -4.268e-05
 ↪ actuator_jerk_penalty: -5.251e-06
 ↪ actuator_relative_force_penalty: -5.053e-06
 ↪ bent_arm_penalty: -2.477e-05
 ↪ stay_alive_reward: 0.00158
 ↪ total: 0.006826
 ↪ upright_reward: 0.00312
 ↪ x_angular_velocity_penalty: -1.325e-05
 ↪ x_naive_forward_reward: 0.002141
 ↪ y_angular_velocity_penalty: -2.474e-05
 ↪ y_linear_velocity_penalty: -7.55e-07
 ↪ z_angular_velocity_penalty: -4.1e-05
 ↪ z_linear_velocity_penalty: -1.339e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05423
 ↪ dt: 17.29

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[0m

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.29.bin


  [1;36mINFO[0m  [90m2025-05-02 22:27:19[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.29.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m30[0m
 ↪ Samples: [36m3,072,000[0m
 ↪ Elapsed Time: [36m11m, 2s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.000144
 ↪ action_smoothness_penalty: -4.228e-05
 ↪ actuator_jerk_penalty: -5.148e-06
 ↪ actuator_relative_force_penalty: -5.02e-06
 ↪ bent_arm_penalty: -2.503e-05
 ↪ stay_alive_reward: 0.001585
 ↪ total: 0.007067
 ↪ upright_reward: 0.003117
 ↪ x_angular_velocity_penalty: -1.277e-05
 ↪ x_naive_forward_reward: 0.00238
 ↪ y_angular_velocity_penalty: -2.474e-05
 ↪ y_linear_velocity_penalty: -8.659e-07
 ↪ z_angular_velocity_penalty: -4.197e-05
 ↪ z_linear_velocity_penalty: -1.316e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05255
 ↪ dt: 17.37

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s[

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.33.bin


  [1;36mINFO[0m  [90m2025-05-02 22:28:29[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.33.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m34[0m
 ↪ Samples: [36m3,481,600[0m
 ↪ Elapsed Time: [36m12m, 12s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001458
 ↪ action_smoothness_penalty: -4.153e-05
 ↪ actuator_jerk_penalty: -5.04e-06
 ↪ actuator_relative_force_penalty: -5.013e-06
 ↪ bent_arm_penalty: -2.492e-05
 ↪ stay_alive_reward: 0.001606
 ↪ total: 0.007104
 ↪ upright_reward: 0.003139
 ↪ x_angular_velocity_penalty: -1.233e-05
 ↪ x_naive_forward_reward: 0.00237
 ↪ y_angular_velocity_penalty: -2.359e-05
 ↪ y_linear_velocity_penalty: -8.454e-07
 ↪ z_angular_velocity_penalty: -4.191e-05
 ↪ z_linear_velocity_penalty: -1.266e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05308
 ↪ dt: 17.27

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.37.bin


  [1;36mINFO[0m  [90m2025-05-02 22:29:38[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.37.bin
[2J[H[1;90mPhase: [0m[1;33mvalid[0m
 ↪ Steps: [36m37[0m
 ↪ Samples: [36m3,788,800[0m
 ↪ Elapsed Time: [36m13m, 22s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001499
 ↪ action_smoothness_penalty: -4.062e-05
 ↪ actuator_jerk_penalty: -4.946e-06
 ↪ actuator_relative_force_penalty: -4.979e-06
 ↪ bent_arm_penalty: -2.488e-05
 ↪ stay_alive_reward: 0.001621
 ↪ total: 0.007206
 ↪ upright_reward: 0.003147
 ↪ x_angular_velocity_penalty: -1.197e-05
 ↪ x_naive_forward_reward: 0.002444
 ↪ y_angular_velocity_penalty: -2.292e-05
 ↪ y_linear_velocity_penalty: -8.872e-07
 ↪ z_angular_velocity_penalty: -4.281e-05
 ↪ z_linear_velocity_penalty: -1.219e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.007486
 ↪ dt: 177.5

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.41.bin


  [1;36mINFO[0m  [90m2025-05-02 22:31:09[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.41.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m42[0m
 ↪ Samples: [36m4,300,800[0m
 ↪ Elapsed Time: [36m14m, 31s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001509
 ↪ action_smoothness_penalty: -3.991e-05
 ↪ actuator_jerk_penalty: -4.799e-06
 ↪ actuator_relative_force_penalty: -4.96e-06
 ↪ bent_arm_penalty: -2.501e-05
 ↪ stay_alive_reward: 0.00164
 ↪ total: 0.007185
 ↪ upright_reward: 0.003158
 ↪ x_angular_velocity_penalty: -1.159e-05
 ↪ x_naive_forward_reward: 0.002388
 ↪ y_angular_velocity_penalty: -2.243e-05
 ↪ y_linear_velocity_penalty: -9.811e-07
 ↪ z_angular_velocity_penalty: -4.198e-05
 ↪ z_linear_velocity_penalty: -1.146e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05244
 ↪ dt: 17.25

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53s

INFO:xax.task.mixins.checkpointing:Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.45.bin


  [1;36mINFO[0m  [90m2025-05-02 22:32:19[0m [[1;34mxax.task.mixins.checkpointing[0m] Saving checkpoint to /content/humanoid_walking_task/run_0/checkpoints/ckpt.45.bin
[2J[H[1;90mPhase: [0m[1;32mtrain[0m
 ↪ Steps: [36m46[0m
 ↪ Samples: [36m4,710,400[0m
 ↪ Elapsed Time: [36m15m, 41s[0m

[1;36m🎁 reward[0m
 ↪ action_in_bounds_reward: 0.0001538
 ↪ action_smoothness_penalty: -3.929e-05
 ↪ actuator_jerk_penalty: -4.691e-06
 ↪ actuator_relative_force_penalty: -4.942e-06
 ↪ bent_arm_penalty: -2.48e-05
 ↪ stay_alive_reward: 0.001641
 ↪ total: 0.007247
 ↪ upright_reward: 0.003161
 ↪ x_angular_velocity_penalty: -1.144e-05
 ↪ x_naive_forward_reward: 0.002445
 ↪ y_angular_velocity_penalty: -2.186e-05
 ↪ y_linear_velocity_penalty: -1.079e-06
 ↪ z_angular_velocity_penalty: -4.377e-05
 ↪ z_linear_velocity_penalty: -1.123e-06

[1;36m🕒 timers[0m
 ↪ steps/second: 0.05282
 ↪ dt: 17.35

[1;90mStatus[0m
 ✦ [32mTensorboard: http://3521daa4530a:6036/[0m
 ✦ [32mFirst step time: 2m, 53