In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
    
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Tuple
from pprint import pprint

import dcargs
import glob
import hydra
import numpy as np
import torch

import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from torch.utils.data.dataloader import DataLoader

from research.logger import WandBLogger, WandBLoggerConfig
from research.mtm.models.mtm_model import MaskedDP, MTMConfig, make_plots_with_masks
from research.mtm.tokenizers.base import Tokenizer, TokenizerManager
import mediapy as media
from research.mtm.train import RunConfig
import matplotlib.pyplot as plt
from collections import defaultdict
from research.utils.plot_utils import PlotHandler as ph
from pathlib import Path

%matplotlib inline

In [None]:
# cheetah
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-05_22-24-33/98_+experiments=d4rl_discrete,args.mask_patterns=[FULL_RANDOM,RANDOM,GOAL,GOAL_N,ID,FD],dataset.env_name=halfcheetah-expert-v2"

In [None]:
# hopper
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-12-05_18-37-25/56_+experiments=d4rl_mixed,args.mask_patterns=[FULL_RANDOM,RANDOM,FD],dataset.env_name=hopper-expert-v2"

In [None]:
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2023-01-04_09-08-01/0_args.traj_length=2,dataset.env_name=walker2d-medium-expert-v2,model_config.task=rcbc,wandb.project=rcbc_med_exp"

In [None]:
# find checkpoints in the directory
steps = []
names = []
paths_ = os.listdir(path)
for name in [os.path.join(path, n) for n in paths_ if "pt" in n]:
    step = os.path.basename(name).split("_")[-1].split(".")[0]
    steps.append(int(step))
    names.append(name)
    print(name)

ckpt_path = names[np.argmax(steps)]

In [None]:
steps

In [None]:
# ckpt_path = '/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-17_18-22-57/1_+experiments=exorl_continuous_rew_qpos,args.mask_patterns=[RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=1,args.model_config.n_enc_layer=1,args.model_config.n_head=4/model_640000.pt'

In [None]:
folder = Path("files/d4rl")
folder.mkdir(parents=True, exist_ok=True)

In [None]:
torch.load(ckpt_path)["step"]

In [None]:
hydra_cfg = OmegaConf.load(os.path.join(path, ".hydra/config.yaml"))

In [None]:
hydra_cfg.dataset

In [None]:
cfg = hydra.utils.instantiate(hydra_cfg.args)
pprint(cfg)

In [None]:
hydra_cfg

In [None]:
train_dataset, val_dataset = hydra.utils.call(
    hydra_cfg.dataset, seq_steps=cfg.traj_length
)
print("Train set size =", len(train_dataset))
print("Validation set size =", len(val_dataset))

tokenizers: Dict[str, Tokenizer] = {
    k: hydra.utils.call(v, key=k, train_dataset=train_dataset)
    for k, v in hydra_cfg.tokenizers.items()
}
tokenizer_manager = TokenizerManager(tokenizers)
discrete_map: Dict[str, bool] = {}
for k, v in tokenizers.items():
    discrete_map[k] = v.discrete
print(tokenizers)

train_loader = DataLoader(
    train_dataset,
    # shuffle=True,
    pin_memory=True,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
val_loader = DataLoader(
    val_dataset,
    # shuffle=False,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
train_batch = next(iter(train_loader))
tokenized = tokenizer_manager.encode(train_batch)
data_shapes = {}
for k, v in tokenized.items():
    data_shapes[k] = v.shape[-2:]
print(data_shapes)


In [None]:
env = val_dataset.dataset.env

In [None]:
env.reset()
image = env.sim.render(640, 480, camera_name="track")[::-1]
plt.imshow(image)

In [None]:
sample_trajectory_with_metadata = val_dataset[0]

In [None]:
sample_trajectory_with_metadata["states"]

In [None]:
help(env.sim.set_state_from_flattened)

In [None]:
phys_state = np.zeros(len(sample_trajectory_with_metadata["states"][0]) + 2)
phys_state[2:] = sample_trajectory_with_metadata["states"][0]
env.sim.set_state_from_flattened(phys_state)
env.sim.forward()

In [None]:
image = env.sim.render(640, 480, camera_name="track")[::-1]
plt.imshow(image)

# Create the model

In [None]:
model = MaskedDP(data_shapes, cfg.model_config)
model.to(cfg.device)
model.train()

# load weights
model.load_state_dict(torch.load(ckpt_path)["model"])
model.eval()

t_len = cfg.model_config.traj_length
prediction_steps = t_len - 1

In [None]:
env.reset()
env.sim.set_state_from_flattened(phys_state)
env.sim.forward()
    
images_open_loop = [env.sim.render(640, 480, camera_name="track")[::-1]]
execute_actions = sample_trajectory_with_metadata["actions"]

for idx, action in enumerate(execute_actions):
    obs = env.step(action)[0]
    image = env.sim.render(640, 480, camera_name="track")[::-1]
    images_open_loop.append(image)

In [None]:
media.show_video(images_open_loop, fps=30)
media.write_video(folder / "gt.gif", np.array(images_open_loop), codec='gif')

In [None]:
# run closed loop
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["states"])
    .to(cfg.device, torch.float32)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
    "rewards": torch.from_numpy(sample_trajectory_with_metadata["rewards"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.ones(batch_torch["states"].shape[1])
# state_mask[1:-1] = 0
action_mask = torch.zeros(batch_torch["actions"].shape[1])
reward_mask = torch.zeros(batch_torch["rewards"].shape[1])
masks = {"states": state_mask, "actions": action_mask, "rewards": reward_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

env.sim.set_state_from_flattened(phys_state)
env.sim.forward()

images_close_loop = [env.sim.render(640, 480)[::-1]]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["states"][0])


for i in range(prediction_steps):
    encoded_batch = tokenizer_manager.encode(batch_torch)
    predicted_trajectories = model(encoded_batch, masks_torch)
    decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)
    
    a = decoded_trajectories["actions"][0][i].detach().cpu().numpy()
    batch_torch["actions"][0][i] = torch.tensor(a, device="cuda")
    traj_real_cl["actions"].append(a)
    ret = env.step(np.clip(a, -1, 1))
    obs = ret[0]
    rew = ret[1]
    traj_real_cl["rewards"].append([rew])
    traj_real_cl["states"].append(obs)
    image = env.sim.render(640, 480)[::-1]
    images_close_loop.append(image)
    masks["states"][i] = 1
    masks["actions"][i] = 1
    masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}
    batch_torch["states"][0][i+1] = torch.tensor(obs, device="cuda")

In [None]:
media.show_video(images_close_loop, fps=30)
media.write_video(folder / "close_loop.gif", images_close_loop, fps=30, codec='gif')

In [None]:
decoded_trajectories.keys()

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = batch_torch[k].cpu().numpy()[0]
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()

In [None]:
# run closed loop
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["states"])
    .to(cfg.device, torch.float32)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
    "rewards": torch.from_numpy(sample_trajectory_with_metadata["rewards"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.ones(batch_torch["states"].shape[1])
# state_mask[1:-1] = 0
action_mask = torch.zeros(batch_torch["actions"].shape[1])
reward_mask = torch.zeros(batch_torch["rewards"].shape[1])
masks = {"states": state_mask, "actions": action_mask, "rewards": reward_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

env.sim.set_state_from_flattened(phys_state)
env.sim.forward()

images_close_loop = [env.sim.render(640, 480)[::-1]]
traj_real_cl = defaultdict(list)
traj_real_cl["states"].append(sample_trajectory_with_metadata["states"][0])
for act in decoded_trajectories["actions"][0].detach().cpu().numpy():
    traj_real_cl["actions"].append(act)
    ret = env.step(np.clip(act, -1, 1))
    obs = ret[0]
    rew = ret[1]
    traj_real_cl["rewards"].append([rew])
    traj_real_cl["states"].append(obs)
    image = env.sim.render(640, 480)[::-1]
    images_close_loop.append(image)

In [None]:
media.show_video(images_open_loop, fps=30)
media.write_video(folder / "open_loop.gif", np.array(images_open_loop), codec='gif')

In [None]:

max_n_plots = 3
for k, _ in decoded_trajectories.items():
    traj = batch_torch[k].cpu().numpy()[0]
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_cl[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()

In [None]:
# run closed loop
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["states"])
    .to(cfg.device, torch.float32)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
    "rewards": torch.from_numpy(sample_trajectory_with_metadata["rewards"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.ones(batch_torch["states"].shape[1])
state_mask[1:] = 0
action_mask = torch.zeros(batch_torch["actions"].shape[1])
reward_mask = torch.zeros(batch_torch["rewards"].shape[1])
masks = {"states": state_mask, "actions": action_mask, "rewards": reward_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

In [None]:
max_n_plots = 3

for k, _ in decoded_trajectories.items():
    traj = batch_torch[k].cpu().numpy()[0]
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()

In [None]:
# load 5 models
# do take same history and pass into all models
# pick top k based on rewards
#