In [None]:
import os
import h5py 
import torch 
import numpy as np
from torch.nn.functional import interpolate
from pathlib import Path

input_dir = '/home/mim-server/datasets/pushT/224/'
output_dir = '/home/mim-server/datasets/pushT/sharded/'
input_files = os.listdir(input_dir)
shard_file = input_files[0]  

input_data_path = os.path.join(input_dir, shard_file)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

def loadPushTData(data_path):
    with h5py.File(data_path, 'r') as f:
        data = f
        actions = data['actions'][:]
        images = data['cam1'][:]
    images = torch.from_numpy(images).float()
    images = interpolate(images, size=(256,256), mode='bilinear', align_corners=False)
    images = images.permute(0, 2, 3, 1)  # NHWC to NCHW
    images = (images*255).to(torch.uint8)
    return actions[None], images[None].numpy()

def saveShardedData(output_path, actions, images):
    with h5py.File(output_path, 'w') as f:
        comp_args = {'compression': 'gzip', 'compression_opts': 4}
        image_shape = images.shape[2:]
        action_shape = actions.shape[2:]
        images_ds = f.create_dataset('images', shape=images.shape, dtype=np.uint8, chunks=(1, 128, *image_shape), **comp_args)
        actions_ds = f.create_dataset('actions', shape=actions.shape, dtype=np.float32, chunks=(1, 128, *action_shape), **comp_args)
        f.create_dataset('episode_lengths', data=np.array([images.shape[1]], dtype=np.int32))
        f.attrs['num_episodes'] = images.shape[1]
        images_ds[:] = images
        actions_ds[:] = actions

In [32]:
for i, data_file in enumerate(input_files):
    print(f"Processing file {i+1}/{len(input_files)}: {data_file}")
    input_data_path = os.path.join(input_dir, data_file)
    output_data_path = os.path.join(output_dir, f"shard_{i:04d}.h5")
    actions, images = loadPushTData(input_data_path)
    print(f"  Loaded data shape - actions: {actions.shape}, images: {images.shape}")
    saveShardedData(output_data_path, actions, images)

Processing file 1/13: episode_3.h5
  Loaded data shape - actions: (1, 6609, 3), images: (1, 6609, 256, 256, 3)
Processing file 2/13: episode_7.h5
  Loaded data shape - actions: (1, 6013, 3), images: (1, 6013, 256, 256, 3)
Processing file 3/13: episode_9.h5
  Loaded data shape - actions: (1, 4513, 3), images: (1, 4513, 256, 256, 3)
Processing file 4/13: episode_11.h5
  Loaded data shape - actions: (1, 6014, 3), images: (1, 6014, 256, 256, 3)
Processing file 5/13: episode_8.h5
  Loaded data shape - actions: (1, 6014, 3), images: (1, 6014, 256, 256, 3)
Processing file 6/13: episode_4.h5
  Loaded data shape - actions: (1, 6183, 3), images: (1, 6183, 256, 256, 3)
Processing file 7/13: episode_12.h5
  Loaded data shape - actions: (1, 6013, 3), images: (1, 6013, 256, 256, 3)
Processing file 8/13: episode_5.h5
  Loaded data shape - actions: (1, 5631, 3), images: (1, 5631, 256, 256, 3)
Processing file 9/13: episode_1.h5
  Loaded data shape - actions: (1, 6310, 3), images: (1, 6310, 256, 256, 3)

In [33]:
str(output_dir)

'/home/mim-server/datasets/pushT/sharded'

In [34]:
from dreamerv4.datasets import ShardedHDF5Dataset

dataset = ShardedHDF5Dataset(str(output_dir), window_size=192, stride=1)

Train split: 0 windows from 0 episodes


In [37]:
import torch 
from dreamerv4.datasets import ShardedHDF5Dataset
from dreamerv4.models.utils import load_tokenizer
from dreamerv4.models.utils import load_denoiser
from dreamerv4.models.dynamics import DenoiserWrapper
from dreamerv4.models.tokenizer import TokenizerWrapper
from hydra import initialize, compose
from omegaconf import OmegaConf
with initialize(version_base=None, config_path="scripts/config"):
    cfg = compose(config_name="dynamics/pushT.yaml")

model = DenoiserWrapper(cfg)
# model = TokenizerWrapper(cfg)



In [38]:
import torch

ckpt_path = "/home/mim-server/projects/rooholla/dreamer-v4-draft/checkpoints/dynamics_ckpts/video-pushT-110M.pt"
state = torch.load(ckpt_path, map_location='cpu')
if 'dyn' in state.keys():
    sd = state['dyn']
else:
    sd = state['model']
    
to_delete = []
clean_sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
for k in clean_sd.keys():
    if k.endswith("cos_emb") or k.endswith("sin_emb") or "temporal_mask_full" in k:
        to_delete.append(k)
for k in to_delete:
    del clean_sd[k]

spell_corrrected = {}
for key in clean_sd.keys():
    val = clean_sd[key]
    if 'diffuion' in key:
        key = key.replace('diffuion', 'diffusion')
        print(key)

    spell_corrrected[key]=val

clean_sd = spell_corrrected

state['model'] = clean_sd
if 'model' not in state.keys():
    clean_sd = {'model.'+k: v for k, v in clean_sd.items()}
    state['model'] = clean_sd
if 'dyn' in state.keys():
    del state['dyn']



print(state.keys())
print(state['model'].keys())

dict_keys(['epoch', 'global_update', 'model', 'optim', 'scheduler', 'wandb_run_id', 'log_dir'])
dict_keys(['model.register_tokens', 'model.action_tokens', 'model.diffusion_embedder.embeddings', 'model.shortcut_embedder.embeddings', 'model.layers.0.layers.0.attn.attn.W_q.weight', 'model.layers.0.layers.0.attn.attn.W_k.weight', 'model.layers.0.layers.0.attn.attn.W_v.weight', 'model.layers.0.layers.0.attn.attn.W_o.weight', 'model.layers.0.layers.0.norm1.weight', 'model.layers.0.layers.0.norm2.weight', 'model.layers.0.layers.0.ffn.up.weight', 'model.layers.0.layers.0.ffn.gate.weight', 'model.layers.0.layers.0.ffn.down.weight', 'model.layers.0.layers.1.attn.attn.W_q.weight', 'model.layers.0.layers.1.attn.attn.W_k.weight', 'model.layers.0.layers.1.attn.attn.W_v.weight', 'model.layers.0.layers.1.attn.attn.W_o.weight', 'model.layers.0.layers.1.norm1.weight', 'model.layers.0.layers.1.norm2.weight', 'model.layers.0.layers.1.ffn.up.weight', 'model.layers.0.layers.1.ffn.gate.weight', 'model.layers

In [39]:
model.load_state_dict(state['model'], strict=True)


<All keys matched successfully>

In [40]:
torch.save(state, ckpt_path)