In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os


# from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
import imageio 
import torch
import torch.nn as nn
import torch.nn.functional as F

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices

In [3]:
def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result 

In [4]:
# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

In [5]:
class PushTImageDataset(torch.utils.data.Dataset):
    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')

        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)
        train_image_data = train_image_data.astype(np.float32)  #for ns data

        # (N, D)
        train_data = {
            # first two dims of state vector are agent (i.e. gripper) locations
            'agent_pos': dataset_root['data']['state'][:,:2],
            'action': dataset_root['data']['action'][:]
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]
        return nsample


In [8]:
pred_horizon = 16
obs_horizon = 2
action_horizon = 8

In [None]:
dataset_path = "data/pusht/pusht_cchi_v7_replay.zarr"
 
dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

In [7]:
stats

{'agent_pos': {'min': array([13.456424, 32.938293], dtype=float32),
  'max': array([496.14618, 510.9579 ], dtype=float32)},
 'action': {'min': array([12., 25.], dtype=float32),
  'max': array([511., 511.], dtype=float32)}}

In [8]:
len(dataset)

24208

In [9]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

In [10]:
batch = next(iter(dataloader))
batch['image'].shape, batch['agent_pos'].shape, batch['action'].shape

(torch.Size([64, 2, 3, 96, 96]),
 torch.Size([64, 2, 2]),
 torch.Size([64, 16, 2]))

In [11]:
batch['image'].min(), batch['image'].max()

(tensor(65.), tensor(255.))

In [12]:
dataset_root = zarr.open(dataset_path, 'r')

# float32, [0,1], (N,96,96,3)
train_image_data = dataset_root['data']['img'][:]
train_image_data = np.moveaxis(train_image_data, -1,1)
# (N,3,96,96)
train_image_data = train_image_data.astype(np.float32)  #for ns data

# (N, D)
train_data = {
    # first two dims of state vector are agent (i.e. gripper) locations
    'agent_pos': dataset_root['data']['state'][:,:2],
    'action': dataset_root['data']['action'][:]
}
episode_ends = dataset_root['meta']['episode_ends'][:]

In [13]:
# compute start and end of each state-action sequence
# also handles padding
indices = create_sample_indices(
    episode_ends=episode_ends,
    sequence_length=pred_horizon,
    pad_before=obs_horizon-1,
    pad_after=action_horizon-1)

In [14]:
indices.shape

(24208, 4)

In [15]:
# compute statistics and normalized data to [-1,1]
stats = dict()
normalized_train_data = dict()
for key, data in train_data.items():
    stats[key] = get_data_stats(data)
    normalized_train_data[key] = normalize_data(data, stats[key])

# images are already normalized
normalized_train_data['image'] = train_image_data

In [16]:
idx=153
buffer_start_idx, buffer_end_idx,  sample_start_idx, sample_end_idx =  indices[idx]

buffer_start_idx, buffer_end_idx,  sample_start_idx, sample_end_idx

(152, 161, 0, 9)

In [17]:
idx=154
buffer_start_idx, buffer_end_idx,  sample_start_idx, sample_end_idx =  indices[idx]

buffer_start_idx, buffer_end_idx,  sample_start_idx, sample_end_idx

(161, 176, 1, 16)

In [18]:
# get nomralized data using these indices
nsample = sample_sequence(
    train_data= normalized_train_data,
    sequence_length= pred_horizon,
    buffer_start_idx=buffer_start_idx,
    buffer_end_idx=buffer_end_idx,
    sample_start_idx=sample_start_idx,
    sample_end_idx=sample_end_idx
)

nsample['image'].shape, nsample['agent_pos'].shape, nsample['action'].shape

((16, 3, 96, 96), (16, 2), (16, 2))

In [19]:
# discard unused observations
nsample['image'] = nsample['image'][:obs_horizon,:]
nsample['agent_pos'] = nsample['agent_pos'][:obs_horizon,:]

nsample['image'].shape, nsample['agent_pos'].shape, nsample['action'].shape

((2, 3, 96, 96), (2, 2), (16, 2))

In [20]:
nsample['action'][-10:]

array([[-0.8116232 ,  0.12757206],
       [-0.8116232 ,  0.16460907],
       [-0.82364726,  0.21810699],
       [-0.84368736,  0.27572012],
       [-0.8517034 ,  0.33333337],
       [-0.85971946,  0.39506173],
       [-0.86372745,  0.4238683 ],
       [-0.8677355 ,  0.44855964],
       [-0.8717435 ,  0.4773662 ],
       [-0.8717435 ,  0.49794233]], dtype=float32)

### create from hdf5 file

In [5]:
import h5py

In [22]:
hdf5_file_name='data/pusht/pusht_v7_zarr_206.hdf5'

f=h5py.File(hdf5_file_name, 'r')
print(f.keys())
print(f['data'].keys())
print(f['data']['demo_1'].keys())
print(f['data']['demo_1']['obs'].keys())
print(f['data']['demo_1']['obs']['img'].shape)
print(f['data']['demo_1']['obs']['state'].shape)

<KeysViewHDF5 ['data', 'mask']>
<KeysViewHDF5 ['demo_0', 'demo_1', 'demo_10', 'demo_100', 'demo_101', 'demo_102', 'demo_103', 'demo_104', 'demo_105', 'demo_106', 'demo_107', 'demo_108', 'demo_109', 'demo_11', 'demo_110', 'demo_111', 'demo_112', 'demo_113', 'demo_114', 'demo_115', 'demo_116', 'demo_117', 'demo_118', 'demo_119', 'demo_12', 'demo_120', 'demo_121', 'demo_122', 'demo_123', 'demo_124', 'demo_125', 'demo_126', 'demo_127', 'demo_128', 'demo_129', 'demo_13', 'demo_130', 'demo_131', 'demo_132', 'demo_133', 'demo_134', 'demo_135', 'demo_136', 'demo_137', 'demo_138', 'demo_139', 'demo_14', 'demo_140', 'demo_141', 'demo_142', 'demo_143', 'demo_144', 'demo_145', 'demo_146', 'demo_147', 'demo_148', 'demo_149', 'demo_15', 'demo_150', 'demo_151', 'demo_152', 'demo_153', 'demo_154', 'demo_155', 'demo_156', 'demo_157', 'demo_158', 'demo_159', 'demo_16', 'demo_160', 'demo_161', 'demo_162', 'demo_163', 'demo_164', 'demo_165', 'demo_166', 'demo_167', 'demo_168', 'demo_169', 'demo_17', 'demo

In [23]:
f['mask'].keys()

<KeysViewHDF5 ['all', 'f100', 'f150', 'f200', 'f50']>

In [42]:
group_name="all"
demos=f['mask'][group_name]
demos=[d.decode('utf-8') for d in demos]
demos = sorted(demos, key=lambda x: int(x.split('_')[1]))

In [None]:
demos

In [10]:
def get_datas(hdf5_file, demos):
    train_images=[]
    train_agent_pos=[]
    train_actions=[]

    episode_ends=[]
    for demo_name in demos: 
        obs=hdf5_file['data'][demo_name]['obs']
        action=hdf5_file['data'][demo_name]['action'][:]

        img=obs['img'][:]
        agent_pos=obs['agent_pos'][:]

        train_images.extend(img)
        train_agent_pos.extend(agent_pos)
        train_actions.extend(action)

        episode_ends.append(len(train_images)) 
    episode_ends=np.array(episode_ends)
    train_images=np.array(train_images)
    train_agent_pos=np.array(train_agent_pos)
    train_actions=np.array(train_actions)
    return train_images, train_agent_pos, train_actions, episode_ends

In [49]:
train_images, train_agent_pos, train_actions, episode_ends = get_datas(f, demos)

train_images.shape, train_agent_pos.shape, train_actions.shape, episode_ends.shape

((25650, 96, 96, 3), (25650, 2), (25650, 2), (206,))

In [52]:
train_image_data = np.moveaxis(train_images, -1,1)
# (N,3,96,96)
train_image_data = train_image_data.astype(np.float32)  #for ns data
train_image_data.shape

(25650, 3, 96, 96)

In [51]:
# episode_ends 

In [6]:
class PushTImageDatasetFromHDF5(torch.utils.data.Dataset):
    def __init__(self,
                 hdf5_file_name: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int,
                 hdf5_filter_key=None):

        f=h5py.File(hdf5_file_name, 'r')
        if hdf5_filter_key is None:
              demos=f['data'].keys()
        else:
            demos=f['mask'][hdf5_filter_key]
            demos=[d.decode('utf-8') for d in demos]

        demos = sorted(demos, key=lambda x: int(x.split('_')[1]))

        train_images, train_agent_pos, train_actions, episode_ends = get_datas(f, demos)
        f.close()

        train_image_data = np.moveaxis(train_images, -1,1)
        # (N,3,96,96)
        train_image_data = train_image_data.astype(np.float32)  #for ns data
 
        # (N, D)
        train_data = { 
            'agent_pos': train_agent_pos,
            'action': train_actions
        } 

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]
        return nsample 

In [13]:
hdf5_file_name='data/pusht/pusht_v7_zarr_206.hdf5'
dataset = PushTImageDatasetFromHDF5(
    hdf5_file_name=hdf5_file_name,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    hdf5_filter_key=None
)

In [14]:
dataset.stats

{'agent_pos': {'min': array([13.456424, 32.938293], dtype=float32),
  'max': array([496.14618, 510.9579 ], dtype=float32)},
 'action': {'min': array([12., 25.], dtype=float32),
  'max': array([511., 511.], dtype=float32)}}

In [15]:
hdf5_file_name='data/pusht/pusht_v7_zarr_206.hdf5'
dataset = PushTImageDatasetFromHDF5(
    hdf5_file_name=hdf5_file_name,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    hdf5_filter_key="f100"
)
dataset.stats

{'agent_pos': {'min': array([13.456424, 36.492805], dtype=float32),
  'max': array([486.46097, 510.9579 ], dtype=float32)},
 'action': {'min': array([13., 25.], dtype=float32),
  'max': array([487., 511.], dtype=float32)}}