In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
import imageio 
import torch
import torch.nn as nn
import torch.nn.functional as F

pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
seed=1000

env = PushTImageEnv() 

env.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
 
obs = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()
 
obs, reward, terminated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("obs['image'].shape:", obs['image'].shape, "float32, [0,1]")
    print("obs['agent_pos'].shape:", obs['agent_pos'].shape, "float32, [0,512]")
    print("action.shape: ", action.shape, "float32, [0,512]")

obs['image'].shape: (3, 96, 96) float32, [0,1]
obs['agent_pos'].shape: (2,) float32, [0,512]
action.shape:  (2,) float32, [0,512]


In [3]:
def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

In [4]:
# # dataset
class PushTImageDataset(torch.utils.data.Dataset):
    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):

        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')

        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)
        train_image_data = train_image_data.astype(np.float32)  #for ns data

        # (N, D)
        train_data = {
            # first two dims of state vector are agent (i.e. gripper) locations
            'agent_pos': dataset_root['data']['state'][:,:2],
            'action': dataset_root['data']['action'][:]
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]
        return nsample


In [5]:
dataset_path = "/home/carl_lab/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"


# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
print("batch['action'].shape", batch['action'].shape)

batch['image'].shape: torch.Size([64, 2, 3, 96, 96])
batch['agent_pos'].shape: torch.Size([64, 2, 2])
batch['action'].shape torch.Size([64, 16, 2])


In [8]:
from my_unet import ConditionalUnet1D
from my_vision_encoder import ResidualBlock, ResNetFe
from my_ddpm import MyScheduler, MyDDPM

In [None]:
t_img = torch.rand(128, 3, 96, 96).float().to('cuda')
# fe=vision_encoder(t_img)
# fe.shape

In [17]:
class VisionEncoder(nn.Module):
    def __init__(self ):
        super(VisionEncoder, self).__init__() 
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # (x+2*3-7)/2+1=(x-1)/2+1=48 : 64*48*48
        self.gn1 = nn.GroupNorm(32, 64) #  num_groups, num_channels
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)       #(x+2*1-3)/2+1=(x-1)/2+1=24  : 64*24*24

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, 512)   
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.gn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
vision_encoder = VisionEncoder()

In [18]:
vision_feature_dim = 512
lowdim_obs_dim = 2
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)
nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})
device = torch.device('cuda')
_ = nets.to(device)
n_params= sum(p.numel() for p in nets.parameters())
print(f"Number of parameters: {n_params:,}")

Number of parameters: 43,699,522


In [20]:
def rollout(env, nets,  seed, max_steps=200):
    
    nets.eval() 
    env.seed(200+seed)
    obs = env.reset()

    # keep a queue of last 2 steps of observations
    obs_deque = collections.deque(
        [obs] * obs_horizon, maxlen=obs_horizon)
    # save visualization and rewards
    imgs = [env.render(mode='rgb_array')]
    rewards = list()
    done = False
    step_idx = 0
    success=False
    with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar:
        while not done:
            B = 1
            # stack the last obs_horizon number of observations
            images = np.stack([x['image'] for x in obs_deque])
            agent_poses = np.stack([x['agent_pos'] for x in obs_deque])

            # normalize observation
            nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
            # images are already normalized to [0,1]
            nimages = images

            # device transfer
            nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
            # (2,3,96,96)
            nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)
            # (2,2)

            # infer action
            with torch.no_grad(): 
                image_features = nets['vision_encoder'](nimages)
                # (2,512) 
                # concat with low-dim observations
                obs_features = torch.cat([image_features, nagent_poses], dim=-1) 
                # reshape observation to (B,obs_horizon*obs_dim)
                obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

                naction,xts=ddpm.sample_ddpm(1, sample_shape, obs_cond)
                

            # unnormalize action
            naction = naction.detach().to('cpu').numpy()
            # (B, pred_horizon, action_dim)
            naction = naction[0]
            action_pred = unnormalize_data(naction, stats=stats['action'])

            # only take action_horizon number of actions
            start = obs_horizon - 1
            end = start + action_horizon
            action = action_pred[start:end,:]
            # (action_horizon, action_dim)

            # execute action_horizon number of steps
            # without replanning
            for i in range(len(action)):
                obs, reward, done, info = env.step(action[i])
                obs_deque.append(obs)
                
                rewards.append(reward)
                imgs.append(env.render(mode='rgb_array'))

                # update progress bar
                step_idx += 1
                pbar.update(1)
                pbar.set_postfix(reward=reward)
                if step_idx > max_steps:
                    done = True
                if done:
                    success=True
                    break

    return max(rewards) , success, imgs

In [21]:
num_diffusion_iters = 100
noise_scheduler=MyScheduler(T=num_diffusion_iters, device=device)

In [22]:
ddpm=MyDDPM(noise_scheduler, nets['noise_pred_net'], device=device)
sample_shape=(pred_horizon, action_dim) 

In [23]:
num_epochs = 400

 
optimizer = torch.optim.AdamW(
    params=nets.parameters(),
    lr=1e-4, weight_decay=1e-6)
 
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

for epoch_idx in tqdm(range(num_epochs)): 
    nets.train()
    epoch_loss = list() 
    for nbatch in dataloader: 
        nimage = nbatch['image'][:,:obs_horizon].to(device)
        nagent_pos = nbatch['agent_pos'][:,:obs_horizon].to(device)
        naction = nbatch['action'].to(device)
        B = nagent_pos.shape[0]

        image_features = nets['vision_encoder'](
            nimage.flatten(end_dim=1))
        image_features = image_features.reshape(
            *nimage.shape[:2],-1) 
        obs_features = torch.cat([image_features, nagent_pos], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)
        
        timesteps = torch.randint(0, noise_scheduler.T,(B,), device=device).long()

        noisy_actions , noise= noise_scheduler.get_xt(naction, timesteps)
        noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()


        loss_cpu = loss.item()
        epoch_loss.append(loss_cpu)  

    if epoch_idx>0 and epoch_idx %100==0:
        rewards = [rollout(env, nets, i*10, 200)[0] for i in range(10)]
        mean_r = np.mean(rewards)
        std_r = np.std(rewards)
        print(f"epoch: {epoch_idx}, mean_r: {mean_r} std_r: {std_r}") 
ema_nets = nets 

  0%|          | 0/400 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

epoch: 100, mean_r: 0.257153445802162 std_r: 0.355503436579649


Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

epoch: 200, mean_r: 0.2743077471567591 std_r: 0.35803039778118817


Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

epoch: 300, mean_r: 0.34614798918778167 std_r: 0.40980096509227143


In [24]:
seed=40

env = PushTImageEnv() 

env.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) 

rewards = [rollout(env, nets, i*10, 200)[0] for i in range(50)]
mean_r = np.mean(rewards)
std_r = np.std(rewards)
print(f"mean_r: {mean_r} std_r: {std_r}") 

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

mean_r: 0.2791431659402337 std_r: 0.3822333886469687


In [25]:
rewards

[0.6675509240218409,
 1.0,
 0.0,
 0.7148659401375723,
 0.0,
 0.013112212060235072,
 0.0,
 0.6499646066605059,
 0.0,
 0.14637952405206542,
 0.44613792852071044,
 0.0,
 0.0,
 0.09992370615264241,
 0.0,
 0.2143484919286783,
 0.0,
 0.01515027519432415,
 0.9578623244168099,
 0.0,
 0.18676337302929955,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.1405256693879443,
 0.24902533602767093,
 1.0,
 0.0,
 0.6009592977325259,
 1.0,
 0.15261273427524308,
 0.0,
 0.22264780225258565,
 0.0,
 0.8992340744381976,
 0.014389533178426566,
 1.0,
 0.9287178051743542,
 0.0,
 0.6369867383700522,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [26]:
# noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
noisy_actions.shape, timesteps.shape, obs_cond.shape

(torch.Size([16, 16, 2]), torch.Size([16]), torch.Size([16, 1028]))

In [None]:
# torch.save(nets.state_dict(), 'pusht_enc_ms_worked.pth')