In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
    
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Tuple
from pprint import pprint

import dcargs
import glob
import hydra
import numpy as np
import torch

import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from torch.utils.data.dataloader import DataLoader

from research.logger import WandBLogger, WandBLoggerConfig
from research.mtm.models.mtm_model import MaskedDP, MTMConfig, make_plots_with_masks
from research.mtm.tokenizers.base import Tokenizer, TokenizerManager
import mediapy as media
from research.mtm.train import RunConfig
import matplotlib.pyplot as plt
from collections import defaultdict
from research.utils.plot_utils import PlotHandler as ph


%matplotlib inline

In [None]:
# [RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=2
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-17_18-23-27/1_+experiments=exorl_continuous_rew_qpos,args.mask_patterns=[RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=2,args.model_config.n_enc_layer=2,args.model_config.n_head=4"

In [None]:
# discrete
# [RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=1
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-18_22-20-17/1_+experiments=exorl_discrete_rew_qpos,args.mask_patterns=[RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=1,args.model_config.n_enc_layer=2"

In [None]:
# find checkpoints in the directory
steps = []
names = []
paths_ = os.listdir(path)
for name in [os.path.join(path, n) for n in paths_ if "pt" in n]:
    step = os.path.basename(name).split("_")[-1].split(".")[0]
    steps.append(int(step))
    names.append(name)
    # print(name)

ckpt_path = names[np.argmax(steps)]

In [None]:
steps

In [None]:
hydra_cfg = OmegaConf.load(os.path.join(path, ".hydra/config.yaml"))
hydra_cfg.dataset.train_max_size = 10000
hydra_cfg.dataset.val_max_size = 10000
del hydra_cfg.args.wandb_config
cfg = hydra.utils.instantiate(hydra_cfg.args)
pprint(cfg)

In [None]:
train_dataset, val_dataset = hydra.utils.call(
    hydra_cfg.dataset, seq_steps=cfg.model_config.traj_length
)
print("Train set size =", len(train_dataset))
print("Validation set size =", len(val_dataset))

tokenizers: Dict[str, Tokenizer] = {
    k: hydra.utils.call(v, key=k, train_dataset=train_dataset)
    for k, v in hydra_cfg.tokenizers.items()
}
tokenizer_manager = TokenizerManager(tokenizers)
discrete_map: Dict[str, bool] = {}
for k, v in tokenizers.items():
    discrete_map[k] = v.discrete
print(tokenizers)

train_loader = DataLoader(
    train_dataset,
    # shuffle=True,
    pin_memory=True,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
val_loader = DataLoader(
    val_dataset,
    # shuffle=False,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
train_batch = next(iter(train_loader))
tokenized = tokenizer_manager.encode(train_batch)
data_shapes = {}
for k, v in tokenized.items():
    data_shapes[k] = v.shape[-2:]
print(data_shapes)


In [None]:
env = val_dataset._env

In [None]:
env.reset()
image = env.physics.render(480, 640)
plt.imshow(image)

In [None]:
sample_trajectory_with_metadata = val_dataset.sample(22, 18)

In [None]:
sample_trajectory_with_metadata.keys()

# Create the model

In [None]:
torch.load(ckpt_path)["step"]

In [None]:
model = MaskedDP(data_shapes, cfg.model_config)
model.to(cfg.device)
model.train()

# load weights
model.load_state_dict(torch.load(ckpt_path)["model"])
model.eval()
print()

In [None]:
sample_trajectory_with_metadata["observations"].shape

In [None]:
batch_torch = {
    "states": torch.from_numpy(sample_trajectory_with_metadata["observations"][:,:15])
    .to(cfg.device)
    .unsqueeze(0),
    "rewards": torch.from_numpy(sample_trajectory_with_metadata["rewards"])
    .to(cfg.device)
    .unsqueeze(0),
    "actions": torch.from_numpy(sample_trajectory_with_metadata["actions"])
    .to(cfg.device)
    .unsqueeze(0),
}


# goalreaching mask
state_mask = torch.ones(batch_torch["states"].shape[1])
action_mask = torch.zeros(batch_torch["actions"].shape[1])
masks = {"states": state_mask, "rewards": state_mask, "actions": action_mask}
masks_torch = {k: v.to(cfg.device) for k, v in masks.items()}


# #shorten everything to the prediction_steps
# for k in masks.keys():
#     masks_torch[k] = masks_torch[k][:prediction_steps+1]
#     batch_torch[k] = batch_torch[k][:prediction_steps+1]

encoded_batch = tokenizer_manager.encode(batch_torch)
predicted_trajectories = model(encoded_batch, masks_torch)
decoded_trajectories = tokenizer_manager.decode(predicted_trajectories)

In [None]:
# def get_actions():

In [None]:
env.physics.reset()
with env.physics.reset_context():
    env.physics.set_state(sample_trajectory_with_metadata["physics"][0])
    
images_open_loop = [env.physics.render(480, 640, 0)]
execute_actions = decoded_trajectories["actions"].squeeze(0).detach().cpu().numpy()
traj_real_ol = defaultdict(list)
traj_real_ol["states"].append(sample_trajectory_with_metadata["observations"][0])

In [None]:
for idx, action in enumerate(execute_actions):
    traj_real_ol["actions"].append(action)
    obs = env.step(action)
    traj_real_ol["states"].append(obs["observation"])
    image = env.physics.render(480, 640, 0)
    images_open_loop.append(image)
    
traj_real_ol["states"] = traj_real_ol["states"][:-1]
    # compare obs against data
#     _obs = sample_trajectory_with_metadata["observations"]
#     np.testing.assert_allclose(obs, _obs[idx], atol=1e-5)

In [None]:
media.show_video(images_open_loop, fps=30)

In [None]:
max_n_plots = 15
for k, _ in decoded_trajectories.items():
    if k == "rewards":
        continue
    traj = batch_torch[k][0].detach().cpu().numpy()
    pred_traj = decoded_trajectories[k][0].detach().cpu().numpy()
    mask = masks[k]
    for i in range(min(max_n_plots, traj.shape[-1])):
        gt_i = traj[:, i]
        re_i = pred_traj[:, i]
        real_i = np.array(traj_real_ol[k])[:, i]
        if len(mask.shape) == 1:
            # only along time dimension: repeat across the given dimension
            mask = mask[:, None].repeat(1, traj.shape[1])
        select_mask = mask[:, i].cpu().numpy()
        unmasked_gt_i = gt_i[select_mask == 1]
        unmasked_gt_i_index = np.arange(len(gt_i))[select_mask == 1]
        vmax = max(np.max(gt_i), np.max(re_i))
        vmin = min(np.min(gt_i), np.min(re_i))
        y_range = vmax - vmin
        with ph.plot_context() as (fig, ax):

            ax.plot(gt_i, "-o", label="ground truth")
            ax.plot(
                re_i, "-o", label="reconstructed", markerfacecolor="none"
            )
            ax.plot(
                unmasked_gt_i_index,
                unmasked_gt_i,
                "o",
                label="unmasked ground truth",
            )
            ax.plot(
                real_i, ".", label="real"
            )
            ax.set_ylim(
                vmin - y_range / 5,
                vmax + y_range / 5,
            )
            ax.legend()
            ax.set_title(f"{k}_{i}")
            plt.show()
#             eval_logs[
#                 f"{eval_name}/batch={batch_idx}|{i}_{k}"
#             ] = wandb.Image(ph.plot_as_image(fig))


In [None]:
plt.imshow(images_open_loop[0])
plt.show()

In [None]:
# plt.imshow(images_open_loop[prediction_steps])
# plt.show()
# plt.imshow(images_open_loop[prediction_steps])
# plt.show()
# plt.imshow(images_close_loop[prediction_steps])
# plt.show()

In [None]:
diff = images_gt[prediction_steps] - images_open_loop[prediction_steps]
diff = (diff - np.min(diff)) / (np.max(diff) - np.min(diff))
plt.imshow(diff)

In [None]:
diff = images_gt[prediction_steps] - images_close_loop[prediction_steps]
diff = (diff - np.min(diff)) / (np.max(diff) - np.min(diff))
plt.imshow(diff)

In [None]:
goal_state = sample_trajectory_with_metadata["observations"][-1]

In [None]:
diff_cl = goal_state - traj_real_cl["states"][-1]

In [None]:
np.sum(diff_cl**2)

In [None]:
diff_ol = goal_state - traj_real_ol["states"][-1]

In [None]:
np.sum(diff_ol**2)

In [None]:
obs = env.task.get_observation(env.physics)
print(obs.keys())

In [None]:
[o.shape for k, o in obs.items()]