In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
    
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Tuple
from pprint import pprint

import dcargs
import glob
import hydra
import numpy as np
import torch

import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from torch.utils.data.dataloader import DataLoader

from research.logger import WandBLogger, WandBLoggerConfig
from research.mtm.models.mtm_model import MaskedDP, MTMConfig, make_plots_with_masks
from research.mtm.tokenizers.base import Tokenizer, TokenizerManager
import mediapy as media
from research.mtm.train import RunConfig
import matplotlib.pyplot as plt
from collections import defaultdict
from research.utils.plot_utils import PlotHandler as ph
from pathlib import Path

%matplotlib inline

In [None]:
# discrete
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-26_11-54-07/1_+experiments=yoga_discrete,args.mask_patterns=[RANDOM,GOAL,ID,FD]"
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-27_01-26-57/13_+experiments=yoga_discrete,args.mask_patterns=[GOAL]"
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-09_10-17-52/4_+experiments=yoga_discrete,args.learning_rate=0.0003,args.mask_patterns=[RANDOM,GOAL],args.weight_decay=0.001"
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-09_11-13-11/4_+experiments=yoga_discrete,args.learning_rate=0.0003,args.mask_patterns=[FULL_RANDOM,RANDOM,GOAL],args.weight_decay=0.001"
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-09_11-13-11/9_+experiments=yoga_discrete,args.learning_rate=0.0003,args.mask_patterns=[FULL_RANDOM,RANDOM,GOAL,GOAL_N,ID,FD],args.weight_decay=0.001"

In [None]:
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-27_01-26-57/19_+experiments=yoga_discrete,args.mask_patterns=[RANDOM,GOAL,GOAL_N,ID,FD]"

In [None]:
# cont
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-29_07-28-03/1_+experiments=yoga_cont,args.mask_patterns=[RANDOM,GOAL,ID,FD]"

In [None]:
# # # cont goal reaching
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-29_07-28-03/9_+experiments=yoga_cont,args.mask_patterns=[RANDOM,GOAL,GOAL_N,ID,FD]"

In [None]:
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-29_07-28-03/3_+experiments=yoga_cont,args.mask_patterns=[GOAL]"

In [None]:
# all, actions only
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-27_01-26-31/29_+experiments2=yoga_discrete_actions,args.mask_patterns=[RANDOM,GOAL,GOAL_N,ID,FD]"

In [None]:
# all except goal_all, actions only
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-27_01-26-31/21_+experiments2=yoga_discrete_actions,args.mask_patterns=[RANDOM,GOAL,ID,FD]"

In [None]:
# # with full random
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-02_15-24-55/9_+experiments=yoga_discrete,args.mask_patterns=[FULL_RANDOM,RANDOM,GOAL,GOAL_N,ID,FD]"

In [None]:
# # with full random
# path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-02_15-24-55/9_+experiments=yoga_discrete,args.mask_patterns=[FULL_RANDOM,RANDOM,GOAL,GOAL_N,ID,FD]"

In [None]:
# find checkpoints in the directory
steps = []
names = []
paths_ = os.listdir(path)
for name in [os.path.join(path, n) for n in paths_ if "pt" in n]:
    step = os.path.basename(name).split("_")[-1].split(".")[0]
    steps.append(int(step))
    names.append(name)
    print(name)

ckpt_path = names[np.argmax(steps)]

In [None]:
# ckpt_path = '/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-17_18-22-57/1_+experiments=exorl_continuous_rew_qpos,args.mask_patterns=[RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=1,args.model_config.n_enc_layer=1,args.model_config.n_head=4/model_640000.pt'

In [None]:
torch.load(ckpt_path)["step"]

In [None]:
hydra_cfg = OmegaConf.load(os.path.join(path, ".hydra/config.yaml"))

In [None]:
hydra_cfg.dataset

In [None]:
# make datasets smaller for easier loading
hydra_cfg.dataset.train_max_size = 10000000
hydra_cfg.dataset.val_max_size = 10000

In [None]:
cfg = hydra.utils.instantiate(hydra_cfg.args)
pprint(cfg)

In [None]:
train_dataset, val_dataset = hydra.utils.call(
    hydra_cfg.dataset, seq_steps=cfg.model_config.traj_length
)
print("Train set size =", len(train_dataset))
print("Validation set size =", len(val_dataset))

tokenizers: Dict[str, Tokenizer] = {
    k: hydra.utils.call(v, key=k, train_dataset=train_dataset)
    for k, v in hydra_cfg.tokenizers.items()
}
tokenizer_manager = TokenizerManager(tokenizers)
discrete_map: Dict[str, bool] = {}
for k, v in tokenizers.items():
    discrete_map[k] = v.discrete
print(tokenizers)

train_loader = DataLoader(
    train_dataset,
    # shuffle=True,
    pin_memory=True,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
val_loader = DataLoader(
    val_dataset,
    # shuffle=False,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
train_batch = next(iter(train_loader))
tokenized = tokenizer_manager.encode(train_batch)
data_shapes = {}
for k, v in tokenized.items():
    data_shapes[k] = v.shape[-2:]
print(data_shapes)


In [None]:
env = val_dataset._env

In [None]:
env.reset()
image = env.physics.render(480, 640)
plt.imshow(image)

In [None]:
sample_trajectory_with_metadata = val_dataset.sample(1)

In [None]:
sample_trajectory_with_metadata.keys()

In [None]:
time = 6
state = sample_trajectory_with_metadata["physics"][time]
state0 = sample_trajectory_with_metadata["physics"][0]
#env.reset()
with env.physics.reset_context():
    env.physics.set_state(state)
    
obs = env.task.get_observation(env.physics)
print(obs.keys())

In [None]:
env.physics.get_state()
np.testing.assert_allclose(obs["orientations"], sample_trajectory_with_metadata["observations"][time][0:14])
obs["orientations"] - sample_trajectory_with_metadata["observations"][time][0:14]

In [None]:
sample_trajectory_with_metadata.keys()

In [None]:
print(time)
action = sample_trajectory_with_metadata["actions"][time]
new_obs = env.step(action)[0]
new_obs - sample_trajectory_with_metadata["observations"][time + 1]

In [None]:
env.physics.get_state() - sample_trajectory_with_metadata["physics"][time + 1]

In [None]:
action = sample_trajectory_with_metadata["actions"][time + 1]
new_obs = env.step(action)[0]
new_obs - sample_trajectory_with_metadata["observations"][time + 2]

In [None]:
# set env physics
env.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])

# rollout actions
actions = sample_trajectory_with_metadata["actions"]
_obs = sample_trajectory_with_metadata["observations"]
images = [env.physics.render(480, 640, 0)]
for idx, action in enumerate(actions):
    obs = env.step(action)[0]
    image = env.physics.render(480, 640, 0)
    if idx < len(actions) - 1:
        np.testing.assert_allclose(obs, _obs[idx+1], 1e-1, 1e-1)
    images.append(image)
# media.show_video(images, fps=30)

In [None]:
lie_back = [ -1.2 ,  0. ,  -1.57,  0, 0. , 0.0, 0, -0.,  0.0]
lie_front = [-1.2, -0, 1.57, 0, 0, 0, 0, 0., 0.]
legs_up = [ -1.24 ,  0. ,  -1.57,  1.57, 0. , 0.0,  1.57, -0.,  0.0]

kneel = [ -0.5 ,  0. ,  0,  0, -1.57, -0.8,  1.57, -1.57,  0.0]
side_angle = [ -0.3 ,  0. ,  0.9,  0, 0, -0.7,  1.87, -1.07,  0.0]
stand_up = [-0.15, 0., 0.34, 0.74, -1.34, -0., 1.1, -0.66, -0.1]

lean_back = [-0.27, 0., -0.45, 0.22, -1.5, 0.86, 0.6, -0.8, -0.4]
boat = [ -1.04 ,  0. ,  -0.8,  1.6, 0. , 0.0, 1.6, -0.,  0.0]
bridge = [-1.1, 0., -2.2, -0.3, -1.5, 0., -0.3, -0.8, -0.4]

head_stand = [-1, 0., -3, 0.6, -1, -0.3, 0.9, -0.5, 0.3]
one_feet = [-0.2, 0., 0, 0.7, -1.34, 0.5, 1.5, -0.6, 0.1]
arabesque = [-0.34, 0., 1.57, 1.57, 0, 0., 0, -0., 0.]

down_with_leg_out = [-1.05549, -0.4248, -2.1923, -0.3573, -1.509, 0.017559  , -0.358, -0.41552893, -0.79436103]

In [None]:
env.physics.data.qpos = kneel

In [None]:
env.physics.forward()
phy_state = env.physics.get_state()

In [None]:
folder = Path("files/yoga")
folder.mkdir(parents=True, exist_ok=True)

In [None]:
_img= env.physics.render(480, 640, 0)
media.write_image(folder / "goal.png", _img)
plt.imshow(_img)

In [None]:
values = list(env.task.get_observation(env.physics).values())

In [None]:
state = np.concatenate([values[0], np.array([values[1]]), values[2]])
# state = np.concatenate([values[0], np.array([values[1]])])
goal_state = state

In [None]:
t_len = cfg.model_config.traj_length

In [None]:
state[None].shape

In [None]:
torch_states  = torch.from_numpy(state[None]).repeat(t_len, 1)

# Create the model

In [None]:
model = MaskedDP(data_shapes, cfg.model_config)
model.to(cfg.device)
model.train()

# load weights
model.load_state_dict(torch.load(ckpt_path)["model"])
model.eval()
print()

In [None]:
state.shape

In [None]:
torch_states.shape

In [None]:
batch_torch = {
    "states": torch_states.to(cfg.device, torch.float32).unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
prediction_steps = t_len - 1
state_mask = torch.ones(batch_torch["states"].shape[1])
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}


# #shorten everything to the prediction_steps
# for k in masks.keys():
#     masks_torch[k] = masks_torch[k][:prediction_steps+1]
#     batch_torch[k] = batch_torch[k][:prediction_steps+1]

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

In [None]:
# def get_actions():

In [None]:
env.reset()
with env.physics.reset_context():
    env.physics.set_state(phy_state)
    
images_open_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_ol = defaultdict(list)
traj_real_ol["states"].append(goal_state)

In [None]:
env.reset()
with env.physics.reset_context():
    env.physics.set_state(phy_state)
    
images_open_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_ol = defaultdict(list)
traj_real_ol["states"].append(goal_state)

for idx, action in enumerate(execute_actions):
    traj_real_ol["actions"].append(action)
    obs = env.step(action)[0]
    traj_real_ol["states"].append(obs)
    image = env.physics.render(480, 640, 0)
    images_open_loop.append(image)
    
traj_real_ol["states"] = traj_real_ol["states"][:-1]
    # compare obs against data
#     _obs = sample_trajectory_with_metadata["observations"]
#     np.testing.assert_allclose(obs, _obs[idx], atol=1e-5)

In [None]:
asdf = np.array(images_open_loop)

In [None]:
asdf.shape

In [None]:
media.show_video(images_open_loop, fps=30)
media.write_video(folder / "open_loop.gif", np.array(images_open_loop), codec='gif')

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = batch_torch[k][0].detach().cpu().numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_ol[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i), np.max(real_i))
        vmin = min(np.min(gt_i), np.min(re_i), np.min(real_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
# run closed loop
batch_torch = {
    "states": torch_states.to(cfg.device, torch.float32).unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.ones(batch_torch["states"].shape[1])
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

# #shorten everything to the prediction_steps
# for k in masks.keys():
#     masks_torch[k] = masks_torch[k][:prediction_steps+1]
#     batch_torch[k] = batch_torch[k][:prediction_steps+1]

env.reset()
with env.physics.reset_context():
    env.physics.set_state(phy_state)

images_close_loop = [env.physics.render(480, 640, 0)]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(goal_state)


for i in range(prediction_steps):
    encoded_batch = tokenizer_manager.encode(batch_torch)
    predicted_trajectories = model(encoded_batch, masks_torch)
    decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)
    
    a = decoded_trajectories["actions"][0][i].detach().cpu().numpy()
    batch_torch["actions"][0][i] = torch.tensor(a, device="cuda")
    traj_real_cl["actions"].append(a)
    obs = env.step(a)[0]
    traj_real_cl["states"].append(obs)
    image = env.physics.render(480, 640, 0)
    images_close_loop.append(image)
    masks["states"][i] = 1
    masks["actions"][i] = 1
    masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}
    batch_torch["states"][0][i + 1] = torch.tensor(obs, device="cuda")

In [None]:
media.show_video(images_close_loop, fps=30)
media.write_video(folder / "close_loop.gif", images_close_loop, fps=30, codec='gif')

In [None]:
np.array(traj_real_cl[k])[0].shape

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = torch_states.numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
diff_cl = goal_state - traj_real_cl["states"][-1]

In [None]:
np.sum(diff_cl**2)

In [None]:
diff_ol = goal_state - traj_real_ol["states"][-1]

In [None]:
np.sum(diff_ol**2)

In [None]:
traj_real_cl["states"][1].shape

In [None]:
# run closed loop
batch_torch = {
    "states": torch_states.to(cfg.device, torch.float32).unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

# #shorten everything to the prediction_steps
# for k in masks.keys():
#     masks_torch[k] = masks_torch[k][:prediction_steps+1]
#     batch_torch[k] = batch_torch[k][:prediction_steps+1]

env.reset()
with env.physics.reset_context():
    env.physics.set_state(phy_state)

images_close_loop = [env.physics.render(480, 640, 0)]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["observations"][0])

    

for i in range(prediction_steps):
    encoded_batch = tokenizer_manager.encode(batch_torch)
    predicted_trajectories = model(encoded_batch, masks_torch)
    decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)
    
    a = decoded_trajectories["actions"][0][i].detach().cpu().numpy()
    batch_torch["actions"][0][i] = torch.tensor(a, device="cuda")
    traj_real_cl["actions"].append(a)
    time_step = env.step(a)[0]
    traj_real_cl["states"].append(time_step)
    image = env.physics.render(480, 640, 0)
    images_close_loop.append(image)
    masks["states"][i] = 1
    masks["actions"][i] = 1
    masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}
    batch_torch["states"][0][i + 1] = torch.tensor(time_step, device="cuda")

In [None]:
media.show_video(images_close_loop, fps=30)
media.write_video(folder / "close_loop_goal.gif", images_close_loop, fps=30, codec='gif')

In [None]:
max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = torch_states.numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
diff_cl = goal_state - traj_real_cl["states"][-1]
np.sum(diff_cl**2)

In [None]:
goal_state.shape

In [None]:
goal_reach_states = torch_states.clone()
loc = 0
goal_reach_states[:3] = torch.tensor(sample_trajectory_with_metadata["observations"][loc:loc+3, :])
physics_start = sample_trajectory_with_metadata["physics"][loc]
alpha = torch.linspace(0, 1, t_len)
goal_reach_states = goal_reach_states[0] + (goal_reach_states[-1] - goal_reach_states[0]) * alpha[:, None]


# run closed loop
batch_torch = {
    "states": goal_reach_states.to(cfg.device, torch.float32).unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

env.reset()
with env.physics.reset_context():
    env.physics.set_state(physics_start)

images_close_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["observations"][0])



for idx, action in enumerate(execute_actions):
    traj_real_cl["actions"].append(action)
    obs = env.step(action)[0]
    traj_real_cl["states"].append(obs)
    image = env.physics.render(480, 640, 0)
    images_close_loop.append(image)
    
traj_real_cl["states"] = traj_real_cl["states"][:-1]

In [None]:
media.show_video(images_close_loop, fps=30)
media.write_video(folder / "open_loop_reach_from_random.gif", images_close_loop, fps=30, codec='gif')

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = goal_reach_states.numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
goal_reach_states = torch_states.clone()
loc = 0
goal_reach_states[:3] = torch.tensor(sample_trajectory_with_metadata["observations"][loc:loc+3, :])
physics_start = sample_trajectory_with_metadata["physics"][loc]
alpha = torch.linspace(0, 1, t_len)
goal_reach_states = goal_reach_states[0] + (goal_reach_states[-1] - goal_reach_states[0]) * alpha[:, None]


# run closed loop
batch_torch = {
    "states": goal_reach_states.to(cfg.device, torch.float32).unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

env.reset()
with env.physics.reset_context():
    env.physics.set_state(physics_start)

images_close_loop = [env.physics.render(480, 640, 0)]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["observations"][0][:])

for i in range(prediction_steps):
    encoded_batch = tokenizer_manager.encode(batch_torch)
    predicted_trajectories = model(encoded_batch, masks_torch)
    decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)
    
    a = decoded_trajectories["actions"][0][i].detach().cpu().numpy()
    batch_torch["actions"][0][i] = torch.tensor(a, device="cuda")
    traj_real_cl["actions"].append(a)
    time_step = env.step(a)[0]
    traj_real_cl["states"].append(time_step)
    image = env.physics.render(480, 640, 0)
    images_close_loop.append(image)
    masks["states"][i] = 1
    masks["actions"][i] = 1
    masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}
    batch_torch["states"][0][i + 1] = torch.tensor(time_step, device="cuda")

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = goal_reach_states.numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))

In [None]:
media.show_video(images_close_loop, fps=30)
media.write_video(folder / "closed_loop_reach_from_random.gif", images_close_loop, fps=30, codec='gif')

In [None]:
sample_trajectory_with_metadata.keys()

In [None]:
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["observations"])
    .to(cfg.device)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[0:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

In [None]:
env.physics.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])
    
images_open_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_ol = defaultdict(list)
traj_real_ol["states"].append(sample_trajectory_with_metadata["observations"][0])

In [None]:
for idx, action in enumerate(execute_actions):
    traj_real_ol["actions"].append(action)
    obs = env.step(action)[0]
    traj_real_ol["states"].append(obs)
    image = env.physics.render(480, 640, 0)
    images_open_loop.append(image)
    
traj_real_ol["states"] = traj_real_ol["states"][:-1]
    # compare obs against data
#     _obs = sample_trajectory_with_metadata["observations"]
#     np.testing.assert_allclose(obs, _obs[idx], atol=1e-5)

In [None]:
media.show_video(images_open_loop, fps=10)

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = batch_torch[k][0].detach().cpu().numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_ol[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i), np.max(real_i))
        vmin = min(np.min(gt_i), np.min(re_i), np.min(real_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
traj_real_ol[k]

In [None]:
next(iter(train_dataset._episodes.keys()))

In [None]:
poses = []
for i in range(1, 8+1):
    p_=f'/checkpoint/aravraj/mtm_data/dmc_pose_dataset_dec7/dmc_walker_pose{i}-v1.pickle0'
    poses.append(p_)

In [None]:
poses

In [None]:
# for p in poses:
#     train_dataset._episodes[p]

In [None]:
# train_dataset._episodes

In [None]:
# # failure cases: 2, 3, 4, 5, 
# # works 0, 1, 6, 7

# for i in range(8):
#     episode = train_dataset._episodes[poses[i]]

#     tl = train_dataset._traj_length
#     # tl = 64

#     idx=0

#     obs = episode["observations"][idx : idx + tl]
#     action = episode["actions"][idx : idx + tl]
#     reward = episode["rewards"][idx : idx + tl]
#     timestep = np.arange(idx, idx + tl)[:, np.newaxis]
#     physics = episode["states"][idx: idx + tl]
#     physics_ = [p["internal_state"] for p in physics]
#     sample_trajectory_with_metadata = {
#         "observations": obs.astype(np.float32),
#         "actions": action.astype(np.float32),
#         "rewards": reward.astype(np.float32).reshape(-1, 1),
#         "timestep": 0,
#         "physics": physics_,
#     }

#     env.reset()
#     with env.physics.reset_context():
#         env.physics.set_state(sample_trajectory_with_metadata["physics"][0])

#     vis = [env.physics.render(480, 640, 0)]

#     for idx, action in enumerate(sample_trajectory_with_metadata["actions"]):
#         env.step(action)
#         image = env.physics.render(480, 640, 0)
#         vis.append(image)

#     #media.show_video(vis, fps=10)
#     media.write_video(f"gt_{i}.gif", vis, fps=30, codec='gif')

In [None]:
# not working, 4, 5
episode = train_dataset._episodes[poses[7]]

tl = train_dataset._traj_length

idx=0

obs = episode["observations"][idx : idx + tl]
action = episode["actions"][idx : idx + tl]
reward = episode["rewards"][idx : idx + tl]
timestep = np.arange(idx, idx + tl)[:, np.newaxis]
physics = episode["states"][idx: idx + tl]
physics_ = [p["internal_state"] for p in physics]
sample_trajectory_with_metadata = {
    "observations": obs.astype(np.float32),
    "actions": action.astype(np.float32),
    "rewards": reward.astype(np.float32).reshape(-1, 1),
    "timestep": 0,
    "physics": physics_,
}

In [None]:
# env.reset()

# vis = []
# for p in sample_trajectory_with_metadata["physics"]:
#     with env.physics.reset_context():
#         env.physics.set_state(p)
#     vis.append(env.physics.render(480, 640, 0))
    
# media.show_video(vis, fps=10)

In [None]:
env.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])
    
vis = [env.physics.render(480, 640, 0)]

for idx, action in enumerate(sample_trajectory_with_metadata["actions"]):
    env.step(action)
    image = env.physics.render(480, 640, 0)
    vis.append(image)
    
media.show_video(vis, fps=10)

In [None]:
hydra_cfg.tokenizers

In [None]:
# import copy
# # tokenizer_info = hydra_cfg.tokenizers
# # tokenizer_info = copy.copy(tokenizer_info)
# # tokenizer_info["actions"].num_bins = 64
# hydra_cfg.tokenizers

# tokenizers_: Dict[str, Tokenizer] = {
#     k: hydra.utils.call(v, key=k, train_dataset=train_dataset)
#     for k, v in tokenizer_info.items()
# }
# tm = TokenizerManager(tokenizers_)

# env.reset()
# with env.physics.reset_context():
#     env.physics.set_state(sample_trajectory_with_metadata["physics"][0])
    
# batch_torch = {
#     "states": torch.from_numpy(sample_trajectory_with_metadata["observations"])
#     .to(cfg.device)
#     .unsqueeze(0),
#     "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
#     .to(cfg.device)
#     .unsqueeze(0),
# }

# encoded_batch = tm.encode(batch_torch)
# decoded_trajectories = tm.decode(encoded_batch)
# act = decoded_trajectories["actions"][0].detach().cpu().numpy()

# vis = [env.physics.render(480, 640, 0)]
# for idx, action in enumerate(act):
#     env.step(action)
#     image = env.physics.render(480, 640, 0)
#     vis.append(image)
    
# media.show_video(vis, fps=10)

In [None]:
# fig, axs = plt.subplots(2, 3)
# for r in range(2):
#     for j in range(3):
#         i = r * 3 + j
#         axs[r, j].plot(act[:, i], "o", label="discrete")
#         axs[r, j].plot(sample_trajectory_with_metadata["actions"][:, i], ".", label="original")
# axs[r, j].legend()
# plt.show()


In [None]:
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["observations"])
    .to(cfg.device)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[0:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

In [None]:
env.physics.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])
    
images_open_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_ol = defaultdict(list)
traj_real_ol["states"].append(sample_trajectory_with_metadata["observations"][0])

In [None]:
for idx, action in enumerate(execute_actions):
    traj_real_ol["actions"].append(action)
    obs = env.step(action)[0]
    traj_real_ol["states"].append(obs)
    image = env.physics.render(480, 640, 0)
    images_open_loop.append(image)
    
traj_real_ol["states"] = traj_real_ol["states"][:-1]
    # compare obs against data
#     _obs = sample_trajectory_with_metadata["observations"]
#     np.testing.assert_allclose(obs, _obs[idx], atol=1e-5)

In [None]:
media.show_video(images_open_loop, fps=10)

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = batch_torch[k][0].detach().cpu().numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_ol[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i), np.max(real_i))
        vmin = min(np.min(gt_i), np.min(re_i), np.min(real_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
# closed loop

In [None]:
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["observations"])
    .to(cfg.device)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}

# goal reaching mask
state_mask = torch.zeros(batch_torch["states"].shape[1])
state_mask[:3] = 1
state_mask[-3:] = 1
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)



In [None]:
env.physics.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])

images_close_loop = [env.physics.render(480, 640, 0)]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["observations"][0][:])

In [None]:
for i in range(prediction_steps):
    encoded_batch = tokenizer_manager.encode(batch_torch)
    predicted_trajectories = model(encoded_batch, masks_torch)
    decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)
    
    a = decoded_trajectories["actions"][0][i].detach().cpu().numpy()
    batch_torch["actions"][0][i] = torch.tensor(a, device="cuda")
    traj_real_cl["actions"].append(a)
    time_step = env.step(a)[0]
    traj_real_cl["states"].append(time_step)
    image = env.physics.render(480, 640, 0)
    images_close_loop.append(image)
    masks["states"][i+1] = 1
    masks["actions"][i] = 1
    masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}
    batch_torch["states"][0][i + 1] = torch.tensor(time_step, device="cuda")

In [None]:
prediction_steps

In [None]:
media.show_video(images_close_loop, fps=10)

In [None]:
media.show_image(images_close_loop[-1])

In [None]:
max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = goal_reach_states.numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = sample_trajectory_with_metadata["observations"][:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))