In [None]:
import os

import gymnasium as gym
import mediapy as media
import numpy as np
import seaborn as sns
import torch
from tqdm import tqdm
import wandb

from legato.nets import *
from legato.sampler import *

In [None]:
# Set mujoco to use EGL
os.environ["MUJOCO_GL"] = "egl"

In [None]:
# Download files from a past wandb run
run_id = "dmqzn687"
for name in [
    "state_encoder.pt",
    "action_encoder.pt",
    "transition_model.pt",
    "state_decoder.pt",
    "action_decoder.pt",
    "indices.npz",
]:
    wandb.restore(
        f"trained_net_params/{name}",
        run_path=f"mishmish66/legato/{run_id}",
        replace=True,
        root="../",
    )

In [None]:
# Load models
state_encoder = torch.load('../trained_net_params/state_encoder.pt')
action_encoder = torch.load('../trained_net_params/action_encoder.pt')
transiton_model = torch.load('../trained_net_params/transition_model.pt')
state_decoder = torch.load('../trained_net_params/state_decoder.pt')
action_decoder = torch.load('../trained_net_params/action_decoder.pt')

In [None]:
action_space_size = 1.0

In [None]:
# Now let's optimize a trajectory

actor = ActorPolicy(
    2,
    action_space_size,
    state_encoder,
    transiton_model,
    state_decoder,
    action_decoder,
    lr=0.1,
    decay=1.0,
    horizon=1024,
    iters=256,
).cuda()

# actor = torch.compile(actor)

parallel_env_count = 8
# Make parallel env
envs = gym.vector.make(
    "PointMaze_Medium-v3",
    num_envs=parallel_env_count,
    asynchronous=True,
    render_mode="rgb_array",
    continuing_task=True,
)

obses, infos = envs.reset()
states = torch.tensor(obses["observation"], dtype=torch.float32, device="cuda")

latent_action_plan = None

initial_state = states

traj_actions = []
traj_states = [initial_state]

frame_stacks = [[frame] for frame in envs.call("render")]

loss_curves = []
for i in tqdm(range(256), disable=True):

    target_states = torch.zeros(
        (parallel_env_count, 4), dtype=torch.float32, device="cuda"
    )
    desired_goals = torch.tensor(
        obses["desired_goal"], dtype=torch.float32, device="cuda"
    )
    target_states[:, :2] = desired_goals

    next_actions, latent_action_plan, loss_curve = actor(
        states, target_states, None, return_curve=True
    )

    loss_curves.append(loss_curve)
    obses, rew, term, trunc, info = envs.step(next_actions.cpu().numpy())

    traj_actions.append(next_actions)
    traj_states.append(obses["observation"])

    for frame_stack, frame in zip(frame_stacks, envs.call("render")):
        frame_stack.append(frame)
    print(f"Iteration {i}, action: {next_actions[0].cpu().numpy()}")

traj_actions = torch.stack(traj_actions, dim=1)
traj_states = torch.tensor(states, dtype=torch.float32, device="cuda")

In [None]:
curve = loss_curves[5]
# Plot the loss curve with i as the x axis
sns.lineplot(x=range(len(curve)), y=curve)

In [None]:
# Show the video
media.show_video(frame_stacks[6], fps=30)

In [None]:
# Load data from data.npz
data = np.load("../data.npz")
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device="cuda") for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data["observations"], data["actions"])
# observations, actions = observations[:4096], actions[:4096]

In [None]:
latent_traj_states = state_encoder(traj_states)
latent_traj_actions = action_encoder((traj_actions, traj_states[..., :-1, :]))
latent_initial_state = state_encoder(initial_state[None])

predicted_fut_latent_states, mask = transiton_model(
    latent_initial_state, latent_traj_actions[None], return_mask=True
)
predicted_fut_states = state_decoder(predicted_fut_latent_states)
actual_fut_states = traj_states[1:]
latent_traj_states = state_encoder(traj_states)
actual_fut_latent_states = state_encoder(actual_fut_states)

In [None]:
actual_fut_latent_states = state_encoder(actual_fut_states)

In [None]:
traj_actions

In [None]:
traj_states

In [None]:
reconstructed_traj_actions = action_decoder(
    (latent_traj_actions, latent_traj_states[..., :-1, :])
)

In [None]:
reconstructed_traj_actions

In [None]:
predicted_fut_latent_states

In [None]:
actual_fut_latent_states

In [None]:
predicted_fut_states

In [None]:
actual_fut_states

In [None]:
latent_state_sampler = PBallSampler(4, 1, 2.0, device="cuda")
latent_action_sampler = PBallSampler(2, 1, 1.0, device="cuda")
latent_action_state_sampler = lambda n: (
    latent_action_sampler(n),
    latent_state_sampler(n),
)

In [None]:
action_samples = action_decoder(latent_action_state_sampler(4096))

In [None]:
action_samples

In [None]:
actions