In [None]:
import gymnasium as gym
import numpy as np
import pandas
import seaborn as sns
import torch
from einops import einsum, pack, rearrange, repeat
from matplotlib import pyplot as plt
from torch import nn

from nets import Perceptron, TransitionModel

In [None]:
np_rng = np.random.default_rng(0)

In [None]:
# Load data from data.npz
data = np.load("data.npz")
indices = np.load("trained_net_params/indices.npz")
train_indices, test_indices = indices["train_indices"], indices["test_indices"]

observations_train = data["observations"][train_indices]
actions_train = data["actions"][train_indices]

observations_test = torch.tensor(data["observations"][test_indices], dtype=torch.float32, device="cuda")
actions_test = torch.tensor(data["actions"][test_indices], dtype=torch.float32, device="cuda")

In [None]:
# Load models
state_encoder = torch.load("trained_net_params/state_encoder.pt")
action_encoder = torch.load("trained_net_params/action_encoder.pt")
transition_model = torch.load("trained_net_params/transition_model.pt")
state_decoder = torch.load("trained_net_params/state_decoder.pt")
action_decoder = torch.load("trained_net_params/action_decoder.pt")

In [None]:
flat_states = rearrange(observations_test, "... f -> (...) f")
flat_actions = rearrange(actions_test, "... f -> (...) f")

In [None]:
# Get stdev of encoded states and actions across each element
with torch.no_grad():
    encoded_states = state_encoder(flat_states)
    state_std = torch.std(encoded_states, dim=0)

    encoded_actions = action_encoder(torch.cat([flat_actions, flat_states], dim=-1))

    recovered_states = state_decoder(encoded_states)
    recovered_actions = action_decoder(
        torch.cat([encoded_actions, encoded_states], dim=-1)
    )

    action_std = torch.std(encoded_actions, dim=0)

In [None]:
print(f"State std: {state_std.cpu().numpy()}\nAction std: {action_std.cpu().numpy()}")

In [None]:
transition_batch_size = 32
# Test forward model
traj_inds = torch.randint(
    0, observations_test.shape[0], (transition_batch_size,), device="cuda"
)
test_start_inds = torch.randint(
    0, int(observations_test.shape[-2] // 1.1), (transition_batch_size,), device="cuda"
)

test_states = observations_test[traj_inds]
test_actions = actions_test[traj_inds]

start_states = observations_test[traj_inds, test_start_inds]

In [None]:
start_states.shape

In [None]:
latent_start_states = state_encoder(start_states)

latent_traj_actions = action_encoder(torch.cat([test_actions, test_states], dim=-1))

latent_pred_fut_states = transition_model(
    latent_start_states, latent_traj_actions, start_indices=test_start_inds
)

In [None]:
latent_fut_states = state_encoder(test_states)

traj_ind = traj_inds[0]
start_ind = test_start_inds[0]

latent_fut_states_select = latent_fut_states[0, start_ind:]
latent_pred_fut_states_select = latent_pred_fut_states[0, start_ind:]

In [None]:
torch.stack([latent_fut_states_select, latent_pred_fut_states_select], dim=-1)

In [None]:
# Measure error per time into the future

time_into_future = torch.arange(
    latent_fut_states.shape[-2], device="cuda"
)[None] - test_start_inds[..., None]

In [None]:
mae_errors = torch.mean(
    torch.abs(latent_fut_states - latent_pred_fut_states), dim=-1
)

In [None]:
mae_errors_flat = rearrange(mae_errors, "... -> (...)").detach().cpu().numpy()
time_into_future_flat = (
    rearrange(time_into_future, "... -> (...)").detach().cpu().numpy()
)

df = pandas.DataFrame.from_dict(
    {"mae_error": mae_errors_flat, "time_into_future": time_into_future_flat},
)
df = df[df["time_into_future"] >= 0]

In [None]:
fig, ax = plt.subplots()
sns.lineplot(data=df, x="time_into_future", y="mae_error", ax=ax)
ax.set_xlim(0, 1024)