# Load Workspace, policy, and dataset

In [1]:
import sys
import os
from typing import Dict, Callable, Tuple, List

# # use line-buffering for both stdout and stderr
# sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
# sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

# Add project root to Python path
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

# Change working directory to project root
os.chdir(ROOT_DIR)
print(f"Working directory: {os.getcwd()}")

import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf
import torch
from torch.utils.data import DataLoader
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset
from diffusion_policy.common.pytorch_util import dict_apply

# allows arbitrary python code execution in configs using the ${eval:''} resolver
OmegaConf.register_new_resolver("eval", eval, replace=True)

temp_output_dir = "./temp_output_dir"

torch.set_num_threads(1)

with initialize(
    version_base=None,
    config_path=str('../diffusion_policy/config'),
    job_name="test_app"
):
    print("Test Script: starting.")
    cfg = compose(config_name="train_spec_workspace")
    # resolve immediately so all the ${now:} resolvers
    # will use the same time.
    print("Test Script: resolving config.")
    OmegaConf.resolve(cfg)

    print("Test Script: initializing workspace.")
    cls = hydra.utils.get_class(cfg._target_)
    workspace: BaseWorkspace = cls(cfg)
    policy = workspace.model

    # configure dataset
    print("Test Script: configuring dataset.")
    dataset: BaseImageDataset
    dataset = hydra.utils.instantiate(cfg.task.dataset)
    assert isinstance(dataset, BaseImageDataset) or isinstance(dataset, BaseDataset)
    print("Test Script: Creating dataloader.")
    train_dataloader = DataLoader(dataset, **cfg.dataloader)
    print('train dataset:', len(dataset), 'train dataloader:', len(train_dataloader))
    


Working directory: /home/robotlab/ACP/adaptive_compliance_policy/PyriteML
Test Script: starting.
Test Script: resolving config.
Test Script: initializing workspace.


  from .autonotebook import tqdm as notebook_tqdm

Initializing zero-element tensors is a no-op

vit will use the CLS token. feature_aggregation (attention_pool_2d) is ignored!


rgb keys:          ['rgb_0']
low_dim_keys keys: ['robot0_eef_pos', 'robot0_eef_rot_axis_angle']
==> reduce pretrained obs_encorder's lr
==> rgb keys:  ['rgb_0']
obs_encorder params: 151
Test Script: configuring dataset.
[VirtualTargetDataset] loading data into store
[ReplayBuffer] checking chunk size and compressor.
 checking:  episode_1727235551
 checking:  episode_1727235569
 checking:  episode_1727235583
 checking:  episode_1727235597
 checking:  episode_1727235613
 checking:  episode_1727235630
 checking:  episode_1727235645
 checking:  episode_1727235661
 checking:  episode_1727235681
 checking:  episode_1727235696
 checking:  episode_1727235710
 checking:  episode_1727235725
 checking:  episode_1727235744
 checking:  episode_1727235765
 checking:  episode_1727235779
 checking:  episode_1727235795
 checking:  episode_1727235810
 checking:  episode_1727235825
 checking:  episode_1727235839
 checking:  episode_1727236158
 checking:  episode_1727236172
 checking:  episode_1727236186


In [2]:
test_data = dataset[5]


# Compute Normalizers

In [3]:
# compute normalizer on the main process and save to disk
print("Test Script: Computing normalizer.")
sparse_normalizer_path = os.path.join(temp_output_dir, 'sparse_normalizer.pkl')
dense_normalizer_path = os.path.join(temp_output_dir, 'dense_normalizer.pkl')
sparse_normalizer = dataset.get_normalizer()
policy.set_normalizer(sparse_normalizer)


device = policy.device
print("Test Script: done")


Test Script: Computing normalizer.


iterating dataset to get normalization: 100%|██████████| 209/209 [00:03<00:00, 60.01it/s] 


data_cache_sparse['action'] (213424, 19)
Test Script: done


# Sweep dataloader parameters

In [4]:
# sweep batch size and num_workers
import time
from tqdm import trange, tqdm


# batch_sizes = [16, 32, 64, 128, 256]
# num_workers = [4,8,16,32]
batch_sizes = [16]
num_workers = [32]
timings = {}

# 5.5 seconds is the threshold for the time it takes to load 10% of the dataset
time_threshold = 20 + 10
check_ratio = 0.3
device = 'cuda:0'
print('starting')
print('device:', device)
for batch_size in batch_sizes:
    for num_worker in num_workers:
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_worker)

        start_time = time.time()
        # dataset.get_normalizer(batch_size, num_worker)
        finished = True
        with tqdm(dataloader, desc=f'iterating dataset with batch_size={batch_size}, num_workers={num_worker}.') as tepoch:
            for batch_idx, batch in enumerate(tepoch):
                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
                if time.time() - start_time > time_threshold:
                    if batch_idx / len(tepoch) < check_ratio:
                        print(f"Time exceeded threshold of {time_threshold} seconds, but data loaded is less than {check_ratio}. Exiting.")
                        finished = False
                        break
        end_time = time.time()
        elapsed_time = end_time - start_time
        if finished:
            timings[(batch_size, num_worker)] = elapsed_time

fastest_batch_size = min(timings, key=timings.get)
print("Fastest batch size:", fastest_batch_size, ", Fastest time:", timings[fastest_batch_size])

starting
device: cuda:0


iterating dataset with batch_size=16, num_workers=32.:   0%|          | 0/834 [00:04<?, ?it/s]


AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


# Read a batch of data, test normalizer

In [None]:

# reach one batch of data from the dataloader
batch = next(iter(train_dataloader))

print(batch.keys())
for key, attr in batch['obs']['sparse'].items():
    print("   obs.sparse.key: ", key, attr.shape)
#for key, attr in batch['obs']['dense'].items():
    #print("   obs.dense.key: ", key, attr.shape)

nobs_sparse = sparse_normalizer.normalize(batch['obs']['sparse'])
nactions_sparse = sparse_normalizer['action'].normalize(batch['action']['sparse'])
# nactions_dense = dense_normalizer['action'].normalize(batch['action']['dense'])



# test normalizer
print("policy debug: batch['action']['sparse'][0,0,:]: ", batch['action']['sparse'][0,0,:])
print("policy debug: nactions_sparse[0,0,:]: ", nactions_sparse[0,0,:])

# batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))


dict_keys(['obs', 'action'])
   obs.sparse.key:  rgb_0 torch.Size([128, 2, 3, 224, 224])
   obs.sparse.key:  robot0_eef_pos torch.Size([128, 3, 3])
   obs.sparse.key:  robot0_eef_rot_axis_angle torch.Size([128, 3, 6])
   obs.sparse.key:  robot0_eef_wrench torch.Size([128, 7000, 6])
policy debug: batch['action']['sparse'][0,0,:]:  tensor([ 0.0000e+00, -2.9802e-08, -1.4901e-08,  1.0000e+00,  3.4119e-08,
        -5.3585e-08,  3.4119e-08,  1.0000e+00,  3.5782e-08,  0.0000e+00,
        -2.9802e-08, -1.4901e-08,  1.0000e+00,  8.4664e-06, -1.0927e-06,
        -8.4497e-06,  1.0000e+00,  9.7696e-06,  5.0000e+03])
policy debug: nactions_sparse[0,0,:]:  tensor([ 1.3410e-01, -2.0990e-01,  1.9732e-01,  1.0000e+00,  3.4119e-08,
        -5.3585e-08,  3.4119e-08,  1.0000e+00,  3.5782e-08,  1.1233e-01,
        -1.4039e-01,  1.9732e-01,  1.0000e+00,  8.4664e-06, -1.0927e-06,
        -8.4497e-06,  1.0000e+00,  9.7696e-06,  1.0000e+00],
       grad_fn=<SelectBackward0>)



## Test compute_loss


In [None]:

flag = {
    'start_training_dense': True,
    'dense_traj_cond_use_gt': True
}
raw_loss = policy(batch, flag)
print(raw_loss)

tensor(1.1602, grad_fn=<MseLossBackward0>)


## Test predict_action


In [None]:

import torch
def log_action_mse(step_log, category, pred_action, gt_action):
    B, T, _ = pred_action['sparse'].shape
    pred_action_sparse = pred_action['sparse'].view(B, T, -1, 19)
    gt_action_sparse = gt_action['sparse'].view(B, T, -1, 19)
    step_log[f'{category}_sparse_action_mse_error'] = torch.nn.functional.mse_loss(pred_action_sparse, gt_action_sparse)
    step_log[f'{category}_sparse_action_mse_error_pos'] = torch.nn.functional.mse_loss(pred_action_sparse[..., :3], gt_action_sparse[..., :3])
    step_log[f'{category}_sparse_action_mse_error_rot'] = torch.nn.functional.mse_loss(pred_action_sparse[..., 3:9], gt_action_sparse[..., 3:9])
    # step_log[f'{category}_sparse_action_mse_error_width'] = torch.nn.functional.mse_loss(pred_action_sparse[..., 9], gt_action_sparse[..., 9])
    #B, T, _, _= pred_action['dense'].shape
    #pred_action_dense = pred_action['dense'].view(B, T, -1, 9)
    #gt_action_dense = gt_action['dense'].view(B, T, -1, 9)
    #step_log[f'{category}_dense_action_mse_error'] = torch.nn.functional.mse_loss(pred_action_dense, gt_action_dense)
    #step_log[f'{category}_dense_action_mse_error_pos'] = torch.nn.functional.mse_loss(pred_action_dense[..., :3], gt_action_dense[..., :3])
    #step_log[f'{category}_dense_action_mse_error_rot'] = torch.nn.functional.mse_loss(pred_action_dense[..., 3:9], gt_action_dense[..., 3:9])
    # step_log[f'{category}_dense_action_mse_error_width'] = torch.nn.functional.mse_loss(pred_action_dense[..., 9], gt_action_dense[..., 9])
                
# sample trajectory from training set, and evaluate difference
gt_action = batch['action']
pred_action = policy.predict_action(batch['obs'])
print("gt_action['sparse'].shape: ", gt_action['sparse'].shape)
print("pred_action['sparse'].shape: ", pred_action['sparse'].shape)
#print("gt_action['dense'].shape: ", gt_action['dense'].shape)
#print("pred_action['dense'].shape: ", pred_action['dense'].shape)

step_log = {}
log_action_mse(step_log, 'train', pred_action, gt_action)
print(step_log)