In [1]:
import gymnasium as gym
import numpy as np
import torch
from einops import einsum, pack, rearrange, repeat
from torch import nn

from nets import Perceptron, TransitionModel

In [2]:
np_rng = np.random.default_rng(0)

In [3]:
# Load data from data.npz
data = np.load("data.npz")
indices = np.load("trained_net_params/indices.npz")
train_indices, test_indices = indices["train_indices"], indices["test_indices"]

observations_train = data["observations"][train_indices]
actions_train = data["actions"][train_indices]

observations_test = torch.tensor(data["observations"][test_indices], dtype=torch.float32, device="cuda")
actions_test = torch.tensor(data["actions"][test_indices], dtype=torch.float32, device="cuda")

In [4]:
# Load models
state_encoder = torch.load("trained_net_params/state_encoder.pt")
action_encoder = torch.load("trained_net_params/action_encoder.pt")
transition_model = torch.load("trained_net_params/transition_model.pt")
state_decoder = torch.load("trained_net_params/state_decoder.pt")
action_decoder = torch.load("trained_net_params/action_decoder.pt")

In [5]:
flat_states = rearrange(observations_test, "... f -> (...) f")
flat_actions = rearrange(actions_test, "... f -> (...) f")

In [6]:
# Get stdev of encoded states and actions across each element
with torch.no_grad():
    encoded_states = state_encoder(flat_states)
    state_std = torch.std(encoded_states, dim=0)

    encoded_actions = action_encoder(torch.cat([flat_actions, flat_states], dim=-1))

    recovered_states = state_decoder(encoded_states)
    recovered_actions = action_decoder(
        torch.cat([encoded_actions, encoded_states], dim=-1)
    )

    action_std = torch.std(encoded_actions, dim=0)

In [7]:
print(f"State std: {state_std.cpu().numpy()}\nAction std: {action_std.cpu().numpy()}")

State std: [0.18352377 0.19787525 0.20208438 0.08535806]
Action std: [0.94992095 0.91848886]


In [13]:
transition_batch_size = 32
# Test forward model
traj_inds = torch.randint(
    0, observations_test.shape[0], (transition_batch_size,), device="cuda"
)
test_start_inds = torch.randint(
    0, int(observations_test.shape[-2] // 1.1), (transition_batch_size,), device="cuda"
)

test_states = observations_test[traj_inds]
test_actions = actions_test[traj_inds]

start_states = observations_test[traj_inds, test_start_inds]

In [14]:
start_states.shape

torch.Size([32, 4])

In [19]:
latent_start_states = state_encoder(start_states)

latent_traj_actions = action_encoder(torch.cat([test_actions, test_states], dim=-1))

latent_pred_fut_states = transition_model(
    latent_start_states, latent_traj_actions, start_indices=test_start_inds
)

In [22]:
start_ind

tensor(605, device='cuda:0')

In [23]:
latent_fut_states = state_encoder(test_states) 

traj_ind = traj_inds[0]
start_ind = test_start_inds[0]

latent_fut_states_select = latent_fut_states[0, start_ind :]
latent_pred_fut_states_select = latent_pred_fut_states[0, start_ind:]

In [24]:
latent_fut_states_select, latent_pred_fut_states_select

(tensor([[-0.0056, -0.1353,  0.1802,  0.0828],
         [-0.0273, -0.1289,  0.1737,  0.0794],
         [-0.0239, -0.1289,  0.1636,  0.0794],
         ...,
         [ 0.0614, -0.0443, -0.0229, -0.0284],
         [ 0.0567, -0.0472, -0.0299, -0.0350],
         [ 0.0453, -0.0462, -0.0243, -0.0392]], device='cuda:0',
        grad_fn=<SliceBackward0>),
 tensor([[-0.0015, -0.1084, -0.0285,  0.0764],
         [-0.0230, -0.1141, -0.0157,  0.0792],
         [ 0.0030, -0.1120, -0.0202,  0.0749],
         ...,
         [ 0.0334, -0.0987, -0.0442,  0.0651],
         [ 0.0142, -0.0876, -0.0324,  0.0685],
         [ 0.0044, -0.0965, -0.0246,  0.0701]], device='cuda:0',
        grad_fn=<SliceBackward0>))