In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
    
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Tuple
from pprint import pprint

import dcargs
import glob
import hydra
import numpy as np
import torch

import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from torch.utils.data.dataloader import DataLoader

from research.logger import WandBLogger, WandBLoggerConfig
from research.mtm.models.mtm_model import MaskedDP, MTMConfig, make_plots_with_masks
from research.mtm.models.mlp_model import MLPConfig, MLP_BC
from research.mtm.tokenizers.base import Tokenizer, TokenizerManager
import mediapy as media
from research.mtm.train_mlp import RunConfig
import matplotlib.pyplot as plt
from collections import defaultdict
from research.utils.plot_utils import PlotHandler as ph
from pathlib import Path
from torch.optim.lr_scheduler import LambdaLR

from research.mtm.datasets.sequence_dataset import evaluate

%matplotlib inline

In [None]:
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2023-01-04_09-08-01/0_args.traj_length=2,dataset.env_name=walker2d-medium-expert-v2,model_config.task=rcbc,wandb.project=rcbc_med_exp"

In [None]:
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2023-01-04_17-31-35/0_+exp_mlp=d4rl_discrete,args.traj_length=2,dataset.env_name=walker2d-medium-expert-v2,model_config.task=rcbc,wandb.project=debug_rcbc"

In [None]:
path = "/private/home/philippwu/mtm/outputs/mtm_mae/2023-01-04_17-51-25/2_+exp_mlp=d4rl_cont,args.traj_length=1,dataset.env_name=walker2d-medium-expert-v2,model_config.task=bc,wandb.project=debug_rcbc"

In [None]:
# find checkpoints in the directory
steps = []
names = []
paths_ = os.listdir(path)
for name in [os.path.join(path, n) for n in paths_ if "pt" in n]:
    step = os.path.basename(name).split("_")[-1].split(".")[0]
    steps.append(int(step))
    names.append(name)
    #print(name)

ckpt_path = names[np.argmax(steps)]

In [None]:
# steps

In [None]:
# ckpt_path = '/private/home/philippwu/mtm/outputs/mtm_mae/2022-11-17_18-22-57/1_+experiments=exorl_continuous_rew_qpos,args.mask_patterns=[RANDOM,GOAL,ID,FD],args.model_config.n_dec_layer=1,args.model_config.n_enc_layer=1,args.model_config.n_head=4/model_640000.pt'

In [None]:
folder = Path("files/d4rl")
folder.mkdir(parents=True, exist_ok=True)

In [None]:
torch.load(ckpt_path)["step"]

In [None]:
p = os.path.join(path, ".hydra/config.yaml")

In [None]:
p 

In [None]:
hydra_cfg = OmegaConf.load(p)

In [None]:
hydra_cfg.dataset

In [None]:
cfg = hydra.utils.instantiate(hydra_cfg.args)
pprint(cfg)

In [None]:
train_dataset, val_dataset = hydra.utils.call(
    hydra_cfg.dataset,
    # seq_steps=cfg.traj_length
    seq_steps=1,
)
print("Train set size =", len(train_dataset))
print("Validation set size =", len(val_dataset))

tokenizers: Dict[str, Tokenizer] = {
    k: hydra.utils.call(v, key=k, train_dataset=train_dataset)
    for k, v in hydra_cfg.tokenizers.items()
}
tokenizer_manager = TokenizerManager(tokenizers)
discrete_map: Dict[str, bool] = {}
for k, v in tokenizers.items():
    discrete_map[k] = v.discrete
print(tokenizers)

train_loader = DataLoader(
    train_dataset,
    # shuffle=True,
    pin_memory=True,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
    drop_last=True,
)
val_loader = DataLoader(
    val_dataset,
    # shuffle=False,
    batch_size=cfg.batch_size,
    num_workers=cfg.n_workers,
)
train_batch = next(iter(train_loader))
tokenized = tokenizer_manager.encode(train_batch)
data_shapes = {}
for k, v in tokenized.items():
    data_shapes[k] = v.shape[-2:]
print(data_shapes)


In [None]:

####################################
################## Model
####################################
model_config = hydra.utils.instantiate(hydra_cfg.model_config)

model = MLP_BC(data_shapes, hydra_cfg.args.traj_length, model_config)
model.to(cfg.device)
model.train()

# load weights
model.load_state_dict(torch.load(ckpt_path)["model"])
model.eval()

####################################
################## Model
####################################

In [None]:
env = val_dataset.dataset.env

In [None]:
env.reset()
image = env.sim.render(640, 480, camera_name="track")[::-1]
plt.imshow(image)

In [None]:
sample_trajectory_with_metadata = val_dataset[0]

In [None]:
sample_trajectory_with_metadata["states"]

In [None]:
help(env.sim.set_state_from_flattened)

In [None]:
phys_state = np.zeros(len(sample_trajectory_with_metadata["states"][0]) + 2)
phys_state[2:] = sample_trajectory_with_metadata["states"][0]
env.sim.set_state_from_flattened(phys_state)
env.sim.forward()

In [None]:
image = env.sim.render(640, 480, camera_name="track")[::-1]
plt.imshow(image)

In [None]:
returns = train_dataset.values_segmented[:, : , 0]
returns.shape
plt.hist(returns.flatten(), bins=100)
plt.show()

In [None]:
returns.flatten().shape

In [None]:
# Medium dataset
# d_dataset = hydra_cfg.dataset
# d_dataset.env_name = "walker2d-medium-v2"
# m_, _ = hydra.utils.call(
#     d_dataset, seq_steps=cfg.traj_length
# )
# plt.hist(m_.values_segmented.flatten(), bins=100)
# plt.show()


# Expert dataset
# d_dataset = hydra_cfg.dataset
# d_dataset.env_name = "walker2d-expert-v2"
# m_, _ = hydra.utils.call(
#     d_dataset, seq_steps=cfg.traj_length
# )
# plt.hist(m_.values_segmented.flatten(), bins=100)
# plt.show()

In [None]:
# d_dataset = hydra_cfg.dataset
# d_dataset.env_name = "walker2d-random-v2"
# m_, _ = hydra.utils.call(
#     d_dataset, seq_steps=cfg.traj_length
# )
# plt.hist(m_.values_segmented.flatten(), bins=100)
# plt.show()

In [None]:
train_dataset.rewards_segmented.shape

In [None]:
train_dataset.path_lengths[3:]

In [None]:
train_dataset.observation_dim

In [None]:

from research.mtm.models.mlp_model import MLP

In [None]:
print(cfg.learning_rate)
print(cfg.weight_decay)

In [None]:
model = MLP(
    train_dataset.observation_dim,
    train_dataset.action_dim,
    256,
    3
)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay
)


def _schedule(step):
    # warmp for 1000 steps
    if step < cfg.warmup_steps:
        return step / cfg.warmup_steps

    # then cosine decay
    assert cfg.num_train_steps > cfg.warmup_steps
    step = step - cfg.warmup_steps
    return 0.5 * (
        1 + np.cos(step / (cfg.num_train_steps - cfg.warmup_steps) * np.pi)
    )

scheduler = LambdaLR(optimizer, lr_lambda=_schedule)

# optimizer = torch.optim.AdamW(
#     model.parameters(), lr=1e-3, weight_decay=0.0001
# )
model.cuda()
model.train()

return_list = []
losses = []
step = 0

use_tokenizer = True

for i in range(15):
    for batch in train_loader:
        if use_tokenizer:
            batch = tokenizer_manager.encode(batch)
            actions, observations = batch["actions"][:, 0, 0, :], batch["states"][:, 0, 0, :]
        else:
            actions, observations = batch["actions"][:, 0, :], batch["states"][:, 0, :]
            
        actions = actions.to("cuda")
        observations = observations.to("cuda")

        pred_a = model(observations)
        loss = torch.mean((actions - pred_a) ** 2)
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        step += 1
    print(i, loss.item())


In [None]:
actions.shape

In [None]:

plt.plot(losses)
plt.show()

In [None]:
@torch.inference_mode()
def sample_action_bc(
    observation: np.ndarray,
    traj,
):
    if use_tokenizer:
        input_ = {"states": torch.from_numpy(observation)[None, None].to("cuda")}
        observation = tokenizer_manager.encode(input_)["states"]
        actions = model(observation)
        decoded_logits = tokenizer_manager.decode({"actions": actions})
        a = decoded_logits["actions"][-1].detach().cpu().numpy()[0]
        return a
    else:
        actions = model(torch.tensor(observation, device="cuda")[None])
        return actions[0].cpu().numpy()



In [None]:
e_d = evaluate(
    sample_action_bc,
    train_dataset.env,
    20,
    (train_dataset.observation_dim, ),
    (train_dataset.action_dim, ),
    num_videos=0,
)
e_d[0]

In [None]:
@torch.inference_mode()
def sample_action_bc(
    observation: np.ndarray,
    traj,
):
    """Sample action from the model.

    Args:
        observation (np.ndarray): observation
        traj (Trajectory): traj
    """

    input_ = {"states": torch.from_numpy(observation)[None, None].to("cuda")}
    logits, _ = model(
        tokenizer_manager.encode(input_), discrete_map, compute_loss=False
    )
    decoded_logits = tokenizer_manager.decode({"actions": logits})
    a = decoded_logits["actions"][-1].detach().cpu().numpy()
    return a[-1]


e_d = evaluate(
    sample_action_bc,
    train_dataset.env,
    10,
    (train_dataset.observation_dim, ),
    (train_dataset.action_dim, ),
    num_videos=0,
)
e_d[0]