In [1]:
from finger import Finger

from physics.gen.mass_matrix import mass_matrix
from physics.gen.bias_forces import bias_forces

from physics.simulate import step
from physics.visualize import animate

import jax
import jax.numpy as jnp
import jax.profiler
import flax
from flax import linen as nn
from flax import struct
import optax

import jax.experimental.host_callback

from training.infos import Infos

import shutil

from einops import einops, einsum
import matplotlib.pyplot as plt

from embeds import EmbeddingLayer
from training.rollout import collect_rollout

from training.eval_actor import evaluate_actor
from training.vibe_state import (
    VibeState,
    TrainConfig,
)
from training.nets import (
    StateEncoder,
    ActionEncoder,
    TransitionModel,
    StateDecoder,
    ActionDecoder,
    encoded_state_dim,
    encoded_action_dim,
)

import orbax.checkpoint as ocp

# from unitree_go1 import UnitreeGo1

from training.train import train_step, dump_to_wandb
from policy import (
    random_policy,
    random_repeat_policy,
    make_target_conf_policy,
    make_piecewise_actor,
    random_action,
)  # , max_dist_policy

import os

%matplotlib inline



In [2]:
seed = 1
key = jax.random.PRNGKey(seed)
### Set up RL stuff

checkpoint_dir = "checkpoints"

checkpointer = ocp.PyTreeCheckpointer()

learning_rate = float(1e-4)
every_k = 4

env_cls = Finger
start_state = env_cls.init()

env_config = env_cls.get_config()

schedule = optax.join_schedules(
    [
        optax.cosine_onecycle_schedule(
            4096,
            peak_value=learning_rate,
            pct_start=0.125,
            div_factor=1.0,
            final_div_factor=10.0,
        ),
        optax.linear_schedule(learning_rate / 10, learning_rate / 100, 32_768),
    ],
    [32_768],
)

vibe_config = TrainConfig.init(
    learning_rate=learning_rate,
    optimizer=optax.MultiSteps(
        optax.chain(
            optax.zero_nans(),
            optax.clip_by_global_norm(200.0),
            optax.lion(learning_rate=schedule),
        ),
        every_k_schedule=every_k,
    ),
    state_encoder=StateEncoder(),
    action_encoder=ActionEncoder(),
    transition_model=TransitionModel(1e4, 6, 64, 4),
    state_decoder=StateDecoder(env_config.state_dim),
    action_decoder=ActionDecoder(env_config.act_dim),
    env_config=env_config,
    seed=seed,
    rollouts=256,
    epochs=256,
    batch_size=128,
    every_k=every_k,
    traj_per_rollout=4096,
    rollout_length=512,
    reconstruction_weight=1.0,
    forward_weight=1.0,
    smoothness_weight=5.0,
    condensation_weight=10.0,
    dispersion_weight=10.0,
    inverse_reconstruction_gate_sharpness=1,
    inverse_forward_gate_sharpness=1,
    inverse_reconstruction_gate_center=-9,
    inverse_forward_gate_center=-9,
    forward_gate_sharpness=1,
    smoothness_gate_sharpness=1,
    dispersion_gate_sharpness=1,
    condensation_gate_sharpness=1,
    forward_gate_center=-2,
    smoothness_gate_center=-5,
    dispersion_gate_center=-9,
    condensation_gate_center=-9,
)

rng, key = jax.random.split(key)
vibe_state = VibeState.init(rng, vibe_config)

In [3]:
rng, key = jax.random.split(key)
vibe_state = VibeState.init(rng, vibe_config)

vibe_state = checkpointer.restore(
    os.path.join(checkpoint_dir, "checkpoint_r19_s8192.0"), item=vibe_state
)

In [4]:
vibe_state

VibeState(step=array(155648, dtype=int32), state_encoder_params={'params': {'FC0': {'bias': array([ 0.01506531, -0.01167228, -0.03017055,  0.00839612,  0.03984173,
        0.01835875,  0.08825713,  0.03908937,  0.01117716, -0.00252026,
       -0.01051669,  0.00981598, -0.01466805, -0.01567   ,  0.01133265,
        0.03520511,  0.05879371,  0.02894277,  0.04745984,  0.01086036,
        0.03935527,  0.06160103,  0.09904744, -0.00297538, -0.57944703,
        0.07525074,  0.00133884,  0.00985229, -0.06112331,  0.01954275,
        0.0652156 ,  0.0523013 ,  0.03253957, -0.5787335 ,  0.06711385,
        0.07164592, -0.1009863 ,  0.05320012,  0.05006368,  0.02417651,
        0.01792047,  0.02602118,  0.02243868,  0.06643349,  0.01223909,
       -0.00999015,  0.02076035,  0.05287045,  0.06194263,  0.06999426,
       -0.08673208,  0.0669498 ,  0.03612506,  0.02056428, -0.0151463 ,
        0.05329606,  0.02659886,  0.02977215,  0.03319614,  0.04233658,
       -0.0525182 ,  0.00797765,  0.03478298

## Actually run the actor evals, this will take a while

In [5]:
eval_count = 256

In [None]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, eval_count)

evaluate_actor_partial = jax.tree_util.Partial(
    evaluate_actor,
    big_step_size=0.5,
    big_steps=512,
    small_step_size=0.001,
    small_steps=1536,
    big_post_steps=16,
    small_post_steps=240,
)

(result_states, result_actions), info = jax.vmap(
    evaluate_actor, in_axes=(0, None, None, None, None)
)(
    rngs,
    start_state,
    env_cls,
    vibe_state,
    vibe_config,
)

In [None]:
info.plain_infos['final_cost']

In [None]:
plt.hist(info.plain_infos['final_cost'])