Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issues with PushT diffusion #41

Merged
merged 9 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 270 additions & 88 deletions .github/poetry/cpu/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions .github/poetry/cpu/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
dm = "^1.3"
dm-control = "^1.0.16"
robomimic = "0.2.0"
huggingface-hub = "^0.21.4"


Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/envs/aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _step(self, tensordict: TensorDict):
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
},
Expand Down
24 changes: 0 additions & 24 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,3 @@ def _make_env(seed):
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)


# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env
26 changes: 23 additions & 3 deletions lerobot/common/envs/pusht/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections import deque
from typing import Optional

import cv2
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -59,12 +61,30 @@ def _make_env(self):

self._env = PushTImageEnv(render_size=self.image_size)

def render(self, mode="rgb_array", width=384, height=384):
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
"""
with_marker adds a cursor showing the targeted action for the controller.
"""
if width != height:
raise NotImplementedError()
tmp = self._env.render_size
self._env.render_size = width
out = self._env.render(mode)
if width != self._env.render_size:
self._env.render_cache = None
self._env.render_size = width
out = self._env.render(mode).copy()
if with_marker and self._env.latest_action is not None:
action = np.array(self._env.latest_action)
coord = (action / 512 * self._env.render_size).astype(np.int32)
marker_size = int(8 / 96 * self._env.render_size)
thickness = int(1 / 96 * self._env.render_size)
cv2.drawMarker(
out,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self._env.render_size = tmp
return out

Expand Down
15 changes: 0 additions & 15 deletions lerobot/common/envs/pusht/pusht_image_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import cv2
import numpy as np
from gym import spaces

Expand Down Expand Up @@ -28,20 +27,6 @@ def _get_obs(self):
img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}

# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self.render_cache = img

return obs
Expand Down
44 changes: 42 additions & 2 deletions lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,44 @@
"""Code from the original diffusion policy project.

Notes on how to load a checkpoint from the original repository:

In the original repository, run the eval and use a breakpoint to extract the policy weights.

```
torch.save(policy.state_dict(), "weights.pt")
```

In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:

```
loaded = torch.load("weights.pt")
aligned = {}
their_prefix = "obs_encoder.obs_nets.image.backbone"
our_prefix = "obs_encoder.key_model_map.image.backbone"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.pool"
our_prefix = "obs_encoder.key_model_map.image.pool"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.nets.3"
our_prefix = "obs_encoder.key_model_map.image.out"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# Note: here you are loading into the ema model.
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
assert all('_dummy_variable' in k for k in missing_keys)
assert len(unexpected_keys) == 0
```

Then in that same runtime you can also save the weights with the new aligned state_dict:

```
policy.save("weights.pt")
```

Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.

"""

from typing import Dict

import torch
Expand Down Expand Up @@ -190,11 +231,10 @@ def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T

# run sampling
nsample = self.conditional_sample(
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
)

action_pred = nsample[..., :action_dim]

# get action
start = n_obs_steps - 1
end = start + self.n_action_steps
Expand Down
55 changes: 40 additions & 15 deletions lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
import copy
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax

from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules


class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""

def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()

def forward(self, x):
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))


class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
Expand All @@ -24,7 +49,7 @@ def __init__(
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
norm_mean_std: Optional[tuple[float, float]] = None,
):
"""
Assumes rgb input: B,C,H,W
Expand Down Expand Up @@ -98,10 +123,9 @@ def __init__(
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
if norm_mean_std is not None:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
mean=norm_mean_std[0], std=norm_mean_std[1]
)

this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
Expand All @@ -124,6 +148,17 @@ def __init__(
def forward(self, obs_dict):
batch_size = None
features = []

# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)

# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
Expand Down Expand Up @@ -161,16 +196,6 @@ def forward(self, obs_dict):
feature = self.key_model_map[key](img)
features.append(feature)

# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)

# concatenate all features
result = torch.cat(features, dim=-1)
return result
Expand Down
28 changes: 23 additions & 5 deletions lerobot/common/policies/diffusion/policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import logging
import time

import hydra
Expand All @@ -7,7 +8,7 @@
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.utils import get_safe_torch_device


Expand Down Expand Up @@ -39,7 +40,10 @@ def __init__(
self.cfg = cfg

noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
Expand All @@ -66,11 +70,13 @@ def __init__(
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)

self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
model=self.ema_diffusion,
)

self.optimizer = hydra.utils.instantiate(
Expand All @@ -94,14 +100,20 @@ def __init__(

@torch.no_grad()
def select_actions(self, observation, step_count):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count
del step_count

obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"]
return action

Expand Down Expand Up @@ -191,4 +203,10 @@ def save(self, fp):

def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
Loading
Loading