In [1]:
import logging
import time
from contextlib import nullcontext
from pathlib import Path

import torch
from torch.amp import GradScaler
from torch.optim import Optimizer

from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.envs.configs import LiberoEnv
from lerobot.policies.segvla.configuration_segvla import SegVLAConfig
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import get_safe_torch_device, has_method, init_logging

In [2]:
def update_policy(
    train_metrics: MetricsTracker,
    policy,
    batch,
    optimizer: Optimizer,
    grad_clip_norm: float,
    grad_scaler: GradScaler,
    lr_scheduler=None,
    use_amp: bool = False,
    lock=None,
):
    start_time = time.perf_counter()
    device = next(policy.parameters()).device
    policy.train()
    autocast_context = torch.autocast(device_type=device.type) if use_amp else nullcontext()
    with autocast_context:
        loss, output_dict = policy.forward(batch)
    grad_scaler.scale(loss).backward()
    grad_scaler.unscale_(optimizer)
    grad_norm = torch.nn.utils.clip_grad_norm_(
        policy.parameters(),
        grad_clip_norm,
        error_if_nonfinite=False,
    )
    with lock if lock is not None else nullcontext():
        grad_scaler.step(optimizer)
    grad_scaler.update()
    optimizer.zero_grad()
    if lr_scheduler is not None:
        lr_scheduler.step()
    if has_method(policy, "update"):
        policy.update()
    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    train_metrics.update_s = time.perf_counter() - start_time
    return train_metrics, output_dict

In [3]:
import os
import shutil

OUTPUT_DIR = Path('outputs/segvla_libero')
if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)

dataset_cfg = DatasetConfig(
    repo_id='HuggingFaceVLA/libero',
    root=None,
    streaming=False,
    use_imagenet_stats=True,
)

policy_cfg = SegVLAConfig(
    push_to_hub=False,
    pretrained_path=None,
    load_vlm_weights=False,
    use_amp=False,
)

env_cfg = LiberoEnv(
    task='libero_10',
    init_states=True,
    camera_name='agentview_image,robot0_eye_in_hand_image',
)

eval_cfg = EvalConfig(n_episodes=1, batch_size=1, use_async_envs=False)

cfg = TrainPipelineConfig(
    dataset=dataset_cfg,
    policy=policy_cfg,
    env=env_cfg,
    batch_size=2,
    steps=20,
    num_workers=2,
    eval=eval_cfg,
    eval_freq=0,
    log_freq=5,
    save_checkpoint=False,
    wandb=WandBConfig(enable=False),
    output_dir=OUTPUT_DIR,
)

cfg.validate()
print('Policy:', cfg.policy.type)
print('Dataset repo:', cfg.dataset.repo_id)
print('Env task:', cfg.env.task)



Policy: segvla
Dataset repo: HuggingFaceVLA/libero
Env task: libero_10


In [4]:
init_logging()
if cfg.seed is not None:
    set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
print('Using device:', device)

Using device: cuda


In [5]:
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle

dataset = make_dataset(cfg)
print(f'Total episodes: {dataset.num_episodes}')
print(f'Total frames: {dataset.num_frames}')

if hasattr(cfg.policy, 'drop_n_last_frames'):
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes['dataset_from_index'],
        dataset.meta.episodes['dataset_to_index'],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
    shuffle = False
else:
    sampler = None
    shuffle = True

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=cfg.batch_size,
    shuffle=shuffle and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == 'cuda',
    drop_last=False,
    prefetch_factor=2,
)
dl_iter = cycle(dataloader)
first_batch = next(dl_iter)
print('Batch keys:', first_batch.keys())

Total episodes: 1693
Total frames: 273465
Batch keys: dict_keys(['observation.images.image', 'observation.images.image2', 'observation.state', 'action', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index', 'observation.images.image_is_pad', 'observation.images.image2_is_pad', 'observation.state_is_pad', 'action_is_pad', 'task'])


In [6]:
from lerobot.envs.factory import make_env
from lerobot.envs.utils import close_envs

if cfg.env is not None:
    eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
    print('Eval suites:', list(eval_env.keys()))
else:
    eval_env = None

Creating LIBERO envs | suites=['libero_10'] | n_envs(per task)=1 | init_states=True
[info] using task orders [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Built vec env | suite=libero_10 | task_id=0 | n_envs=1
Built vec env | suite=libero_10 | task_id=1 | n_envs=1
Built vec env | suite=libero_10 | task_id=2 | n_envs=1
Built vec env | suite=libero_10 | task_id=3 | n_envs=1
Built vec env | suite=libero_10 | task_id=4 | n_envs=1
Built vec env | suite=libero_10 | task_id=5 | n_envs=1
Built vec env | suite=libero_10 | task_id=6 | n_envs=1
Built vec env | suite=libero_10 | task_id=7 | n_envs=1
Built vec env | suite=libero_10 | task_id=8 | n_envs=1
Built vec env | suite=libero_10 | task_id=9 | n_envs=1
Eval suites: ['libero_10']


In [7]:
from lerobot.policies.factory import make_policy, make_pre_post_processors

policy = make_policy(cfg.policy, ds_meta=dataset.meta)
processor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
    processor_kwargs['dataset_stats'] = dataset.meta.stats
preprocessor, postprocessor = make_pre_post_processors(
    policy_cfg=cfg.policy,
    pretrained_path=cfg.policy.pretrained_path,
    **processor_kwargs,
)
print('Policy parameters:', sum(p.numel() for p in policy.parameters()) )

You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


Reducing the number of VLM layers to 16 ...
Policy parameters: 450046176


In [8]:
from lerobot.optim.factory import make_optimizer_and_scheduler

optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)

train_meters = {
    'loss': AverageMeter('loss', ':.3f'),
    'grad_norm': AverageMeter('grdn', ':.3f'),
    'lr': AverageMeter('lr', ':0.1e'),
    'update_s': AverageMeter('updt_s', ':.3f'),
    'dataloading_s': AverageMeter('data_s', ':.3f'),
}
train_tracker = MetricsTracker(
    cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_meters, initial_step=0
)
step = 0

INFO 2025-11-21 10:24:09 hedulers.py:105 Auto-scaling LR scheduler: num_training_steps (20) < num_decay_steps (30000). Scaling warmup: 1000 → 0, decay: 30000 → 20 (scale factor: 0.001)


In [9]:
MAX_VIS_STEPS = min(cfg.steps, 5)
policy.train()
for _ in range(MAX_VIS_STEPS):
    start_time = time.perf_counter()
    batch = next(dl_iter)
    batch = preprocessor(batch)
    train_tracker.dataloading_s = time.perf_counter() - start_time
    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,
        grad_scaler=grad_scaler,
        lr_scheduler=lr_scheduler,
        use_amp=cfg.policy.use_amp,
    )
    step += 1
    train_tracker.step()
    if step % cfg.log_freq == 0 or step == MAX_VIS_STEPS:
        print(train_tracker)
        train_tracker.reset_averages()

seg_logits is none
seg_logits is none


NameError: name 'key' is not defined

In [None]:
from lerobot.scripts.lerobot_eval import eval_policy_all

if eval_env is not None:
    with torch.no_grad(), torch.autocast(device_type=device.type, enabled=cfg.policy.use_amp):
        eval_info = eval_policy_all(
            envs=eval_env,
            policy=policy,
            preprocessor=preprocessor,
            postprocessor=postprocessor,
            n_episodes=cfg.eval.n_episodes,
            videos_dir=cfg.output_dir / 'eval_videos',
            max_episodes_rendered=1,
            start_seed=cfg.seed,
            max_parallel_tasks=cfg.env.max_parallel_tasks,
        )
    print('Eval overall metrics:', eval_info['overall'])
    close_envs(eval_env)
else:
    print('Eval env not created.')