In [19]:
from dino_wm.utils.get_model import get_model
from dino_wm.planning.objectives import create_objective_fn
from dino_wm.planning.mpc import MPCPlanner
from dino_wm.planning.cem import CEMPlanner
from einops import rearrange, repeat
import einops
from dino_wm.utils.utils import move_to_device
from dino_wm.env.venv import SubprocVectorEnv
import gymnasium as gym
import gym_pusht
from gym_pusht.envs import PushTEnv
import torch
import numpy as np
from dino_wm.models.visual_world_model import VWorldModel
from dino_wm.utils.preprocessor import Preprocessor
import imageio
from torchvision import transforms





In [20]:
n_envs = 1
frameskip = 5
horizon = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
wm, dataset, data_preprocessor = get_model("/home/ianchuang/dino_wm/outputs/checkpoints", "pusht", device)
wm : VWorldModel = wm
wm.to(device)

Loaded 18685 rollouts
Loaded 21 rollouts
Resuming from epoch 2: /home/ianchuang/dino_wm/outputs/checkpoints/outputs/pusht/checkpoints/model_latest.pth


Using cache found in /home/ianchuang/.cache/torch/hub/facebookresearch_dinov2_main


num_action_repeat: 1
num_proprio_repeat: 1
proprio encoder: ProprioceptiveEmbedding(
  (patch_embed): Conv1d(4, 10, kernel_size=(1,), stride=(1,))
)
action encoder: ProprioceptiveEmbedding(
  (patch_embed): Conv1d(10, 10, kernel_size=(1,), stride=(1,))
)
proprio_dim: 10, after repeat: 10
action_dim: 10, after repeat: 10
emb_dim: 404
Model emb_dim:  404


VWorldModel(
  (encoder): DinoV2Encoder(
    (base_model): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-11): 12 x NestedTensorBlock(
          (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=384, out_features=384, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=1536, out_features=38

In [22]:
# gym.vector.SyncVectorEnv
env = gym.vector.AsyncVectorEnv(
    [
        lambda: gym.make(
            "gym_pusht/PushT-v0", 
            disable_env_checker=True, 
            relative=True,
            legacy=False,
            obs_type="visual_proprio", 
            render_mode="rgb_array",
            observation_width=224,
            observation_height=224,
        )
        for _ in range(n_envs)
    ]
)
obs, info = env.reset()

In [46]:
obs, actions, state, info = dataset[0]

obs = {
    k: v[-(frameskip*horizon):]
    for k, v in obs.items()
}
actions = actions[-(frameskip*horizon):]
state = state[-(frameskip*horizon):]
print(f"obs['visual'] shape: {obs['visual'].shape}")
print(f"obs['proprio'] shape: {obs['proprio'].shape}")
print(f"actions shape: {actions.shape}")
print(f"state shape: {state.shape}")

start_obs = {}
start_obs['visual'] = obs['visual'][:1].unsqueeze(0).repeat(n_envs, 1, 1, 1, 1)
start_obs['proprio'] = obs['proprio'][:1].unsqueeze(0).repeat(n_envs, 1, 1)
end_obs = {}
end_obs['visual'] = obs['visual'][-1:].unsqueeze(0).repeat(n_envs, 1, 1, 1, 1)
end_obs['proprio'] = obs['proprio'][-1:].unsqueeze(0).repeat(n_envs, 1, 1)
print(f"start_obs['visual'] shape: {start_obs['visual'].shape}")
print(f"start_obs['proprio'] shape: {start_obs['proprio'].shape}")
print(f"end_obs['visual'] shape: {end_obs['visual'].shape}")
print(f"end_obs['proprio'] shape: {end_obs['proprio'].shape}")

actions = einops.rearrange(actions, "(h f) a -> h (f a)", f=frameskip, h=horizon).unsqueeze(0).repeat(n_envs, 1, 1)
print(f"actions shape: {actions.shape}")


obs['visual'] shape: torch.Size([25, 3, 224, 224])
obs['proprio'] shape: torch.Size([25, 4])
actions shape: torch.Size([25, 2])
state shape: torch.Size([25, 7])
start_obs['visual'] shape: torch.Size([1, 1, 3, 224, 224])
start_obs['proprio'] shape: torch.Size([1, 1, 4])
end_obs['visual'] shape: torch.Size([1, 1, 3, 224, 224])
end_obs['proprio'] shape: torch.Size([1, 1, 4])
actions shape: torch.Size([1, 5, 10])


In [47]:
def get_inverse_normalize(mean, std):
    inv_std = [1.0/s for s in std]
    inv_mean = [-m/s for m, s in zip(mean, std)]
    return transforms.Normalize(mean=inv_mean, std=inv_std)
inverse_normalize = get_inverse_normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

imageio.mimsave(
    "dataset.mp4",
    einops.rearrange(inverse_normalize(obs['visual']) * 255, "b c h w -> b h w c").cpu().numpy().astype(np.uint8),
    fps=10,
)

In [48]:
images = []
env.reset(options={
    'reset_to_state': state[0],
})
env_actions = einops.rearrange(actions, "b h (f a) -> b (h f) a", f=frameskip)
env_actions = data_preprocessor.denormalize_actions(env_actions) * 100.0
for i in range(env_actions.shape[1]):
    obs, reward, terminated, truncated, info = env.step(env_actions[:, i])
    images.append(obs['visual'])

imageio.mimsave(
    "env.mp4",
    einops.rearrange(np.array(images), "n b h w c -> n h (b w) c"),
    fps=10,
)

In [None]:
start_obs = {
    k: v.to(device)
    for k, v in start_obs.items()
}
actions = actions.to(device)

z_obs, z = wm.rollout(start_obs, actions)

In [None]:
wm_obs, diff = wm.decode_obs(z_obs)
wm_images = wm_obs['visual'].squeeze(0)
wm_images = torch.clamp(wm_images, 0, 1)

imageio.mimsave(
    "imagination.mp4",
    einops.rearrange(inverse_normalize(wm_images) * 255, "b c h w -> b h w c").detach().cpu().numpy().astype(np.uint8),
    fps=10 // 5,
)

tensor(1., device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.5000, device='cuda:0', grad_fn=<MinBackward1>)


In [None]:

def sample_traj_segment_from_dset(self, traj_len):
        states = []
        actions = []
        observations = []
        env_info = []

        # Check if any trajectory is long enough
        valid_traj = [
            self.dset[i][0]["visual"].shape[0]
            for i in range(len(self.dset))
            if self.dset[i][0]["visual"].shape[0] >= traj_len
        ]
        if len(valid_traj) == 0:
            raise ValueError("No trajectory in the dataset is long enough.")

        # sample init_states from dset
        for i in range(self.n_evals):
            max_offset = -1
            while max_offset < 0:  # filter out traj that are not long enough
                traj_id = random.randint(0, len(self.dset) - 1)
                obs, act, state, e_info = self.dset[traj_id]
                max_offset = obs["visual"].shape[0] - traj_len
            state = state.numpy()
            offset = random.randint(0, max_offset)
            obs = {
                key: arr[offset : offset + traj_len]
                for key, arr in obs.items()
            }
            state = state[offset : offset + traj_len]
            act = act[offset : offset + self.frameskip * self.goal_H]
            actions.append(act)
            states.append(state)
            observations.append(obs)
            env_info.append(e_info)
        return observations, states, actions, env_info


# update env config from val trajs
observations, states, actions, env_info = (
    self.sample_traj_segment_from_dset(traj_len=self.frameskip * self.goal_H + 1)
)
self.env.update_env(env_info)

# get states from val trajs
init_state = [x[0] for x in states]
init_state = np.array(init_state)
actions = torch.stack(actions)
if self.goal_source == "random_action":
    actions = torch.randn_like(actions)
wm_actions = rearrange(actions, "b (t f) d -> b t (f d)", f=self.frameskip)
exec_actions = self.data_preprocessor.denormalize_actions(actions)
# replay actions in env to get gt obses
rollout_obses, rollout_states = self.env.rollout(
    self.eval_seed, init_state, exec_actions.numpy()
)
self.obs_0 = {
    key: np.expand_dims(arr[:, 0], axis=1)
    for key, arr in rollout_obses.items()
}
self.obs_g = {
    key: np.expand_dims(arr[:, -1], axis=1)
    for key, arr in rollout_obses.items()
}
self.state_0 = init_state  # (b, d)
self.state_g = rollout_states[:, -1]  # (b, d)
self.gt_actions = wm_actions

    

In [7]:
batch = {
    k: np.expand_dims(v, axis=1).astype(np.float32)
    for k, v in obs.items()
}
batch = data_preprocessor.transform_obs(batch)
batch = {
    k: v.to(device)
    for k, v in batch.items()
}

z_obs_g = wm.encode_obs(batch)