# Imports

In [7]:
import collections
import datetime
import functools
import math
import os
import sys
from typing import Callable, Optional

if "PyTorch_VAE" not in sys.path:
    sys.path.append("PyTorch_VAE")

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from tqdm.notebook import tqdm, trange

from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.model.diffusion import conditional_unet1d
from inverse_dynamics import InverseDynamicsCNN
from state_diffusion import Diffusion
from utils import normalize_pn1, normalize_standard_normal
from vae import VanillaVAE

# Args

In [11]:
device = "cpu"
data_path = (
    "/nas/ucb/ebronstein/lsdp/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
)

# VAE
vae_path = "models/pusht_vae/vae_32_20240403.pt"
vae_latent_dim = 32

# Diffusion
diffusion_load_dir = "models/diffusion/pusht_unet1d_img_128_256_512_1024_edim_256_obs_8_pred_8_bs_256_lr_0.0003_e_250_ema_norm_latent_uniform/2024-05-06_01-09-27"
n_obs_history = 8
n_pred_horizon = 8
diffusion_step_embed_dim = 256
normalize_latent = "uniform"
use_ema_helper = True
lr = 3e-4

# Inverse dynamics
inv_dyn_path = "models/inverse_dynamics/pusht_cnn-img-obs_5-bs_256-lr_0.0001-epochs_10-train_on_recon_True-latent_dim_32/2024-05-06_16-57-18/inverse_dynamics_final.pt"
inv_dyn_n_obs_history = 5

# Load dataset

In [3]:
dataset = PushTImageDataset(data_path)

# Load models

In [4]:
# Load the VAE
img_data = (
    torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2).float()
)
N, C, H, W = img_data.shape
vae = VanillaVAE(in_channels=C, in_height=H, in_width=W, latent_dim=vae_latent_dim).to(
    device
)
vae.load_state_dict(torch.load(vae_path))

<All keys matched successfully>

In [5]:
# Load the diffusion model.
if n_pred_horizon == 1:
    down_dims = [128, 256]
elif n_pred_horizon == 4:
    down_dims = [128, 256, 512]
elif n_pred_horizon == 8:
    down_dims = [128, 256, 512, 1024]
else:
    raise NotImplementedError()

STATE_DIM = vae_latent_dim
global_cond_dim = STATE_DIM * n_obs_history
diff_model = conditional_unet1d.ConditionalUnet1D(
    input_dim=STATE_DIM,
    down_dims=down_dims,
    diffusion_step_embed_dim=diffusion_step_embed_dim,
    global_cond_dim=global_cond_dim,
).to(device)

# Make the observation normalizer.
if normalize_latent == "uniform":
    latent_min = np.load(os.path.join(diffusion_load_dir, "latent_min.npy"))
    latent_max = np.load(os.path.join(diffusion_load_dir, "latent_max.npy"))
    obs_normalizer = functools.partial(
        normalize_pn1,
        min_val=torch.tensor(latent_min, dtype=torch.float32, device=device),
        max_val=torch.tensor(latent_max, dtype=torch.float32, device=device),
    )
elif normalize_latent == "standard_normal":
    latent_mean = np.load(os.path.join(diffusion_load_dir, "latent_mean.npy"))
    latent_std = np.load(os.path.join(diffusion_load_dir, "latent_std.npy"))
    obs_normalizer = functools.partial(
        normalize_standard_normal,
        mean=torch.tensor(latent_mean, dtype=torch.float32, device=device),
        std=torch.tensor(latent_std, dtype=torch.float32, device=device),
    )
else:
    raise NotImplementedError()

optim_kwargs = dict(lr=lr)
diffusion = Diffusion(
    train_data=None,
    test_data=None,
    obs_normalizer=obs_normalizer,
    model=diff_model,
    n_epochs=0,
    optim_kwargs=optim_kwargs,
    device=device,
    use_ema_helper=use_ema_helper,
)

In [12]:
# Load the inverse dynamics model.

hidden_dims = None
N, H, W, C = dataset.replay_buffer["img"].shape
N, action_dim = dataset.replay_buffer["action"].shape
inv_dyn_model = InverseDynamicsCNN(
    C, H, W, action_dim, inv_dyn_n_obs_history, hidden_dims=hidden_dims
).to(device)

inv_dyn_model.load_state_dict(torch.load(inv_dyn_path))

<All keys matched successfully>