In [None]:
import os

import gymnasium as gym
import mediapy as media
import numpy as np
import torch
from tqdm import tqdm

from stuff import *

In [None]:
# Set mujoco to use EGL
os.environ["MUJOCO_GL"] = "egl"

In [None]:
# Load models
state_encoder = torch.load('trained_net_params/state_encoder.pt')
action_encoder = torch.load('trained_net_params/action_encoder.pt')
transiton_model = torch.load('trained_net_params/transition_model.pt')
state_decoder = torch.load('trained_net_params/state_decoder.pt')
action_decoder = torch.load('trained_net_params/action_decoder.pt')

In [None]:
# Load data from data.npz
data = np.load("data.npz")
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device="cuda") for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data["observations"], data["actions"])
# observations, actions = observations[:4096], actions[:4096]

In [None]:
action_space_size = 1.75

In [None]:
# Now let's optimize a trajectory

actions = torch.randn(32, 64, 2, device="cuda")
actions = actions / torch.norm(actions, p=1, dim=-1, keepdim=True)
actions = actions * torch.rand((32, 64, 1), device="cuda") * action_space_size

actor = ActorPolicy(
    2,
    action_space_size,
    state_encoder,
    transiton_model,
    action_decoder,
    horizon=128,
    iters=32,
).cuda()

env = gym.make("PointMaze_Large-v3", render_mode="rgb_array")
obs, info = env.reset()
state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")
latent_action_plan = None

initial_state = state
target_state = torch.tensor(
    [*obs["desired_goal"], 0, 0], dtype=torch.float32, device="cuda"
)

actions = []
states = [initial_state]

frames = [env.render()]
for i in tqdm(range(1024), disable=True):
    next_action, latent_action_plan = actor(
        state[None], target_state[None], latent_action_plan
    )
    obs, rew, term, trunc, info = env.step(next_action[0].cpu().numpy())
    state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")

    actions.append(next_action[0].cpu().numpy())
    states.append(state)

    frames.append(env.render())
    print(f"Iteration {i}, action: {next_action[0].cpu().numpy()}")

traj_actions = torch.tensor(actions, dtype=torch.float32, device="cuda")
traj_states = torch.stack(states, dim=0)

In [None]:
# Show the video
media.show_video(frames, fps=30)

In [None]:
latent_traj_states = state_encoder(traj_states)
latent_traj_actions = action_encoder(torch.cat([traj_actions, latent_traj_states[..., 1:, :]], dim=-1))
latent_initial_state = state_encoder(initial_state[None])

predicted_fut_latent_states = transiton_model(latent_initial_state, latent_traj_actions[None])
predicted_fut_states = state_decoder(predicted_fut_latent_states)
actual_fut_states = traj_states[1:]

In [None]:
predicted_fut_states

In [None]:
actual_fut_states