In [1]:
from omegaconf import OmegaConf
from lerobot.common.policies.factory import make_policy, _policy_cfg_from_hydra_cfg
from pathlib import Path
from lerobot.common.envs.factory import make_env
from importlib import import_module
from safetensors import safe_open
from lerobot.scripts.eval import eval_policy
import torch
torch.set_grad_enabled(False)

def load_safetensors(path):
    tensors = {}
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    return tensors

device = 0

In [13]:
# task = 'insertion'
task = 'transfer_cube'

In [14]:
config_path = './configs/arp.yaml'
ckpt_path = f'./weights/model.{task}.safetensors'

In [15]:
cfg = OmegaConf.load(config_path)
cfg['dataset_repo_id'] = f'lerobot/aloha_sim_{task}_human'
if task == 'insertion':
    cfg.env.task = 'AlohaInsertion-v0'
else:
    cfg.env.task = 'AlohaTransferCube-v0'

In [16]:
prefix = 'lerobot.common.policies.' + cfg.policy['name']
print(f'importing policy module from: {prefix}')
config_mod = import_module(prefix + '.configuration')
Config = config_mod.ARPConfig if hasattr(config_mod, 'ARPConfig') else config_mod.Config
modeling_mod = import_module(prefix + '.modeling')
Policy = modeling_mod.ARPPolicy if hasattr(modeling_mod, 'ARPPolicy') else modeling_mod.Policy

importing policy module from: lerobot.common.policies.autoregressive_policy


In [17]:
config = _policy_cfg_from_hydra_cfg(Config, cfg)



In [19]:
policy = Policy(config)
policy.load_state_dict(load_safetensors(ckpt_path))



<All keys matched successfully>

In [20]:
policy = policy.to(device)

In [21]:
env = make_env(cfg, 1)

In [22]:
eval_policy(env, policy, 30, max_episodes_rendered=30, videos_dir=Path('./outputs/demo/' + cfg['dataset_repo_id']), 
            enable_progbar=True, enable_inner_progbar=True)

Stepping through eval batches: 100%|██████████| 30/30 [02:08<00:00,  4.29s/it, running_success_rate=80.0%]


{'per_episode': [{'episode_ix': 0,
   'sum_reward': 242.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 1,
   'sum_reward': 264.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 2,
   'sum_reward': 252.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 3,
   'sum_reward': 223.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 4,
   'sum_reward': 35.0,
   'max_reward': 2.0,
   'success': False,
   'seed': None},
  {'episode_ix': 5,
   'sum_reward': 290.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 6,
   'sum_reward': 102.0,
   'max_reward': 2.0,
   'success': False,
   'seed': None},
  {'episode_ix': 7,
   'sum_reward': 278.0,
   'max_reward': 4.0,
   'success': True,
   'seed': None},
  {'episode_ix': 8,
   'sum_reward': 64.0,
   'max_reward': 1.0,
   'success': False,
   'seed': None},
  {'episode_ix': 9,
   'sum_reward': 242.0,
   '