In [18]:
!pwd

/home/akopyane/rl/lumos


In [1]:
from pathlib import Path

project_root = (Path.home() / "rl" / "lumos").resolve()
assert (project_root / "config").is_dir(), project_root  # optional sanity check


In [2]:
# Cell 1 â”€ load the Hydra config just like scripts/train_wm.py
import os, sys
from pathlib import Path
from hydra import initialize_config_dir, compose
from omegaconf import OmegaConf

project_root = Path("/home/akopyane/rl/lumos").resolve()
sys.path.insert(0, project_root.as_posix())

with initialize_config_dir(version_base="1.3", config_dir=str(project_root / "config")):
    cfg = compose(
        config_name="train_wm",
        overrides = [
            "datamodule.root_data_dir=dataset/task_D_D",
            "datamodule.batch_size=1",
            "datamodule.seq_len=10",
            "datamodule.datasets.vision_dataset.use_cached_data=false",
            "datamodule.datasets.vision_dataset.num_workers=0",
            "trainer.accelerator=gpu",
            "trainer.devices=1",
            # "trainer.max_steps=500",
            "trainer.max_epochs=1000",
            "trainer.log_every_n_steps=1",
            "trainer.check_val_every_n_epoch=10",
            "+trainer.overfit_batches=1.0",
            "+trainer.fast_dev_run=false",  # make sure this is here
            "datamodule.datasets.vision_dataset._target_=lumos.datasets.vision_wm_disk_dataset.VisionWMDiskDataset",
        ]
    )


print(OmegaConf.to_yaml(cfg.datamodule))


datasets:
  vision_dataset:
    _target_: lumos.datasets.vision_wm_disk_dataset.VisionWMDiskDataset
    key: vis
    save_format: npz
    use_cached_data: false
    reset_prob: ${datamodule.reset_prob}
    batch_size: ${datamodule.batch_size}
    min_window_size: ${datamodule.seq_len}
    max_window_size: ${datamodule.seq_len}
    proprio_state: ${datamodule.proprioception_dims}
    obs_space: ${datamodule.observation_space}
    pad: false
    for_wm: true
    lang_folder: ''
    num_workers: 0
transforms:
  train:
    rgb_static:
    - _target_: lumos.utils.transforms.ScaleImageTensor
    - _target_: torchvision.transforms.Normalize
      mean:
      - 0.5
      std:
      - 0.5
    rgb_gripper:
    - _target_: lumos.utils.transforms.ScaleImageTensor
    - _target_: torchvision.transforms.Normalize
      mean:
      - 0.5
      std:
      - 0.5
    out_rgb:
    - _target_: lumos.utils.transforms.UnNormalizeImageTensor
      mean:
      - 0.5
      std:
      - 0.5
    robot_obs:
    -

In [3]:
# Cell 2 â”€ instantiate the transforms exactly like BaseDataModule.setup
import hydra
import torchvision
from lumos.datasets.utils.episode_utils import load_dataset_statistics

import torch.nn.functional as F
from torchvision import transforms

train_dir = project_root / cfg.datamodule.root_data_dir / "training"
val_dir = project_root / cfg.datamodule.root_data_dir / "validation"

transforms_cfg = load_dataset_statistics(train_dir, val_dir, cfg.datamodule.transforms)

def build_tfms(branch):
    return {
        cam: torchvision.transforms.Compose([hydra.utils.instantiate(t) for t in transforms_cfg[branch][cam]])
        for cam in transforms_cfg[branch]
    }

train_tfms = build_tfms("train")
val_tfms = build_tfms("val")


resize34 = transforms.Lambda(
    lambda x: F.interpolate(x, size=(34, 34), mode="bilinear", align_corners=False)
)

def inject_resize(pipe):
    steps = list(pipe.transforms)
    return transforms.Compose([steps[0], resize34, *steps[1:]])

train_tfms["rgb_static"] = inject_resize(train_tfms["rgb_static"])
train_tfms["rgb_gripper"] = inject_resize(train_tfms["rgb_gripper"])
val_tfms["rgb_static"] = inject_resize(val_tfms["rgb_static"])
val_tfms["rgb_gripper"] = inject_resize(val_tfms["rgb_gripper"])




In [4]:
# Cell 3 â”€ build the dataset/dataloader and inspect shapes
from torch.utils.data import DataLoader
from lumos.datasets.vision_wm_disk_dataset import VisionWMDiskDataset
from lumos.utils.nn_utils import transpose_collate_wm

dataset_cfg = cfg.datamodule.datasets.vision_dataset

train_ds = VisionWMDiskDataset(
    datasets_dir=train_dir,
    obs_space=cfg.datamodule.observation_space,
    proprio_state=cfg.datamodule.proprioception_dims,
    key=dataset_cfg.key,
    lang_folder=dataset_cfg.lang_folder,
    num_workers=dataset_cfg.num_workers,
    transforms=train_tfms,
    batch_size=cfg.datamodule.batch_size,
    min_window_size=cfg.datamodule.seq_len,
    max_window_size=cfg.datamodule.seq_len,
    pad=dataset_cfg.pad,
    for_wm=dataset_cfg.for_wm,
    reset_prob=cfg.datamodule.reset_prob,
    save_format=dataset_cfg.save_format,
    use_cached_data=False,
)

sample = train_ds[0]
print("Sequence keys:", list(sample.keys()))
for k, v in sample.items():
    if hasattr(v, "shape"):
        print(f"{k:15s} {tuple(v.shape)} {v.dtype}")

loader = DataLoader(
    train_ds,
    batch_size=cfg.datamodule.batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=transpose_collate_wm,
)

batch = next(iter(loader))
print("\nBatch shapes:")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"{k:15s} {tuple(v.shape)} {v.dtype}")


Sequence keys: ['robot_obs', 'rgb_obs', 'depth_obs', 'actions', 'state_info', 'lang', 'reset', 'frame', 'idx']
robot_obs       (10, 15) torch.float32
lang            (0,) torch.float32
reset           (10,) torch.bool
frame           (10,) torch.int32

Batch shapes:
robot_obs       (10, 1, 15) torch.float32
lang            (1, 0) torch.float32
reset           (10, 1) torch.bool
frame           (10, 1) torch.int32
idx             (1,) torch.int64


In [5]:
batch["actions"]["rel_actions"].shape

torch.Size([10, 1, 7])

In [16]:
batch["robot_obs"].shape

torch.Size([10, 1, 15])

In [6]:
batch["rgb_obs"]["rgb_static"].shape

torch.Size([10, 1, 3, 34, 34])

In [21]:
batch.keys()

dict_keys(['robot_obs', 'rgb_obs', 'depth_obs', 'actions', 'state_info', 'lang', 'reset', 'frame', 'idx'])

In [1]:
sample["rgb_obs"]["rgb_static"].shape

NameError: name 'sample' is not defined

In [6]:
import pytorch_lightning as pl

class DebugDataModule(pl.LightningDataModule):
    def __init__(self, train_loader, val_loader):
        super().__init__()
        self._train_loader = train_loader
        self._val_loader = val_loader

    def train_dataloader(self):
        return {"vis": self._train_loader}

    def val_dataloader(self):
        return {"vis": self._val_loader}

debug_dm = DebugDataModule(loader, loader)
debug_dm.train_transforms = {"out_rgb": train_tfms["out_rgb"]}
debug_dm.val_transforms = {"out_rgb": val_tfms["out_rgb"]}



  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# Cell 4 â”€ instantiate DreamerV2 + Lightning trainer on CPU
import contextlib
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
from pytorch_lightning import Trainer, seed_everything
from lumos.utils.info_utils import setup_logger, setup_callbacks


from datetime import datetime
from copy import deepcopy
from omegaconf import OmegaConf

from copy import deepcopy
from omegaconf import OmegaConf


def build_hybrid_world_model_cfg(cfg):
    """
    Massage cfg.world_model into a DreamerV2Hybrid-compatible config.
    Returns a new DictConfig (original cfg is untouched).
    """
    # Resolve and clone the world-model block
    wm = OmegaConf.create(deepcopy(OmegaConf.to_container(cfg.world_model, resolve=True)))
    OmegaConf.set_struct(wm.loss, False)
    wm.loss["state_weight"] = wm.loss.get("state_weight", 1.0)
    OmegaConf.set_struct(wm.loss, True)
    OmegaConf.set_struct(wm, False)

    # Target the hybrid module
    wm._target_ = "lumos.world_models.dreamer_v2_hybrid.DreamerV2"

    # Drop legacy DreamerV2-only keys
    wm.pop("batch_size", None)
    wm.pop("with_proprio", None)
    wm.pop("robot_dim", None)
    wm.pop("gripper_control", None)

    # Convenience handles
    obs_cfg = cfg.datamodule.observation_space
    proprio_cfg = cfg.datamodule.proprioception_dims
    robot_dim = proprio_cfg.n_state_obs
    scene_dim = 0  # extend if your observation space adds scene state
    has_gripper_cam = "rgb_gripper" in obs_cfg.rgb_obs

    # --- encoder -----------------------------------------------------------------
    enc = wm.encoder
    enc._target_ = "lumos.world_models.encoders.cnn_mlp_encoder.CnnMLPEncoder"
    enc.use_gripper_camera = has_gripper_cam
    enc.stride = enc.get("stride", 2)
    enc.robot_dim = robot_dim
    enc.scene_dim = scene_dim
    enc.state_out_dim = enc.get("state_out_dim", 256)
    enc.state_mlp_layers = enc.get("state_mlp_layers", [512, 512])

    # --- decoder -----------------------------------------------------------------
    dec = wm.decoder
    dec._target_ = "lumos.world_models.decoders.cnn_mlp_decoder.CnnMLPDecoder"
    dec.stride = dec.get("stride", 2)
    dec.layer_norm = dec.get("layer_norm", True)
    dec.activation = dec.get("activation", "elu")
    dec.mlp_layers = dec.get("mlp_layers", 0)
    dec.state_mlp_layers = dec.get("state_mlp_layers", [512, 512])
    dec.use_gripper_camera = has_gripper_cam
    dec.robot_dim = robot_dim
    dec.scene_dim = scene_dim
    dec.in_dim = enc.cnn_depth * 32  # visual embedding size
    dec.in_dim += enc.state_out_dim   # plus state-MLP output

    # The hybrid module sets rssm.cell.embed_dim internally; no manual tweak needed.

    # --- training/AMP knobs -------------------------------------------------------
    wm.train_batch_size = cfg.datamodule.batch_size
    wm.val_batch_size = cfg.datamodule.batch_size
    wm.use_gripper_camera = has_gripper_cam  # positional arg in DreamerV2Hybrid.__init__

    OmegaConf.set_struct(wm.amp, False)
    wm.amp.autocast._target_ = "contextlib.nullcontext"
    wm.amp.autocast.pop("enabled", None)
    wm.amp.scaler.enabled = False
    OmegaConf.set_struct(wm.amp, True)

    OmegaConf.set_struct(wm, True)
    return wm



model_cfg = build_hybrid_world_model_cfg(cfg)
model = hydra.utils.instantiate(model_cfg)
model = model.to("cuda")

OmegaConf.set_struct(cfg, False)
cfg.exp_dir = str(project_root / "logs" / "notebook" / datetime.now().strftime("%Y%m%d_%H%M%S"))
cfg.logger = None
OmegaConf.set_struct(cfg, True)


from pytorch_lightning.loggers import CSVLogger
cfg.logger = None  # disable WandB
csv_logger = CSVLogger(save_dir=str(project_root / "logs"), name="overfit_debug")
trainer = Trainer(logger=csv_logger, **cfg.trainer)



ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [12]:
model

DreamerV2(
  (encoder): CnnMLPEncoder(
    (activation): ELU(alpha=1.0)
    (encoder_static): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(2, 2))
      (1): ELU(alpha=1.0)
      (2): Conv2d(48, 96, kernel_size=(4, 4), stride=(2, 2))
      (3): ELU(alpha=1.0)
      (4): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2))
      (5): ELU(alpha=1.0)
      (6): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2))
      (7): ELU(alpha=1.0)
      (8): Flatten(start_dim=1, end_dim=-1)
    )
    (encoder_gripper): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(2, 2))
      (1): ELU(alpha=1.0)
      (2): Conv2d(48, 96, kernel_size=(4, 4), stride=(2, 2))
      (3): ELU(alpha=1.0)
      (4): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2))
      (5): ELU(alpha=1.0)
      (6): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2))
      (7): ELU(alpha=1.0)
      (8): Flatten(start_dim=1, end_dim=-1)
    )
    (fuse): Sequential(
      (0): Linear(in_features=3072,

In [11]:
batch.keys()

dict_keys(['robot_obs', 'rgb_obs', 'depth_obs', 'actions', 'state_info', 'lang', 'reset', 'frame', 'idx'])

In [13]:
batch["rgb_obs"].keys()

dict_keys(['rgb_static', 'rgb_gripper'])

In [17]:
batch["robot_obs"].shape

torch.Size([10, 1, 15])

In [14]:
batch["state_obs"]

KeyError: 'state_obs'

In [16]:
batch["state_info"].keys()

dict_keys(['robot_obs', 'pre_robot_obs'])

In [10]:
model

DreamerV2(
  (encoder): CnnEncoder(
    (activation): ELU(alpha=1.0)
    (encoder_static): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ELU(alpha=1.0)
      (2): Conv2d(48, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ELU(alpha=1.0)
      (4): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ELU(alpha=1.0)
      (6): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ELU(alpha=1.0)
      (8): Flatten(start_dim=1, end_dim=-1)
    )
    (encoder_gripper): Sequential(
      (0): Conv2d(3, 48, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ELU(alpha=1.0)
      (2): Conv2d(48, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ELU(alpha=1.0)
      (4): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ELU(alpha=1.0)
      (6): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7):

In [None]:
trainer.fit(model, datamodule=debug_dm)

In [None]:
print(12)

In [None]:
!pip install pandas

In [None]:
metrics