In [1]:
import os

import gymnasium as gym
import mediapy as media
import numpy as np
import torch
from tqdm import tqdm

from nets import *

In [2]:
# Set mujoco to use EGL
os.environ["MUJOCO_GL"] = "egl"

In [3]:
# Load models
state_encoder = torch.load('trained_net_params/state_encoder.pt')
action_encoder = torch.load('trained_net_params/action_encoder.pt')
transiton_model = torch.load('trained_net_params/transition_model.pt')
state_decoder = torch.load('trained_net_params/state_decoder.pt')
action_decoder = torch.load('trained_net_params/action_decoder.pt')

In [4]:
# Load data from data.npz
data = np.load("data.npz")
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device="cuda") for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data["observations"], data["actions"])
# observations, actions = observations[:4096], actions[:4096]

In [5]:
action_space_size = 1.75

In [6]:
# Now let's optimize a trajectory

actions = torch.randn(32, 64, 2, device="cuda")
actions = actions / torch.norm(actions, p=1, dim=-1, keepdim=True)
actions = actions * torch.rand((32, 64, 1), device="cuda") * action_space_size

actor = ActorPolicy(
    2,
    action_space_size,
    state_encoder,
    transiton_model,
    action_decoder,
    horizon=128,
    iters=32,
).cuda()

env = gym.make("PointMaze_Open-v3", render_mode="rgb_array")
obs, info = env.reset()
state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")
latent_action_plan = None

initial_state = state
target_state = torch.tensor(
    [*obs["desired_goal"], 0, 0], dtype=torch.float32, device="cuda"
)

actions = []
states = [initial_state]

frames = [env.render()]
for i in tqdm(range(1024), disable=True):
    next_action, latent_action_plan = actor(
        state[None], target_state[None], latent_action_plan
    )
    obs, rew, term, trunc, info = env.step(next_action[0].cpu().numpy())
    state = torch.tensor(obs["observation"], dtype=torch.float32, device="cuda")

    actions.append(next_action[0].cpu().numpy())
    states.append(state)

    frames.append(env.render())
    print(f"Iteration {i}, action: {next_action[0].cpu().numpy()}")

traj_actions = torch.tensor(actions, dtype=torch.float32, device="cuda")
traj_states = torch.stack(states, dim=0)

  return F.mse_loss(input, target, reduction=self.reduction)


0.0016476123128086329
0.0002711997367441654
Iteration 0, action: [-0.10121667  0.09673527]
0.00023550081823486835
0.0001568745938129723
Iteration 1, action: [-0.47230813  0.28171405]
0.00020006258273497224
0.00012555134890135378
Iteration 2, action: [0.01495984 0.03982478]
0.0001688539341557771
0.00013091758592054248
Iteration 3, action: [-0.02575747  0.6828878 ]
0.00016572405002079904
0.0001425862719770521
Iteration 4, action: [ 0.07523281 -0.07610372]
0.0001398722524754703
0.00012603081995621324
Iteration 5, action: [-0.01858988 -0.13432772]
0.00011534995428519323
0.00011078581883339211
Iteration 6, action: [0.09451473 0.440294  ]
0.00012193428119644523
9.205375681631267e-05
Iteration 7, action: [0.05704945 0.1328618 ]
0.00011481788533274084
8.686818910064176e-05
Iteration 8, action: [ 0.3117628 -0.4455597]
0.00010486644168850034
0.00010575151100056246
Iteration 9, action: [ 0.07011877 -0.06403125]
0.00022357077978085726
0.00011882802937179804
Iteration 10, action: [-0.24105424 -0.39

  traj_actions = torch.tensor(actions, dtype=torch.float32, device="cuda")


In [7]:
# Show the video
media.show_video(frames, fps=30)

0
This browser does not support the video tag.


In [8]:
latent_traj_states = state_encoder(traj_states)
latent_traj_actions = action_encoder(torch.cat([traj_actions, latent_traj_states[..., 1:, :]], dim=-1))
latent_initial_state = state_encoder(initial_state[None])

predicted_fut_latent_states = transiton_model(latent_initial_state, latent_traj_actions[None])
predicted_fut_states = state_decoder(predicted_fut_latent_states)
actual_fut_states = traj_states[1:]

In [9]:
predicted_fut_states

tensor([[[0.6651, 1.6712, 0.0901, 0.4258],
         [1.0710, 1.7534, 0.0493, 0.5042],
         [0.4225, 1.6327, 0.3684, 0.5143],
         ...,
         [0.3544, 1.6384, 0.7449, 0.4960],
         [0.5483, 1.6898, 0.6057, 0.4980],
         [0.6040, 1.7182, 0.4634, 0.4689]]], device='cuda:0',
       grad_fn=<ViewBackward0>)

In [10]:
actual_fut_states

tensor([[ 1.2041, -0.7539, -0.0241,  0.0230],
        [ 1.2028, -0.7530, -0.1365,  0.0901],
        [ 1.2015, -0.7520, -0.1326,  0.0993],
        ...,
        [ 2.3988,  0.9625, -0.0146,  4.8989],
        [ 2.3987,  1.0108, -0.0101,  4.8338],
        [ 2.3985,  1.0608, -0.0153,  4.9980]], device='cuda:0')