In [None]:
import os

import gymnasium as gym
import mediapy as media
import numpy as np
import seaborn as sns
import torch
from tqdm import tqdm

from legato.nets import *
from legato.sampler import *

In [None]:
# Set mujoco to use EGL
os.environ["MUJOCO_GL"] = "egl"

In [None]:
# Load models
state_encoder = torch.load('../trained_net_params/state_encoder.pt')
action_encoder = torch.load('../trained_net_params/action_encoder.pt')
transiton_model = torch.load('../trained_net_params/transition_model.pt')
state_decoder = torch.load('../trained_net_params/state_decoder.pt')
action_decoder = torch.load('../trained_net_params/action_decoder.pt')

In [None]:
action_space_size = 1.0

In [None]:
# Now let's optimize a trajectory

actor = ActorPolicy(
    2,
    action_space_size,
    state_encoder,
    transiton_model,
    action_decoder,
    lr=0.01,
    decay=1.0,
    horizon=16,
    iters=256,
).cuda()

# actor = torch.compile(actor)

env = gym.make("PointMaze_Medium-v3", render_mode="rgb_array")
obs, info = env.reset()
state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")
latent_action_plan = None

initial_state = state
target_state = torch.tensor(
    [*obs["desired_goal"], 0, 0], dtype=torch.float32, device="cuda"
)

actions = []
states = [initial_state]

frames = [env.render()]

loss_curves = []
for i in tqdm(range(64), disable=True):
    next_action, latent_action_plan, loss_curve = actor(
        state[None], target_state[None], latent_action_plan, return_curve=True
    )
    # latent_action_plan[:] = 1
    loss_curves.append(loss_curve)
    obs, rew, term, trunc, info = env.step(next_action[0].cpu().numpy())
    state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")

    actions.append(next_action[0].cpu().numpy())
    states.append(state)

    frames.append(env.render())
    print(f"Iteration {i}, action: {next_action[0].cpu().numpy()}")

traj_actions = torch.tensor(actions, dtype=torch.float32, device="cuda")
traj_states = torch.stack(states, dim=0)

In [None]:
curve = loss_curves[1]
# Plot the loss curve with i as the x axis
sns.lineplot(x=range(len(curve)), y=curve)

In [None]:
# Show the video
media.show_video(frames, fps=30)

In [None]:
# Load data from data.npz
data = np.load("../data.npz")
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device="cuda") for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data["observations"], data["actions"])
# observations, actions = observations[:4096], actions[:4096]

In [None]:
latent_traj_states = state_encoder(traj_states)
latent_traj_actions = action_encoder((traj_actions, traj_states[..., :-1, :]))
latent_initial_state = state_encoder(initial_state[None])

predicted_fut_latent_states, mask = transiton_model(
    latent_initial_state, latent_traj_actions[None], return_mask=True
)
predicted_fut_states = state_decoder(predicted_fut_latent_states)
actual_fut_states = traj_states[1:]
latent_traj_states = state_encoder(traj_states)
actual_fut_latent_states = state_encoder(actual_fut_states)

In [None]:
actual_fut_latent_states = state_encoder(actual_fut_states)

In [None]:
traj_actions

In [None]:
traj_states

In [None]:
reconstructed_traj_actions = action_decoder(
    (latent_traj_actions, latent_traj_states[..., :-1, :])
)

In [None]:
reconstructed_traj_actions

In [None]:
predicted_fut_latent_states

In [None]:
actual_fut_latent_states

In [None]:
predicted_fut_states

In [None]:
actual_fut_states

In [None]:
latent_state_sampler = PBallSampler(4, 1, 2.0, device="cuda")
latent_action_sampler = PBallSampler(2, 1, 1.0, device="cuda")
latent_action_state_sampler = lambda n: (
    latent_action_sampler(n),
    latent_state_sampler(n),
)

In [None]:
action_samples = action_decoder(latent_action_state_sampler(4096))

In [None]:
action_samples

In [None]:
actions