# Imports

In [1]:
import os
import sys
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv

In [2]:
# 0. create env object
env = PushTImageEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
seed = 0
env.seed(seed)

# 2. must reset before use
obs = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("obs['image'].shape:", obs["image"].shape, "float32, [0,1]")
    print("obs['agent_pos'].shape:", obs["agent_pos"].shape, "float32, [0,512]")
    print("action.shape: ", action.shape, "float32, [0,512]")

obs['image'].shape: (3, 96, 96) float32, [0,1]
obs['agent_pos'].shape: (2,) float32, [0,512]
action.shape:  (2,) float32, [0,512]


In [39]:
num_states = 100_000 # 500_000
batch_size = 10_000

root_save_dir = "diffusion_policy/data/pusht/images"
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
tag = f"random_{num_states}"
save_dir = os.path.join(root_save_dir, tag, timestamp)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

imgs = []
start_idx = 0
for cur_idx in trange(num_states):
    state = env.random_state(None)
    env._set_state(state)
    img = env._get_obs()["image"]
    imgs.append(img)

    if len(imgs) == batch_size:
        imgs = np.array(imgs)
        fname = f"{start_idx}-{cur_idx}.npy"
        np.save(os.path.join(save_dir, fname), imgs)
        imgs = []
        start_idx = cur_idx + 1

  0%|          | 0/100000 [00:00<?, ?it/s]