In [1]:
from itertools import product
from pathlib import Path

import torch
from omegaconf import OmegaConf

from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PATH_TO_ORIGINAL_WEIGHTS = "/home/amandeep/latest_unet.ckpt"
# PATH_TO_CONFIG = "/home/amandeep/lerobot/lerobot/configs/default.yaml"
PATH_TO_CONFIG = "/home/amandeep/lerobot/unet_config.yaml"
PATH_TO_SAVE_NEW_WEIGHTS = "../"

cfg = init_hydra_config(PATH_TO_CONFIG)

policy = make_policy(cfg, dataset_stats=make_dataset(cfg).stats)

state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS)['state_dicts']['model']

Fetching 222 files: 100%|██████████| 222/222 [00:00<00:00, 3436.30it/s]


In [3]:
# Remove keys based on what they start with.

start_removals = ["normalizer.", "obs_encoder.obs_nets.rgb.backbone.nets.0.nets.0"]

for to_remove in start_removals:
    for k in list(state_dict.keys()):
        if k.startswith(to_remove):
            del state_dict[k]

In [40]:
# Replace keys based on what they start with.

start_replacements = [
    ("obs_encoder.obs_nets.image.backbone.nets", "rgb_encoder.backbone"),
    ("obs_encoder.obs_nets.image.pool", "rgb_encoder.pool"),
    ("obs_encoder.obs_nets.image.nets.3", "rgb_encoder.out"),
    # *[(f"model.up_modules.{i}.2.conv.", f"model.up_modules.{i}.2.") for i in range(2)],
    # *[(f"model.down_modules.{i}.2.conv.", f"model.down_modules.{i}.2.") for i in range(2)],
    # *[
    #     (f"model.mid_modules.{i}.blocks.{k}.", f"model.mid_modules.{i}.conv{k + 1}.")
    #     for i, k in product(range(3), range(2))
    # ],
    # *[
    #     (f"model.down_modules.{i}.{j}.blocks.{k}.", f"model.down_modules.{i}.{j}.conv{k + 1}.")
    #     for i, j, k in product(range(3), range(2), range(2))
    # ],
    # *[
    #     (f"model.up_modules.{i}.{j}.blocks.{k}.", f"model.up_modules.{i}.{j}.conv{k + 1}.")
    #     for i, j, k in product(range(3), range(2), range(2))
    # ],
    ("model.", "net.")
]

for to_replace, replace_with in start_replacements:
    for k in list(state_dict.keys()):
        if k.startswith(to_replace):
            k_ = replace_with + k.removeprefix(to_replace)
            state_dict[k_] = state_dict[k]
            del state_dict[k]

pos_y = state_dict['rgb_encoder.pool.pos_y'].t()
pos_x = state_dict['rgb_encoder.pool.pos_x'].t()
pos_grid = torch.cat([pos_x, pos_y], dim=1)
state_dict['rgb_encoder.pool.pos_grid'] = pos_grid

missing_keys, unexpected_keys = policy.diffusion.load_state_dict(state_dict, strict=False)

In [41]:
unexpected_keys = set(unexpected_keys)
allowed_unexpected_keys = eval(
    "{'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.1.nets.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.1.pos_x', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.1.nets.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.weight', '_dummy_variable', 'mask_generator._dummy_variable', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.temperature', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.pos_y', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.bias'}"
)
if len(missing_keys) != 0:
    print("MISSING KEYS")
    print(missing_keys)
if unexpected_keys != allowed_unexpected_keys:
    print("UNEXPECTED KEYS")
    for k in unexpected_keys:
        if(k not in allowed_unexpected_keys):
            print(k)

UNEXPECTED KEYS
rgb_encoder.pool.pos_y
rgb_encoder.pool.temperature
net._dummy_variable
rgb_encoder.pool.pos_x


In [42]:
unexpected_keys = set(unexpected_keys)
allowed_unexpected_keys = eval(
    "{'rgb_encoder.pool.pos_y','rgb_encoder.pool.temperature','net._dummy_variable','rgb_encoder.pool.pos_x','obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.1.nets.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.1.pos_x', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.1.nets.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.weight', '_dummy_variable', 'mask_generator._dummy_variable', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.temperature', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.pos_y', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.bias'}"
)
if len(missing_keys) != 0:
    print("MISSING KEYS")
    print(missing_keys)
if unexpected_keys != allowed_unexpected_keys:
    print("UNEXPECTED KEYS")
    print(unexpected_keys)

if len(missing_keys) != 0 or unexpected_keys != allowed_unexpected_keys:
    print("Failed due to mismatch in state dicts.")
    exit()

torch.save(policy.state_dict(), "/tmp/policy.pt")
policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS)
OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml")

In [45]:
from lerobot.common.envs.factory import make_env
from lerobot.scripts.eval import eval_policy
env = make_env(cfg)

In [47]:
#Eval
from lerobot.common.envs.factory import make_env
from lerobot.scripts.eval import eval_policy
env = make_env(cfg)
with torch.no_grad():
    info = eval_policy(
        env,
        policy,
        cfg.eval.n_episodes,
        max_episodes_rendered=10,
        videos_dir=Path("./"),
        start_seed=cfg.seed,
    )
print(info["aggregated"])

{'avg_sum_reward': 10.811818440497525, 'avg_max_reward': 0.09880081223534598, 'pc_success': 0.0, 'eval_s': 110.97826099395752, 'eval_ep_s': 2.219565234184265}
