In [None]:
from latch.env.finger.finger import Finger

from latch.models import ModelState

from latch.rollout import eval_actor

from latch.policy.actor_policy import ActorPolicy

from latch.latch_config import LatchConfig
from latch.config import TrainConfig, NetConfig, configure_state

import orbax.checkpoint as ocp

import jax
from jax.tree_util import Partial
import jax.numpy as jnp
import optax

import jax.experimental.host_callback

from einops import einops, einsum, rearrange
import matplotlib.pyplot as plt

import orbax.checkpoint as ocp
import shutil

from pathlib import Path

import os

%matplotlib inline

In [None]:
cfg = TrainConfig(
    net_config=NetConfig(
        latent_state_dim=6,
        state_dim=6,
        latent_action_dim=2,
        action_dim=2,
        latent_state_radius=1.5,
        latent_action_radius=2.0,
    )
)

# Instantiate the environment
env = Finger.init()

# Initialize the train state
train_state = configure_state(train_config=cfg, env=env)

checkpointer = ocp.StandardCheckpointer()

checkpoint_path = Path("../../checkpoints/9wawoqrb/checkpoint_latest.zip")

shutil.unpack_archive(checkpoint_path, checkpoint_path.with_suffix(""))
train_state = checkpointer.restore(
    checkpoint_path.with_suffix("").absolute(), item=train_state
)

In [None]:
def eval_model(train_state, theta):
    """This evaluates the model and logs the results to wandb."""
    state_target = jnp.zeros(train_state.config.state_dim)
    state_weights = jnp.zeros_like(state_target)

    state_target = state_target.at[0].set(theta)
    state_weights = state_weights.at[0].set(1.0)
    policy = ActorPolicy(state_target=state_target, state_weights=state_weights)

    key, train_state = train_state.split_key()

    rng, key = jax.random.split(key)
    rngs = jax.random.split(rng, 32)
    result_states, eval_infos, dense_states = jax.vmap(
        Partial(
            eval_actor,
            start_state=train_state.config.env.reset(),
            train_state=train_state,
            policy=policy,
        )
    )(key=rngs)

    return result_states, eval_infos, dense_states

In [None]:
result_states, infos, dense_states = eval_model(train_state, jnp.array(-6.0))

In [None]:
result_states[0, -2]

## Actually run the actor evals, this will take a while

In [None]:
policy = ActorPolicy(
    big_step_size = 0.5,
    small_step_size = 0.005,

    big_steps = 2048,
    small_steps = 2048,

    big_post_steps = 32,
    small_post_steps = 32,
    
    state_target = jnp.array([-6.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    state_weights = jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    
)

In [None]:
start_state = env_cls.init()

## Optimize us some actions

In [None]:
rng, key = jax.random.split(key)
(optimized_actions, aux), infos, costs = jax.jit(policy.make_init_carry)(
    key=rng,
    start_state=start_state,
    aux=policy_aux,
    net_state=train_state.target_net_state,
    train_config=train_state.train_config,
)

In [None]:
plt.plot(costs)

## Now let's see what the actor thinks would happen

In [None]:
rng, key = jax.random.split(key)
latent_start_state = encode_state(
    rng, start_state, train_state.target_net_state, train_state.train_config
)
expected_latent_states = jax.jit(infer_states)(rng, latent_start_state, optimized_actions, train_state.target_net_state, train_state.train_config)

In [None]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, expected_latent_states.shape[0])
expected_states = jax.jit(
    jax.vmap(
        Partial(
            decode_state,
            net_state=train_state.target_net_state,
            train_config=train_state.train_config,
        )
    )
)(rngs, expected_latent_states)

In [None]:
plt.plot(expected_states[..., 2])

In [None]:
import mediapy as media

media.show_video(
    env_cls.host_make_video(
        expected_states, train_state.train_config.env_config, dense=False
    ).transpose(0, 2, 3, 1)
)

## Roll out the policy in the environment

In [None]:
def scanf(carry, key):
    """Scans to collect a single rollout of physics data."""
    state, i, policy_carry = carry

    rng, key = jax.random.split(key)
    action, policy_carry, policy_info = policy(
        key=rng,
        state=state,
        i=i,
        carry=policy_carry,
        net_state=train_state.target_net_state,
        train_config=train_state.train_config,
    )
    action = jnp.clip(
        action,
        a_min=train_config.env_config.action_bounds[..., 0],
        a_max=train_config.env_config.action_bounds[..., -1],
    )
    next_state, dense_states = train_config.env_cls.step(
        state, action, train_config.env_config
    )

    return (next_state, i + 1, policy_carry), (
        (state, action),
        dense_states,
        policy_info,
    )

rng, key = jax.random.split(key)
scan_rngs = jax.random.split(rng, train_config.rollout_length - 1)
_, ((states, actions), dense_states, policy_info) = jax.lax.scan(
    scanf,
    (start_state, 0, (optimized_actions, aux)),
    scan_rngs,
)

dense_states = rearrange(dense_states, "t u s -> (t u) s")

dense_states = jnp.concatenate([start_state[None], dense_states])
states = jnp.concatenate([states, start_state[None]])

dense_actions = jnp.repeat(actions, 32, axis=0)

In [None]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, states.shape[0])
latent_states = jax.vmap(
    Partial(
        encode_state,
        net_state=train_state.target_net_state,
        train_config=train_state.train_config,
    ),
)(key=rngs, state=states)

rng, key = jax.random.split(key)
rngs = jax.random.split(rng, dense_states.shape[0])
dense_latent_states = jax.vmap(
    Partial(
        encode_state,
        net_state=train_state.target_net_state,
        train_config=train_state.train_config,
    ),
)(key=rngs, state=dense_states)

In [None]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, actions.shape[0])
latent_actions = jax.vmap(
    Partial(
        encode_action,
        net_state=train_state.target_net_state,
        train_config=train_state.train_config,
    )
)(key=rngs, action=actions, latent_state=latent_states[:-1])


rng, key = jax.random.split(key)
rngs = jax.random.split(rng, dense_actions.shape[0])
dense_latent_actions = jax.vmap(
    Partial(
        encode_action,
        net_state=train_state.target_net_state,
        train_config=train_state.train_config,
    )
)(key=rngs, action=dense_actions, latent_state=dense_latent_states[:-1])

In [None]:
rng, key = jax.random.split(key)
inferred_latent_states = infer_states(
    key=rng,
    latent_start_state=latent_states[0],
    latent_actions=latent_actions,
    net_state=train_state.target_net_state,
    train_config=train_state.train_config,
    current_action_i=0,
)

In [None]:
diffs = inferred_latent_states - latent_states[1:]
diff_norms = jnp.linalg.norm(diffs, ord=1, axis=-1)

In [None]:
print(jnp.mean(diff_norms))

In [None]:
plt.plot(diff_norms)

In [None]:
rng, key = jax.random.split(key)
rngs = jax.random.split(rng, inferred_latent_states.shape[0])

media.show_video(
    env_cls.host_make_video(
        # jax.vmap(
        #     Partial(
        #         decode_state,
        #         net_state=train_state.target_net_state,
        #         train_config=train_state.train_config,
        #     )
        # )(key=rngs, latent_state=inferred_latent_states),
        dense_states,
        env_config=train_state.train_config.env_config,
        dense=True,
    ).transpose([0, 2, 3, 1])
)

## Let's make a scatterplot of all of the latent states and actions the algorithm decided on

In [None]:
import numpy as np

plt.scatter(
    dense_latent_states[..., 4],
    dense_latent_states[..., 5],
    c=range(dense_latent_states.shape[0]),
    cmap="viridis",
)

In [None]:
plt.scatter(optimized_actions[..., 0], optimized_actions[..., 1], c=range(len(optimized_actions)), cmap="viridis")

## Let's plot the achieved final costs

In [None]:
# plt.hist(info.plain_infos['final_cost'])

## Let's visualize what the algorithm did

In [None]:
import mediapy as media

video = env_cls.host_make_video(dense_states, env_config).transpose(0, 2, 3, 1)
media.show_video(video, fps=24)

## Let's investigate what action space actions are available from the start state

In [None]:
sample_count = 512
rng, key = jax.random.split(key)
# action_samples = (
#     jax.random.ball(rng, d=train_state.train_config.latent_action_dim, p=1, shape=[sample_count])
#     * train_state.train_config.action_radius
# )

action_samples = (
    jax.random.ball(
        rng, d=train_state.train_config.latent_action_dim, p=1, shape=[4, sample_count]
    )
    * 1.0
) + jnp.array([[-1, 0], [1, 0], [0, -1], [0, 1]])[:, None, :]


rng, key = jax.random.split(key)
rngs = jax.random.split(rng, [4, sample_count])
action_space_actions = jax.vmap(
    jax.vmap(
        jax.tree_util.Partial(
            decode_action,
            latent_state=latent_states[12],
            net_state=train_state.target_net_state,
            train_config=train_state.train_config,
        )
    )
)(rngs, action_samples)

# plt.scatter(x=action_samples[..., 0], y=action_samples[..., 1])
for i in range(4):
    plt.scatter(x=action_space_actions[i, ..., 0], y=action_space_actions[i, ..., 1])

# set lims
# plt.xlim(-10, 10)
# plt.ylim(-10, 10)