In [9]:
from env.finger import Finger

from learning.train_config import TrainConfig
from learning.train_state import TrainState, NetState

from policy.policy import Policy
from policy.actor_policy import ActorPolicy

from nets.nets import (
    StateEncoder,
    StateDecoder,
    TransitionModel,
    ActionEncoder,
    ActionDecoder,
)

from infos import Infos

import orbax.checkpoint as ocp
import optax

import jax
from jax import numpy as jnp

from pathlib import Path
import tempfile
import shutil

# Initialize the train state

In [10]:
seed = 0

# Generate random key
key = jax.random.PRNGKey(seed)

checkpointer = ocp.PyTreeCheckpointer()

# Save a list of the most recent checkpoints
checkpoint_paths = []
checkpoint_count = 3

# Set up the training config
learning_rate = float(1e-3)
every_k = 1

# Set the environment class
env_cls = Finger

# Grab the default environment config from the env class
env_config = env_cls.get_config()

# Set the latent state and action dimensions (here I've just set them the same as the state and action dims)
latent_state_dim = 6
latent_action_dim = 2

# Actually create the training config
train_config = TrainConfig.init(
    learning_rate=learning_rate,
    # Make the optimizer
    optimizer=optax.chain(
        optax.zero_nans(),
        optax.adamw(
            learning_rate=optax.cosine_onecycle_schedule(
                transition_steps=8192,
                peak_value=learning_rate,
                pct_start=0.3,
                div_factor=25.0,
                final_div_factor=2.0,
            ),
        ),
    ),
    # Instantiate all of the networks
    state_encoder=StateEncoder(latent_state_dim=latent_state_dim),
    action_encoder=ActionEncoder(latent_action_dim=latent_action_dim),
    transition_model=TransitionModel(
        latent_state_dim=latent_state_dim, n_layers=8, latent_dim=64, heads=4
    ),
    state_decoder=StateDecoder(state_dim=env_config.state_dim),
    action_decoder=ActionDecoder(act_dim=env_config.act_dim),
    latent_state_dim=latent_state_dim,
    latent_action_dim=latent_action_dim,
    env_config=env_config,
    env_cls=env_cls,
    seed=seed,
    target_net_tau=0.05,
    transition_factor=10.0,
    rollouts=256,
    epochs=64,
    batch_size=64,
    every_k=every_k,
    traj_per_rollout=1024,
    rollout_length=64,
    state_radius=1.6,
    action_radius=2.0,
    reconstruction_weight=1.0,
    forward_weight=1.0,
    smoothness_weight=1.0,
    condensation_weight=1.0,
    dispersion_weight=10.0,
    forward_gate_sharpness=8,
    smoothness_gate_sharpness=1,
    dispersion_gate_sharpness=1,
    condensation_gate_sharpness=8,
    forward_gate_center=0.00025,
    smoothness_gate_center=-3,
    dispersion_gate_center=-3,
    condensation_gate_center=0.00025,
)

# Create the train state that contains all of the network and optimizer parameters
rng, key = jax.random.split(key)
train_state = TrainState.init(rng, train_config)

# Restore the train state to the checkpoint

In [17]:
checkpoint_dir = Path("checkpoints")
run_id = "zc71g4ms"
checkpoint_path = checkpoint_dir / f"checkpoints_{run_id}" / "checkpoint_latest.zip"

# Make a temporary directory and unzip to there
with tempfile.TemporaryDirectory() as tmpdirname:
    tmpdir = Path(tmpdirname)
    # Unzip the checkpoint
    shutil.unpack_archive(checkpoint_path, tmpdir / "checkpoint_latest")
    # Load the checkpoint
    train_state = checkpointer.restore(
        tmpdir / "checkpoint_latest",
        item=train_state,
    )

KeyError: 'FC3'

: 

In [12]:
def collect_single_rollout(
    key,
    start_state,
    policy: Policy,
    policy_aux,
    net_state: NetState,
    train_config: TrainConfig,
):
    rng, key = jax.random.split(key)
    init_policy_carry, init_policy_info = policy.make_init_carry(
        key=rng,
        start_state=start_state,
        aux=policy_aux,
        net_state=net_state,
        train_config=train_config,
    )

    # Collect a rollout of physics data
    def scanf(carry, key):
        """Scans to collect a single rollout of physics data."""
        state, i, policy_carry = carry

        rng, key = jax.random.split(key)
        action, policy_carry, policy_info = policy(
            key=rng,
            state=state,
            i=i,
            carry=policy_carry,
            net_state=net_state,
            train_config=train_config,
        )
        action = jnp.clip(
            action,
            a_min=train_config.env_config.action_bounds[..., 0],
            a_max=train_config.env_config.action_bounds[..., -1],
        )
        next_state, dense_states = train_config.env_cls.step(
            state, action, train_config.env_config
        )

        return (next_state, i + 1, policy_carry), (
            (state, action),
            dense_states,
            policy_info,
        )

    rng, key = jax.random.split(key)
    scan_rngs = jax.random.split(rng, train_config.rollout_length - 1)
    _, ((states, actions), dense_states, policy_info) = jax.lax.scan(
        scanf,
        (start_state, 0, init_policy_carry),
        scan_rngs,
    )

    dense_states = rearrange(dense_states, "t u s -> (t u) s")

    dense_states = jnp.concatenate([start_state[None], dense_states])
    states = jnp.concatenate([states, start_state[None]])

    return (states, actions), Infos.merge(policy_info, init_policy_info), dense_states

In [13]:
def eval_single_actor(
    key,
    start_state,
    net_state: NetState,
    train_config: TrainConfig,
    target_q=1.0,
    big_step_size=0.5,
    big_steps=2048,
    small_step_size=0.005,
    small_steps=2048,
    big_post_steps=32,
    small_post_steps=32,
):
    policy = ActorPolicy.init()
    policy_aux = policy.make_aux(target_q=target_q)

    rng, key = jax.random.split(key)
    (
        (result_states, result_actions),
        rollout_infos,
        result_dense_states,
    ) = collect_single_rollout(
        key=rng,
        start_state=start_state,
        policy=policy,
        policy_aux=policy_aux,
        net_state=net_state,
        train_config=train_config,
    )

    def cost_func(state):
        state_cost = jnp.abs(state[0] - target_q)

        return state_cost

    final_cost = jnp.mean(jax.vmap(cost_func)(result_states))
    infos = rollout_infos.add_plain_info("final_cost", final_cost)

    return (
        (result_states, result_actions),
        infos,
        result_dense_states,
    )

In [14]:
(states, actions), infos, dense_states = eval_single_actor(
    key=jax.random.PRNGKey(0),
    start_state=jnp.array([0.0]),
    net_state=,
    train_config=TrainConfig.init(),
)

TypeError: NetState.init() missing 2 required positional arguments: 'key' and 'train_config'