In [1]:
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
from threading import Lock

import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler

from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
    format_big_number,
    get_safe_torch_device,
    init_hydra_config,
    init_logging,
    set_global_seed,
)
from lerobot.scripts.eval import eval_policy


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def make_optimizer_and_scheduler(cfg, policy):
    if cfg.policy.name == "act":
        optimizer_params_dicts = [
            {
                "params": [
                    p
                    for n, p in policy.named_parameters()
                    if not n.startswith("model.backbone") and p.requires_grad
                ]
            },
            {
                "params": [
                    p
                    for n, p in policy.named_parameters()
                    if n.startswith("model.backbone") and p.requires_grad
                ],
                "lr": cfg.training.lr_backbone,
            },
        ]
        optimizer = torch.optim.AdamW(
            optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
        )
        lr_scheduler = None
    elif cfg.policy.name == "diffusion":
        optimizer = torch.optim.Adam(
            policy.diffusion.parameters(),
            cfg.training.lr,
            cfg.training.adam_betas,
            cfg.training.adam_eps,
            cfg.training.adam_weight_decay,
        )
        from diffusers.optimization import get_scheduler

        lr_scheduler = get_scheduler(
            cfg.training.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=cfg.training.lr_warmup_steps,
            num_training_steps=cfg.training.offline_steps,
        )
    elif policy.name == "tdmpc":
        optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
        lr_scheduler = None
    elif cfg.policy.name == "vqbet":
        from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

        optimizer = VQBeTOptimizer(policy, cfg)
        lr_scheduler = VQBeTScheduler(optimizer, cfg)
    else:
        raise NotImplementedError()

    return optimizer, lr_scheduler

In [3]:
def update_policy(
    policy,
    batch,
    optimizer,
    grad_clip_norm,
    grad_scaler: GradScaler,
    lr_scheduler=None,
    use_amp: bool = False,
    lock=None,
):
    """Returns a dictionary of items for logging."""
    start_time = time.perf_counter()
    device = get_device_from_parameters(policy)
    policy.train()
    with torch.autocast(device_type=device.type) if use_amp else nullcontext():
        output_dict = policy.forward(batch)
        # TODO(rcadene): policy.unnormalize_outputs(out_dict)
        loss = output_dict["loss"]
    grad_scaler.scale(loss).backward()

    # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
    grad_scaler.unscale_(optimizer)

    grad_norm = torch.nn.utils.clip_grad_norm_(
        policy.parameters(),
        grad_clip_norm,
        error_if_nonfinite=False,
    )

    # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
    # although it still skips optimizer.step() if the gradients contain infs or NaNs.
    with lock if lock is not None else nullcontext():
        grad_scaler.step(optimizer)
    # Updates the scale for next iteration.
    grad_scaler.update()

    optimizer.zero_grad()

    if lr_scheduler is not None:
        lr_scheduler.step()

    if isinstance(policy, PolicyWithUpdate):
        # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
        policy.update()

    info = {
        "loss": loss.item(),
        "grad_norm": float(grad_norm),
        "lr": optimizer.param_groups[0]["lr"],
        "update_s": time.perf_counter() - start_time,
        **{k: v for k, v in output_dict.items() if k != "loss"},
    }
    info.update({k: v for k, v in output_dict.items() if k not in info})

    return info

In [4]:
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
    loss = info["loss"]
    grad_norm = info["grad_norm"]
    lr = info["lr"]
    update_s = info["update_s"]
    dataloading_s = info["dataloading_s"]

    # A sample is an (observation,action) pair, where observation and action
    # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
    num_samples = (step + 1) * cfg.training.batch_size
    avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
    num_episodes = num_samples / avg_samples_per_ep
    num_epochs = num_samples / dataset.num_samples
    log_items = [
        f"step:{format_big_number(step)}",
        # number of samples seen during training
        f"smpl:{format_big_number(num_samples)}",
        # number of episodes seen during training
        f"ep:{format_big_number(num_episodes)}",
        # number of time all unique samples are seen
        f"epch:{num_epochs:.2f}",
        f"loss:{loss:.3f}",
        f"grdn:{grad_norm:.3f}",
        f"lr:{lr:0.1e}",
        # in seconds
        f"updt_s:{update_s:.3f}",
        f"data_s:{dataloading_s:.3f}",  # if not ~0, you are bottlenecked by cpu or io
    ]
    logging.info(" ".join(log_items))

    info["step"] = step
    info["num_samples"] = num_samples
    info["num_episodes"] = num_episodes
    info["num_epochs"] = num_epochs
    info["is_online"] = is_online

    logger.log_dict(info, step, mode="train")

In [5]:
def log_eval_info(logger, info, step, cfg, dataset, is_online):
    eval_s = info["eval_s"]
    avg_sum_reward = info["avg_sum_reward"]
    pc_success = info["pc_success"]

    # A sample is an (observation,action) pair, where observation and action
    # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
    num_samples = (step + 1) * cfg.training.batch_size
    avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
    num_episodes = num_samples / avg_samples_per_ep
    num_epochs = num_samples / dataset.num_samples
    log_items = [
        f"step:{format_big_number(step)}",
        # number of samples seen during training
        f"smpl:{format_big_number(num_samples)}",
        # number of episodes seen during training
        f"ep:{format_big_number(num_episodes)}",
        # number of time all unique samples are seen
        f"epch:{num_epochs:.2f}",
        f"∑rwrd:{avg_sum_reward:.3f}",
        f"success:{pc_success:.1f}%",
        f"eval_s:{eval_s:.3f}",
    ]
    logging.info(" ".join(log_items))

    info["step"] = step
    info["num_samples"] = num_samples
    info["num_episodes"] = num_episodes
    info["num_epochs"] = num_epochs
    info["is_online"] = is_online

    logger.log_dict(info, step, mode="eval")

In [None]:
config_name="default"
config_path="../configs"
out_dir="/home/ns1254/lerobot/outputs/train/notebook"
job_name="default"


from hydra import compose, initialize

hydra.core.global_hydra.GlobalHydra.instance().clear()

initialize(config_path=config_path)  
cfg = compose(config_name=config_name, overrides=["policy=vqbet", "env=pusht", "dataset_repo_id=lerobot/pusht", "device=cuda"])


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path=config_path)


In [None]:
# python lerobot/scripts/train.py   policy=vqbet   env=pusht dataset_repo_id=lerobot/pusht   wandb.enable=true   device=cuda

In [7]:
if out_dir is None:
    raise NotImplementedError()
if job_name is None:
    raise NotImplementedError()

init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
 
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)

set_global_seed(cfg.seed)

# Check device is available
device = get_safe_torch_device(cfg.device, log=True)

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

INFO 2024-11-07 22:37:19 2069291986.py:7 {'dataset_repo_id': 'lerobot/pusht',
 'device': 'cuda',
 'env': {'action_dim': 2,
         'episode_length': 300,
         'fps': '${fps}',
         'gym': {'obs_type': 'pixels_agent_pos',
                 'render_mode': 'rgb_array',
                 'visualization_height': 384,
                 'visualization_width': 384},
         'image_size': 96,
         'name': 'pusht',
         'state_dim': 2,
         'task': 'PushT-v0'},
 'eval': {'batch_size': 50, 'n_episodes': 50, 'use_async_envs': False},
 'fps': 10,
 'override_dataset_stats': {'observation.image': {'mean': [[[0.485]],
                                                           [[0.456]],
                                                           [[0.406]]],
                                                  'std': [[[0.229]],
                                                          [[0.224]],
                                                          [[0.225]]]}},
 'policy': {'chunk_s

In [8]:
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)

INFO 2024-11-07 22:37:20 2614359162.py:1 make_dataset
Fetching 212 files: 100%|██████████| 212/212 [00:00<00:00, 11243.65it/s]


In [9]:
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0:
    logging.info("make_env")
    eval_env = make_env(cfg)

INFO 2024-11-07 22:37:21 1892783823.py:6 make_env


In [10]:
logging.info("make_policy")
policy = make_policy(
    hydra_cfg=cfg,
    dataset_stats=offline_dataset.stats if not cfg.resume else None,
    pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)

INFO 2024-11-07 22:37:23 2704595911.py:1 make_policy


In [11]:
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)

step = 0  # number of policy updates (forward + backward + optim)

  grad_scaler = GradScaler(enabled=cfg.use_amp)


In [12]:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())

log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")

INFO 2024-11-07 22:37:25 on/logger.py:39 [1m[33mOutput dir:[0m /home/ns1254/lerobot/outputs/train/notebook
INFO 2024-11-07 22:37:25 /195372604.py:5 cfg.env.task='PushT-v0'
INFO 2024-11-07 22:37:25 /195372604.py:6 cfg.training.offline_steps=200000 (200K)
INFO 2024-11-07 22:37:25 /195372604.py:7 cfg.training.online_steps=0
INFO 2024-11-07 22:37:25 /195372604.py:8 offline_dataset.num_samples=25650 (26K)
INFO 2024-11-07 22:37:25 /195372604.py:9 offline_dataset.num_episodes=206
INFO 2024-11-07 22:37:25 195372604.py:10 num_learnable_params=51588994 (52M)
INFO 2024-11-07 22:37:25 195372604.py:11 num_total_params=51589012 (52M)


In [13]:
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step, is_online):
    _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
    step_identifier = f"{step:0{_num_digits}d}"

    if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
        logging.info(f"Eval policy at step {step}")
        with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
            assert eval_env is not None
            eval_info = eval_policy(
                eval_env,
                policy,
                cfg.eval.n_episodes,
                videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
                max_episodes_rendered=4,
                start_seed=cfg.seed,
            )
        log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
        if cfg.wandb.enable:
            logger.log_video(eval_info["video_paths"][0], step, mode="eval")
        logging.info("Resume training")

    if cfg.training.save_checkpoint and (
        step % cfg.training.save_freq == 0
        or step == cfg.training.offline_steps + cfg.training.online_steps
    ):
        logging.info(f"Checkpoint policy after step {step}")
        # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
        # needed (choose 6 as a minimum for consistency without being overkill).
        logger.save_checkpoint(
            step,
            policy,
            optimizer,
            lr_scheduler,
            identifier=step_identifier,
        )
        logging.info("Resume training")


In [14]:
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
    shuffle = False
    sampler = EpisodeAwareSampler(
        offline_dataset.episode_data_index,
        drop_n_last_frames=cfg.training.drop_n_last_frames,
        shuffle=True,
    )
else:
    shuffle = True
    sampler = None
dataloader = torch.utils.data.DataLoader(
    offline_dataset,
    num_workers=cfg.training.num_workers,
    batch_size=cfg.training.batch_size,
    shuffle=shuffle,
    sampler=sampler,
    pin_memory=device.type != "cpu",
    drop_last=False,
)
dl_iter = cycle(dataloader)

In [16]:
offline_dataset

LeRobotDataset(
  Repository ID: 'lerobot/pusht',
  Split: 'train',
  Number of Samples: 25650,
  Number of Episodes: 206,
  Type: video (.mp4),
  Recorded Frames per Second: 10,
  Camera Keys: ['observation.image'],
  Video Frame Keys: ['observation.image'],
  Transformations: None,
  Codebase Version: v1.6,
)

In [None]:
policy.train()
offline_step = 0
for _ in range(step, cfg.training.offline_steps):
    if offline_step == 0:
        logging.info("Start offline training on a fixed dataset")

    start_time = time.perf_counter()
    batch = next(dl_iter)
    dataloading_s = time.perf_counter() - start_time

    for key in batch:
        batch[key] = batch[key].to(device, non_blocking=True)

    train_info = update_policy(
        policy,
        batch,
        optimizer,
        cfg.training.grad_clip_norm,
        grad_scaler=grad_scaler,
        lr_scheduler=lr_scheduler,
        use_amp=cfg.use_amp,
    )

    train_info["dataloading_s"] = dataloading_s

    if step % cfg.training.log_freq == 0:
        log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)

    # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
    # so we pass in step + 1.
    evaluate_and_checkpoint_if_needed(step + 1, is_online=False)

    step += 1
    offline_step += 1  # noqa: SIM113

if cfg.training.online_steps == 0:
    if eval_env:
        eval_env.close()
    logging.info("End of training")