# Load the checkpoint

In [None]:
import sys
import os
from typing import Dict, Callable, Tuple, List

import numpy as np
import torch
import time
import dill
import hydra
from torch.utils.data import DataLoader


from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset
from diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace

data_path = "/training_outputs/"
ckpt_path = data_path + "2024.05.21_22.06.13_flip_up_linear_interpolated_dense_action/checkpoints/epoch=0040-train_loss=0.012.ckpt"

device = torch.device('cpu')

# load checkpoint
if not ckpt_path.endswith('.ckpt'):
    ckpt_path = os.path.join(ckpt_path, 'checkpoints', 'latest.ckpt')
payload = torch.load(open(ckpt_path, 'rb'), map_location='cpu', pickle_module=dill)
cfg = payload['cfg']
print("model_name:", cfg.policy.obs_encoder.model_name)
print("dataset_path:", cfg.task.dataset.dataset_path)

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.num_inference_steps = cfg.policy.num_inference_steps # DDIM inference iterations

policy.eval().to(device)
policy.reset()

# use normalizer saved in the policy
sparse_normalizer, dense_normalizer = policy.get_normalizer()

shape_meta = cfg.task.shape_meta

# Load a dataset

In [None]:
# # load the dataset used in training
# dataset: BaseImageDataset
# dataset = hydra.utils.instantiate(cfg.task.dataset)
# assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
# print("Test Script: Creating dataloader.")
# train_dataloader = DataLoader(dataset, **cfg.dataloader)
# print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))

# load the dataset specified in config
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(
    version_base=None,
    config_path=str('../diffusion_policy/config'),
    job_name="test_app"
):
    cfg = compose(config_name="train_diffusion_unet_timm_flip_up_workspace")
    OmegaConf.resolve(cfg)

    print("Test Script: configuring dataset.")
    dataset: BaseImageDataset
    dataset = hydra.utils.instantiate(cfg.task.dataset)
    assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
    print("Test Script: Creating dataloader.")
    train_dataloader = DataLoader(dataset, **cfg.dataloader)
    print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))

# Run some tests

In [None]:
import torch.nn.functional as F
from einops import rearrange, reduce
import json

def log_action_mse(step_log, category, pred_action, gt_action):
    pred_naction = {
        'sparse': sparse_normalizer['action'].normalize(pred_action['sparse']),
        'dense': dense_normalizer['action'].normalize(pred_action['dense'])
    }
    gt_naction = {
        'sparse': sparse_normalizer['action'].normalize(gt_action['sparse']),
        'dense': dense_normalizer['action'].normalize(gt_action['dense'])
    }

    B, T, _ = pred_naction['sparse'].shape
    pred_naction_sparse = pred_naction['sparse'].view(B, T, -1, 9)
    gt_naction_sparse = gt_naction['sparse'].view(B, T, -1, 9)
    sparse_loss = F.mse_loss(pred_naction_sparse, gt_naction_sparse, reduction='none')
    sparse_loss = sparse_loss.type(sparse_loss.dtype)
    sparse_loss = reduce(sparse_loss, 'b ... -> b (...)', 'mean')
    sparse_loss = sparse_loss.mean()            

    step_log[f'{category}_sparse_naction_mse_error'] = float(sparse_loss.detach())
    # step_log[f'{category}_sparse_naction_mse_error_pos'] = F.mse_loss(pred_naction_sparse[..., :3], gt_naction_sparse[..., :3])
    # step_log[f'{category}_sparse_naction_mse_error_rot'] = F.mse_loss(pred_naction_sparse[..., 3:9], gt_naction_sparse[..., 3:9])
    B, T, _, _= pred_naction['dense'].shape
    pred_naction_dense = pred_naction['dense'].view(B, T, -1, 9)
    gt_naction_dense = gt_naction['dense'].view(B, T, -1, 9)
    dense_loss = F.mse_loss(pred_naction_dense, gt_naction_dense, reduction='none')
    dense_loss = dense_loss.type(dense_loss.dtype)
    dense_loss = reduce(dense_loss, 'b ... -> b (...)', 'mean')
    dense_loss = dense_loss.mean()            
    step_log[f'{category}_dense_naction_mse_error'] = float(dense_loss.detach())
    # step_log[f'{category}_dense_naction_mse_error_pos'] = F.mse_loss(pred_naction_dense[..., :3], gt_naction_dense[..., :3])
    # step_log[f'{category}_dense_naction_mse_error_rot'] = F.mse_loss(pred_naction_dense[..., 3:9], gt_naction_dense[..., 3:9])
    
# get a batch of data'
print('get a batch of data')
batch = next(iter(train_dataloader))

# print(batch.keys())
# for key, attr in batch['obs']['sparse'].items():
#     print("   obs.sparse.key: ", key, attr.shape)
# for key, attr in batch['obs']['dense'].items():
#     print("   obs.dense.key: ", key, attr.shape)
# for key, attr in batch['action'].items():
#     print("   action.key: ", key, attr.shape)


In [6]:

# test compute loss
print('running policy on batch')
flag = {'dense_traj_cond_use_gt': True}
raw_loss = policy(batch, flag)
print('total loss: ', raw_loss)
print('sparse loss:', policy.sparse_loss)
print('dense loss:', policy.dense_loss)

# test predict action
gt_action = batch['action']
pred_action = policy.predict_action(batch['obs'], batch['action']) # providing batch will enable gt sparse condition
# print("gt_action['sparse'].shape: ", gt_action['sparse'].shape)
# print("pred_action['sparse'].shape: ", pred_action['sparse'].shape)
# print("gt_action['dense'].shape: ", gt_action['dense'].shape)
# print("pred_action['dense'].shape: ", pred_action['dense'].shape)

step_log = {}
log_action_mse(step_log, 'train', pred_action, gt_action)
print(json.dumps(step_log, indent=4))