# Load the checkpoint

In [1]:
import sys
import os
from typing import Dict, Callable, Tuple, List

# Add project root to Python path
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

# Change working directory to project root
os.chdir(ROOT_DIR)
print(f"Working directory: {os.getcwd()}")

import numpy as np
import torch
import time
import dill
import hydra
from torch.utils.data import DataLoader

from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset
from diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace

data_path = "/home/robotlab/ACP/training_outputs/"
ckpt_path = data_path + "2025.11.04_10.05.05_Wipe_single_arm_Wipe_single_arm/checkpoints/latest.ckpt"

device = torch.device('cpu')

# load checkpoint
if not ckpt_path.endswith('.ckpt'):
    ckpt_path = os.path.join(ckpt_path, 'checkpoints', 'latest.ckpt')
payload = torch.load(open(ckpt_path, 'rb'), map_location='cpu', pickle_module=dill)
cfg = payload['cfg']

# Inspect config structure
print("Config keys:", cfg.keys())
print("\nDataset path:", cfg.task.dataset.dataset_path if 'task' in cfg else "N/A")

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.num_inference_steps = cfg.policy.num_inference_steps

policy.eval().to(device)
policy.reset()

# Get the single normalizer (handles both sparse and dense)
normalizer = policy.get_normalizer()
print(f"\nNormalizer keys: {list(normalizer.params_dict.keys())}")

shape_meta = cfg.task.shape_meta
print(f"\nLoaded checkpoint successfully!")

Working directory: /home/robotlab/ACP/adaptive_compliance_policy/PyriteML


  from .autonotebook import tqdm as notebook_tqdm


Config keys: dict_keys(['name', 'output_dir', '_target_', 'task_name', 'shape_meta', 'exp_name', 'policy', 'ema', 'dataloader', 'val_dataloader', 'optimizer', 'training', 'logging', 'checkpoint', 'multi_run', 'task'])

Dataset path: /home/robotlab/ACP/data/real_processed/wipe_single_arm



Initializing zero-element tensors is a no-op

vit will use the CLS token. feature_aggregation (attention_pool_2d) is ignored!


rgb keys:          ['rgb_0']
low_dim_keys keys: ['robot0_eef_pos', 'robot0_eef_rot_axis_angle']
==> reduce pretrained obs_encorder's lr
==> rgb keys:  ['rgb_0']
obs_encorder params: 151

Normalizer keys: ['action', 'rgb_0', 'robot0_eef_pos', 'robot0_eef_rot_axis_angle', 'robot0_eef_wrench']

Loaded checkpoint successfully!


# Load a dataset

In [2]:
# # load the dataset used in training
# dataset: BaseImageDataset
# dataset = hydra.utils.instantiate(cfg.task.dataset)
# assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
# print("Test Script: Creating dataloader.")
# train_dataloader = DataLoader(dataset, **cfg.dataloader)
# print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))

# load the dataset specified in config
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(
    version_base=None,
    config_path=str('../diffusion_policy/config'),
    job_name="test_app"
):
    cfg = compose(config_name="train_spec_workspace")
    OmegaConf.resolve(cfg)

    print("Test Script: configuring dataset.")
    dataset: BaseImageDataset
    dataset = hydra.utils.instantiate(cfg.task.dataset)
    assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
    print("Test Script: Creating dataloader.")
    train_dataloader = DataLoader(dataset, **cfg.dataloader)
    print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))

Test Script: configuring dataset.
[VirtualTargetDataset] loading data into store
[ReplayBuffer] checking chunk size and compressor.
 checking:  episode_1727235551
 checking:  episode_1727235569
 checking:  episode_1727235583
 checking:  episode_1727235597
 checking:  episode_1727235613
 checking:  episode_1727235630
 checking:  episode_1727235645
 checking:  episode_1727235661
 checking:  episode_1727235681
 checking:  episode_1727235696
 checking:  episode_1727235710
 checking:  episode_1727235725
 checking:  episode_1727235744
 checking:  episode_1727235765
 checking:  episode_1727235779
 checking:  episode_1727235795
 checking:  episode_1727235810
 checking:  episode_1727235825
 checking:  episode_1727235839
 checking:  episode_1727236158
 checking:  episode_1727236172
 checking:  episode_1727236186
 checking:  episode_1727236202
 checking:  episode_1727236216
 checking:  episode_1727236238
 checking:  episode_1727236253
 checking:  episode_1727236265
 checking:  episode_1727236280


# Run some tests

In [3]:
import torch.nn.functional as F
from einops import rearrange, reduce
import json

def log_action_mse(step_log, category, pred_action, gt_action):
    # Only process keys that exist in both pred and gt
    for key in ['sparse', 'dense']:
        if key not in pred_action or key not in gt_action:
            continue
            
        pred_naction = normalizer['action'].normalize(pred_action[key])
        gt_naction = normalizer['action'].normalize(gt_action[key])
        
        B, T, _ = pred_naction.shape
        pred_naction = pred_naction.view(B, T, -1, 19)  # 19 action dims
        gt_naction = gt_naction.view(B, T, -1, 19)
        
        loss = F.mse_loss(pred_naction, gt_naction, reduction='none')
        loss = loss.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        
        step_log[f'{category}_{key}_naction_mse_error'] = float(loss.detach())
    
# get a batch of data
print('get a batch of data')
batch = next(iter(train_dataloader))

# Inspect what keys are actually present
print("pred_action keys:", pred_action.keys() if 'pred_action' in locals() else "not computed yet")
print("gt_action keys:", batch['action'].keys())


get a batch of data
pred_action keys: not computed yet
gt_action keys: dict_keys(['sparse'])


In [4]:

# test compute loss
print('running policy on batch')
flag = {'dense_traj_cond_use_gt': True}
raw_loss = policy(batch, flag)
print('total loss: ', raw_loss)
print('sparse loss:', policy.sparse_loss)
#print('dense loss:', policy.dense_loss)

# test predict action
gt_action = batch['action']
pred_action = policy.predict_action(batch['obs']) # providing batch will enable gt sparse condition
print("gt_action['sparse'].shape: ", gt_action['sparse'].shape)
print("pred_action['sparse'].shape: ", pred_action['sparse'].shape)
#print("gt_action['dense'].shape: ", gt_action['dense'].shape)
#print("pred_action['dense'].shape: ", pred_action['dense'].shape)

step_log = {}
log_action_mse(step_log, 'train', pred_action, gt_action)
print(json.dumps(step_log, indent=4))

running policy on batch


total loss:  tensor(0.1973, grad_fn=<MseLossBackward0>)
sparse loss: tensor(0.1973, grad_fn=<MseLossBackward0>)
gt_action['sparse'].shape:  torch.Size([128, 16, 19])
pred_action['sparse'].shape:  torch.Size([128, 16, 19])
{
    "train_sparse_naction_mse_error": 0.27118033170700073
}
