# Get Batch - Devuelve un batch del tipo que next(dl_iter) devuelve

In [None]:
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

from lerobot.configs import parser
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle

from xhuman.configs.train import TrainPipelineConfigXHUMAN
from xhuman.datasets.factory import make_dataset_xhuman
from xhuman.datasets.utils import split_train_eval_episodes
from xhuman.policies.factory import make_xhuman_policy, make_xhuman_pre_post_processors

: 

In [None]:
# Configuración manual (ajusta según necesites)
from xhuman.configs.default import LerobotDatasetConfig
from xhuman.policies.pi05.configuration_pi05 import PI05Config

# Crear config del dataset
dataset_config = LerobotDatasetConfig(
    repo_id="tu_repo_id/aqui",  # Cambia esto
    root=None,  # o ruta local si tienes el dataset descargado
)

# Crear config de la policy
policy_config = PI05Config(
    pretrained_path=None,  # o ruta a checkpoint si tienes uno
    device="cuda" if torch.cuda.is_available() else "cpu",
)

# Crear config de entrenamiento
cfg = TrainPipelineConfigXHUMAN(
    dataset=dataset_config,
    policy=policy_config,
    batch_size=8,
    num_workers=4,
    split_ratio=0.8,
)

cfg.validate()
device = torch.device(policy_config.device if policy_config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Crear accelerator (igual que en train_val_pi05.py)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
device = accelerator.device

In [None]:
# Cargar dataset siguiendo el patrón de train_val_pi05.py
# Primero cargar dataset completo para obtener total_episodes
dataset = make_dataset_xhuman(cfg)
episodes = list(range(dataset.meta.total_episodes))
train_episodes, _ = split_train_eval_episodes(episodes, split_ratio=cfg.split_ratio, seed=42)
if 30 in train_episodes:
    train_episodes.remove(30)

# Eliminar dataset para liberar memoria (igual que en train_val_pi05.py)
del dataset

# Crear dataset solo con train_episodes
cfg.dataset.episodes = train_episodes
dataset = make_dataset_xhuman(cfg)

In [None]:
# Crear policy y preprocessor
policy = make_xhuman_policy(cfg=cfg.policy, ds_meta=dataset.meta)
processor_kwargs = {"dataset_stats": dataset.meta.stats}
if cfg.policy.pretrained_path is not None:
    processor_kwargs["preprocessor_overrides"] = {
        "device_processor": {"device": device.type},
        "normalizer_processor": {
            "stats": dataset.meta.stats,
            "features": {**policy.config.input_features, **policy.config.output_features},
            "norm_map": policy.config.normalization_mapping,
        },
    }
preprocessor, _ = make_xhuman_pre_post_processors(
    policy_cfg=cfg.policy,
    pretrained_path=cfg.policy.pretrained_path,
    **processor_kwargs,
)

In [None]:
# Crear dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes["dataset_from_index"],
        dataset.meta.episodes["dataset_to_index"],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
else:
    sampler = None

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=cfg.batch_size,
    shuffle=(sampler is None) and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)

In [None]:
# Preparar dataloader con accelerator (igual que en train_val_pi05.py)
# Nota: solo preparamos el dataloader, no policy ni optimizer ya que no los necesitamos
dataloader = accelerator.prepare(dataloader)
dl_iter = cycle(dataloader)

In [None]:
# Obtener batch procesado (igual que next(dl_iter) en train_val_pi05.py)
batch = next(dl_iter)
batch = preprocessor(batch)

# El batch está listo para usar
print("Batch keys:", list(batch.keys()))
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")