In [7]:
import shutil
from pathlib import Path
import os
import sys
import time

# set the cwd to the parent of the parent of this file

os.chdir(Path(os.getcwd()).parent)

In [2]:
from env.finger import Finger

from learning.eval_actor import eval_batch_actor

from learning.train_state import TrainConfig, TrainState

from learning.training.train_step import train_step

from nets.nets import (
    StateEncoder,
    ActionEncoder,
    TransitionModel,
    StateDecoder,
    ActionDecoder,
)

import orbax.checkpoint as ocp
import optax

import jax
from jax import numpy as jnp
from jax.experimental.host_callback import id_tap

import wandb

import argparse

In [3]:
seed = 0

# Generate random key
key = jax.random.PRNGKey(seed)

checkpointer = ocp.PyTreeCheckpointer()

# Save a list of the most recent checkpoints
checkpoint_paths = []
checkpoint_count = 3

# Set up the training config
learning_rate = float(2.5e-4)
every_k = 1

# Set the environment class
env_cls = Finger

# Grab the default environment config from the env class
env_config = env_cls.get_config()

# Set the latent state and action dimensions (here I've just set them the same as the state and action dims)
latent_state_dim = 6
latent_action_dim = 2

# Actually create the training config
train_config = TrainConfig.init(
    learning_rate=learning_rate,
    # Make the optimizer
    optimizer=optax.chain(
        optax.zero_nans(),
        optax.adamw(
            learning_rate=optax.cosine_onecycle_schedule(
                transition_steps=8192,
                peak_value=learning_rate,
                pct_start=0.3,
                div_factor=25.0,
                final_div_factor=1.0,
            ),
        ),
    ),
    # Instantiate all of the networks
    state_encoder=StateEncoder(latent_state_dim=latent_state_dim),
    action_encoder=ActionEncoder(latent_action_dim=latent_action_dim),
    transition_model=TransitionModel(
        latent_state_dim=latent_state_dim, n_layers=8, latent_dim=64, heads=4
    ),
    state_decoder=StateDecoder(state_dim=env_config.state_dim),
    action_decoder=ActionDecoder(act_dim=env_config.act_dim),
    latent_state_dim=latent_state_dim,
    latent_action_dim=latent_action_dim,
    env_config=env_config,
    env_cls=env_cls,
    seed=seed,
    target_net_tau=0.05,
    transition_factor=10.0,
    rollouts=256,
    epochs=64,
    batch_size=64,
    every_k=every_k,
    traj_per_rollout=1024,
    rollout_length=64,
    state_radius=1.6,
    action_radius=2.0,
    reconstruction_weight=1.0,
    forward_weight=1.0,
    smoothness_weight=1.0,
    condensation_weight=1.0,
    dispersion_weight=10.0,
    forward_gate_sharpness=1,
    smoothness_gate_sharpness=1,
    dispersion_gate_sharpness=1,
    condensation_gate_sharpness=1,
    forward_gate_center=-6,
    smoothness_gate_center=-3,
    dispersion_gate_center=-3,
    condensation_gate_center=-6,
)

# Create the train state that contains all of the network and optimizer parameters
rng, key = jax.random.split(key)
train_state = TrainState.init(rng, train_config)

In [4]:
# Disable wandb
wandb.init(
    mode="disabled",
)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




# Let's just to a sanity check and see if we can train anything at all

In [5]:
def gen_dat_train_step(train_state, key):
    rng, key = jax.random.split(key)
    state_data = jax.random.normal(
        rng,
        [
            train_state.train_config.batch_size,
            train_state.train_config.rollout_length,
            train_state.train_config.env_config.state_dim,
        ],
    )
    rng, key = jax.random.split(key)
    action_data = jax.random.normal(
        rng,
        [
            train_state.train_config.batch_size,
            train_state.train_config.rollout_length - 1,
            train_state.train_config.env_config.act_dim,
        ],
    )

    train_state = train_step(state_data, action_data, train_state)

    return train_state, None

In [6]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, 16384)
train_state, _ = jax.lax.scan(gen_dat_train_step, train_state, rngs)

Step 0:
Losses:
		state_reconstruction_loss: 0.3926679193973541
		action_reconstruction_loss: 0.48948368430137634

Step 1:
Losses:
		state_reconstruction_loss: 0.41215747594833374
		action_reconstruction_loss: 0.44861045479774475

Step 2:
Losses:
		state_reconstruction_loss: 0.39103737473487854
		action_reconstruction_loss: 0.45167019963264465

Step 3:
Losses:
		state_reconstruction_loss: 0.39208364486694336
		action_reconstruction_loss: 0.46379655599594116

Step 4:
Losses:
		state_reconstruction_loss: 0.39477190375328064
		action_reconstruction_loss: 0.41120803356170654

Step 5:
Losses:
		state_reconstruction_loss: 0.40082722902297974
		action_reconstruction_loss: 0.44826969504356384

Step 6:
Losses:
		state_reconstruction_loss: 0.3656995892524719
		action_reconstruction_loss: 0.43818405270576477

Step 7:
Losses:
		state_reconstruction_loss: 0.4132826030254364
		action_reconstruction_loss: 0.44849857687950134

Step 8:
Losses:
		state_reconstruction_loss: 0.38991302251815796
		action_r