In [1]:
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
from threading import Lock

import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler

from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
    format_big_number,
    get_safe_torch_device,
    init_hydra_config,
    init_logging,
    set_global_seed,
)
from lerobot.scripts.eval import eval_policy
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled=False)


In [3]:
config_name="default"
config_path="../configs"


from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)  
cfg = compose(config_name=config_name, overrides=["policy=vqbet", "env=pusht", "dataset_repo_id=lerobot/pusht", "device=cuda"])

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path=config_path)


In [4]:
# eval_env = make_env(cfg)

In [5]:
device = get_safe_torch_device(cfg.device, log=True)
device

device(type='cuda')

In [6]:
set_global_seed(cfg.seed) 
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [7]:
offline_dataset = make_dataset(cfg)

Fetching 212 files: 100%|██████████| 212/212 [00:00<00:00, 196593.51it/s]


In [8]:
offline_dataset

LeRobotDataset(
  Repository ID: 'lerobot/pusht',
  Split: 'train',
  Number of Samples: 25650,
  Number of Episodes: 206,
  Type: video (.mp4),
  Recorded Frames per Second: 10,
  Camera Keys: ['observation.image'],
  Video Frame Keys: ['observation.image'],
  Transformations: None,
  Codebase Version: v1.6,
)

In [9]:
cfg.training.get("delta_timestamps")

{'observation.image': [-0.4, -0.3, -0.2, -0.1, 0.0], 'observation.state': [-0.4, -0.3, -0.2, -0.1, 0.0], 'action': [-0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]}

In [10]:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
    offline_dataset,
    num_workers=cfg.training.num_workers,
    batch_size=cfg.training.batch_size,
    shuffle=shuffle,
    sampler=sampler,
    pin_memory=device.type != "cpu",
    drop_last=False,
)
dl_iter = cycle(dataloader)

In [11]:
batch = next(dl_iter)
batch.keys()

dict_keys(['observation.image', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'next.success', 'index', 'observation.image_is_pad', 'observation.state_is_pad', 'action_is_pad'])

In [12]:
batch['observation.image'].shape, batch['action'].shape

(torch.Size([64, 5, 3, 96, 96]), torch.Size([64, 15, 2]))

In [13]:
policy = make_policy(
    hydra_cfg=cfg,
    dataset_stats=offline_dataset.stats if not cfg.resume else None,
    pretrained_policy_name_or_path=None,
)

number of parameters: 26.00M


In [14]:
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg) 
grad_scaler = GradScaler(enabled=cfg.use_amp) 

  grad_scaler = GradScaler(enabled=cfg.use_amp)


In [15]:
batch = next(dl_iter) 
for key in batch:
    batch[key] = batch[key].to(device, non_blocking=True)

In [None]:
# train_info = update_policy(
#     policy,
#     batch,
#     optimizer,
#     cfg.training.grad_clip_norm,
#     grad_scaler=grad_scaler,
#     lr_scheduler=lr_scheduler,
#     use_amp=cfg.use_amp,
# )

In [20]:
output_dict = policy.forward(batch)
output_dict.keys()

dict_keys(['loss', 'n_different_codes', 'n_different_combinations', 'recon_l1_error'])

In [21]:
output_dict['loss']

tensor(0.3762, device='cuda:0', grad_fn=<AddBackward0>)