Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature to drop last n keyframes for delta_timestamps #129

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
import logging

import torch
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset


def make_dataset(
cfg,
split="train",
):
def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset:
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
Expand All @@ -43,6 +40,7 @@ def make_dataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=delta_timestamps,
n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")),
)

if cfg.get("override_dataset_stats"):
Expand Down
34 changes: 31 additions & 3 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,26 @@ def __init__(
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
n_end_keyframes_dropped: int = 0,
):
"""
Args:
delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is
sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames
according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates
that we want to load the current keyframe's action, as well as one from 50 ms ago, and one
50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without
`delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that
frames outside of an episode boundary are retrieved, a copy padding strategy is used. See
`load_previous_and_future_frames` for more details.
n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when
used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future
frames, but we want to avoid introducing too much copy padding into the data distribution.
For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}`
and we sample the last frame in the episode, we would end up padding with 6 frames worth of
copies. Instead, we might want no padding (in which case we need n=6), or we might be okay
with up to 2 frames of padding (in which case we need n=4).
"""
super().__init__()
self.repo_id = repo_id
self.version = version
Expand All @@ -65,6 +84,12 @@ def __init__(
self.info = load_info(repo_id, version, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
# If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If
# `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where
# we drop those indices pertaining to the last n frames of each episode.
self.index = []
for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True):
self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped)))
Copy link
Contributor

@radekosmulski radekosmulski May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will leave self.hf_dataset["index'] and self.hf_dataset["episode_index"] in an inconsistent state with self.index and self.episode_data_index.

One problem is that we are duplicating the information that exists on the hf_dataset, namely hf_dataset["index"] so that we can mutate it. So instead of LeRobotDataset being a lightweight wrapper around an hf_dataset we now introduce this new state in between (self.index) that the maintainers and the users will have to know and remember about.

Another problem is that what should self.episode_data_index now reflect? self.episode_data_index now holds info for the data in hf_dataset but not for self.index, this can be confusing.

And the dependencies on this new introduced state only cascade down, where num_samples no longer depends directly on self.hf_dataset, etc.

So the intentions here are very good -- we want to make the life better for the users (so that they can train better models) and in some sense this implementation adds the new functionality in as straightforward way as possible, but the side effect here is that the complexity of the LeRobotDataset grows quite significantly and it is not clear anymore what what depends on. This can lead to subtle bugs and will make extending the functionality much harder down the road (and also new users to the codebase will have a much harder time understanding how things work -- for instance, what is the difference between self.hf_dataset["index"] and self.index and under what conditions they are different? Yeah, I know the answers to these questions but that is because I spent a good bit of time with this code now, but for new folks that might be quite challenging)

One potential way of tackling this might be to attack it via sampling -- in the same way as we have drop_last on the pytorch DataLoader. We could have a custom sampler we could pass to the dataloader that would drop the last n examples from each episode based on the episode_data_index.

The advantage here is that we would not be adding complexity to the LeRobotDataset. It would remain truer to its original intention, that is not as a way to modify data but as a way to add to a HuggingFace dataset in order to cater to the needs of training policies for controlling robots. But the dependency would be one way -- LeRobotDataset would depend on the hf_dataset it gets passed but the dependency would not go the other way. For all intents and purposes data contained in the hf_dataset would remain immutable, we would not be modifying the data in the hf_dataset nor would we be creating some state standing in for aspects of the hf_dataset. Our users would know where the single point of truth always is.

And also the full functionality would be localized to a Sampler. Yeah, if users would want to make avail of this functionality they would need to know of the existence of the Sampler (maybe a good place to introduce this would be as an example in the doc string for DiffusionPolicy) but we would not be adding extra complexity to the LeRobotDataset and arguably might be fighting complexity in the codebase via reducing coupling and doing something like dependency injection at the level of the dataloader.

Sorry, these are just some loose thoughts, though I understand that this problem is not easy to solve. But maybe some of this might be helpful.

I'll keep thinking about this and if anything comes to mind will comment again in his PR 🙂

Copy link
Contributor

@radekosmulski radekosmulski May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for reference, here is a SubsetRandomSampler from pytorch.

Might be nice to inherit from it and instatiate it with episode_data_index and n_end_keyframes_dropped. A minimal implementation might be utils function that would return an index from a LeRobotDataset with the n_end_keyframes_dropped excluded.

Anyhow, sorry 🙂 Not sure if any of this is helpful, I might be completely off the mark with the intention behind this PR.

Either way, it is very good to know of this difference vs the original training on pusht, so thank you for pointing me to this PR, @alexander-soare! 🙌

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@radekosmulski I find your arguments very convincing :) Would you be interested in having a crack at the sampler approach by any chance? cc @Cadene who had this on his TODO list at some point but it got shoved back by other prios.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing @alexander-soare, will give it a go! 🙂


@property
def fps(self) -> int:
Expand Down Expand Up @@ -107,8 +132,11 @@ def video_frame_keys(self) -> list[str]:

@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return len(self.hf_dataset)
"""Number of possible samples in the dataset.

This is equivalent to the number of frames in the dataset minus n_end_keyframes_dropped.
"""
return len(self.index)

@property
def num_episodes(self) -> int:
Expand All @@ -128,7 +156,7 @@ def __len__(self):
return self.num_samples

def __getitem__(self, idx):
item = self.hf_dataset[idx]
item = self.hf_dataset[self.index[idx]]

if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
Expand Down
10 changes: 10 additions & 0 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,21 @@ training:
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1

# For each training batch we want (consider n_obs_steps=2, horizon=16):
# t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
# action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a
# observation | o, o, , , , , , , , , , , , , ,
# Note that at rollout we only use some of the actions (consider n_action_steps=8):
# action used | , a, a, a, a, a, a, a, a, , , , , , ,
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"

# The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described
# above, the last 7 actions from the diffusion model are not used.
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1

eval:
n_episodes: 50
batch_size: 50
Expand Down
1 change: 1 addition & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_compute_stats_on_xarm():

# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
dataset.index = [i for i in dataset.index if i < 10]

# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
Expand Down
Loading