In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os


# from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
import imageio 
import torch
import torch.nn as nn
import torch.nn.functional as F

import h5py

from diffusion_policy.pusht_data_utils import get_data_stats, normalize_data, unnormalize_data, PushTImageDatasetFromHDF5
from diffusion_policy.vision_model import ResidualBlock, ResNetFe, replace_bn_with_gn
from diffusion_policy.noise_predictor_model import ConditionalUnet1D
from diffusion_policy.myddpm import MyScheduler, MyDDPM

  from .autonotebook import tqdm as notebook_tqdm


pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html


  warn(f"Failed to load image Python extension: {e}")


In [2]:
pred_horizon = 16
obs_horizon = 2
action_horizon = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device 

device(type='cuda')

In [3]:
hdf5_file_name='data/pusht/pusht_v7_zarr_206.hdf5'
dataset = PushTImageDatasetFromHDF5(
    hdf5_file_name=hdf5_file_name,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    hdf5_filter_key="f50"
)
stats=dataset.stats

In [4]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

batch = next(iter(dataloader))
batch['image'].shape, batch['agent_pos'].shape, batch['action'].shape

(torch.Size([64, 2, 3, 96, 96]),
 torch.Size([64, 2, 2]),
 torch.Size([64, 16, 2]))

In [5]:
vision_encoder = ResNetFe(ResidualBlock, [2, 2]) 
vision_encoder = replace_bn_with_gn(vision_encoder)

In [6]:
vision_feature_dim = 512
lowdim_obs_dim = 2

obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2

noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})

_ = nets.to(device)

num_diffusion_iters = 100
sample_shape=(pred_horizon, action_dim) 

noise_scheduler=MyScheduler(T=num_diffusion_iters, device=device)
ddpm=MyDDPM(noise_scheduler, nets['noise_pred_net'], device=device)

ConditionalUnet1D: number of parameters: 7.994727e+07


In [7]:
nbatch = next(iter(dataloader))

In [8]:
nimage = nbatch['image'][:,:obs_horizon].to(device)
nagent_pos = nbatch['agent_pos'][:,:obs_horizon].to(device)
naction = nbatch['action'].to(device)
B = nagent_pos.shape[0]

In [9]:
nimage.shape

torch.Size([64, 2, 3, 96, 96])

In [10]:
image_features = nets['vision_encoder'](nimage.flatten(end_dim=1))
image_features = image_features.reshape(*nimage.shape[:2],-1)
# (B,obs_horizon,D)
image_features.shape, nagent_pos.shape

  return F.conv2d(input, weight, bias, self.stride,


(torch.Size([64, 2, 512]), torch.Size([64, 2, 2]))

In [11]:
# concatenate vision feature and low-dim obs
obs_features = torch.cat([image_features, nagent_pos], dim=-1)
obs_cond = obs_features.flatten(start_dim=1)
# (B, obs_horizon * obs_dim)
obs_cond.shape

torch.Size([64, 1028])

In [12]:
timesteps = torch.randint(0, noise_scheduler.T,(B,), device=device).long()
noisy_actions , eps= noise_scheduler.get_xt(naction, timesteps)
noisy_actions.shape, eps.shape, timesteps.shape

(torch.Size([64, 16, 2]), torch.Size([64, 16, 2]), torch.Size([64]))

In [13]:
eps_theta = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
eps_theta.shape

torch.Size([64, 16, 2])

In [14]:
noise_pred_net.diffusion_step_encoder

Sequential(
  (0): SinusoidalPosEmb()
  (1): Linear(in_features=256, out_features=1024, bias=True)
  (2): Mish()
  (3): Linear(in_features=1024, out_features=256, bias=True)
)

In [15]:
noise_pred_net.up_modules

ModuleList(
  (0): ModuleList(
    (0): ConditionalResidualBlock1D(
      (blocks): ModuleList(
        (0): Conv1dBlock(
          (block): Sequential(
            (0): Conv1d(2048, 512, kernel_size=(5,), stride=(1,), padding=(2,))
            (1): GroupNorm(8, 512, eps=1e-05, affine=True)
            (2): Mish()
          )
        )
        (1): Conv1dBlock(
          (block): Sequential(
            (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
            (1): GroupNorm(8, 512, eps=1e-05, affine=True)
            (2): Mish()
          )
        )
      )
      (cond_encoder): Sequential(
        (0): Mish()
        (1): Linear(in_features=1284, out_features=1024, bias=True)
        (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
      )
      (residual_conv): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
    )
    (1): ConditionalResidualBlock1D(
      (blocks): ModuleList(
        (0-1): 2 x Conv1dBlock(
          (block): Sequential(
            (0): Con

In [16]:
noise_pred_net.down_modules

ModuleList(
  (0): ModuleList(
    (0): ConditionalResidualBlock1D(
      (blocks): ModuleList(
        (0): Conv1dBlock(
          (block): Sequential(
            (0): Conv1d(2, 256, kernel_size=(5,), stride=(1,), padding=(2,))
            (1): GroupNorm(8, 256, eps=1e-05, affine=True)
            (2): Mish()
          )
        )
        (1): Conv1dBlock(
          (block): Sequential(
            (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,))
            (1): GroupNorm(8, 256, eps=1e-05, affine=True)
            (2): Mish()
          )
        )
      )
      (cond_encoder): Sequential(
        (0): Mish()
        (1): Linear(in_features=1284, out_features=512, bias=True)
        (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
      )
      (residual_conv): Conv1d(2, 256, kernel_size=(1,), stride=(1,))
    )
    (1): ConditionalResidualBlock1D(
      (blocks): ModuleList(
        (0-1): 2 x Conv1dBlock(
          (block): Sequential(
            (0): Conv1d(256

In [17]:
noise_pred_net.final_conv

Sequential(
  (0): Conv1dBlock(
    (block): Sequential(
      (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): GroupNorm(8, 256, eps=1e-05, affine=True)
      (2): Mish()
    )
  )
  (1): Conv1d(256, 2, kernel_size=(1,), stride=(1,))
)

In [18]:
noise_pred_net.mid_modules

ModuleList(
  (0-1): 2 x ConditionalResidualBlock1D(
    (blocks): ModuleList(
      (0-1): 2 x Conv1dBlock(
        (block): Sequential(
          (0): Conv1d(1024, 1024, kernel_size=(5,), stride=(1,), padding=(2,))
          (1): GroupNorm(8, 1024, eps=1e-05, affine=True)
          (2): Mish()
        )
      )
    )
    (cond_encoder): Sequential(
      (0): Mish()
      (1): Linear(in_features=1284, out_features=2048, bias=True)
      (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
    )
    (residual_conv): Identity()
  )
)

In [46]:
noisy_actions.shape, timesteps.shape, obs_cond.shape

(torch.Size([64, 16, 2]), torch.Size([64]), torch.Size([64, 1028]))

In [47]:
eps_theta = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
eps_theta.shape

torch.Size([64, 16, 2])

In [51]:
noisy_actions = torch.randn(64, 16, 2).to(device)
timesteps = torch.randint(0, 100, (64,)).to(device)
obs_cond = torch.randn(64, 1028).to(device)

noisy_actions.shape, timesteps.shape, obs_cond.shape

eps_theta = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
eps_theta.shape

torch.Size([64, 16, 2])

### forward

In [19]:
# eps_theta = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
# eps_theta.shape

sample=noisy_actions
timestep=timesteps
global_cond=obs_cond

sample.shape, timestep.shape, global_cond.shape

(torch.Size([64, 16, 2]), torch.Size([64]), torch.Size([64, 1028]))

In [20]:
# (B,T,C)
sample = sample.moveaxis(-1,-2)
# (B,C,T)
sample.shape

torch.Size([64, 2, 16])

In [21]:
torch.is_tensor(timesteps), torch.is_tensor(timesteps) and len(timesteps.shape) == 0

(True, False)

In [22]:
# 1. time
timesteps = timestep 
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
timesteps.shape

torch.Size([64])

In [37]:
global_feature = noise_pred_net.diffusion_step_encoder(timesteps)
global_feature.shape

torch.Size([64, 256])

In [38]:
global_feature.shape, global_cond.shape

(torch.Size([64, 256]), torch.Size([64, 1028]))

In [39]:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
global_feature.shape 

torch.Size([64, 1284])

In [40]:
x = sample
x.shape

torch.Size([64, 2, 16])

In [42]:
h = []
for idx, (resnet, resnet2, downsample) in enumerate(noise_pred_net.down_modules):
    x = resnet(x, global_feature)
    x = resnet2(x, global_feature)
    h.append(x)
    x = downsample(x)

x.shape

torch.Size([64, 1024, 4])

In [43]:
for mid_module in noise_pred_net.mid_modules:
    x = mid_module(x, global_feature)

x.shape

torch.Size([64, 1024, 4])

In [44]:
for idx, (resnet, resnet2, upsample) in enumerate(noise_pred_net.up_modules):
    x = torch.cat((x, h.pop()), dim=1)
    x = resnet(x, global_feature)
    x = resnet2(x, global_feature)
    x = upsample(x)

x.shape

torch.Size([64, 256, 16])

In [45]:
x = noise_pred_net.final_conv(x)
x.shape

torch.Size([64, 2, 16])

In [30]:
# (B,C,T)
x = x.moveaxis(-1,-2)
# (B,T,C)
x.shape

torch.Size([64, 16, 2])

In [32]:
self= noise_pred_net

In [33]:
self.diffusion_step_encoder

Sequential(
  (0): SinusoidalPosEmb()
  (1): Linear(in_features=256, out_features=1024, bias=True)
  (2): Mish()
  (3): Linear(in_features=1024, out_features=256, bias=True)
)

In [None]:

def forward(self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        global_cond=None):
    """
    x: (B,T,input_dim)
    timestep: (B,) or int, diffusion step
    global_cond: (B,global_cond_dim)
    output: (B,T,input_dim)
    """
    # (B,T,C)
    sample = sample.moveaxis(-1,-2)
    # (B,C,T)

    # 1. time
    timesteps = timestep
    if not torch.is_tensor(timesteps):
        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
        timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
        timesteps = timesteps[None].to(sample.device)
    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    timesteps = timesteps.expand(sample.shape[0])

    global_feature = self.diffusion_step_encoder(timesteps)

    if global_cond is not None:
        global_feature = torch.cat([
            global_feature, global_cond
        ], axis=-1)

    x = sample
    h = []
    for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
        x = resnet(x, global_feature)
        x = resnet2(x, global_feature)
        h.append(x)
        x = downsample(x)

    for mid_module in self.mid_modules:
        x = mid_module(x, global_feature)

    for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
        x = torch.cat((x, h.pop()), dim=1)
        x = resnet(x, global_feature)
        x = resnet2(x, global_feature)
        x = upsample(x)

    x = self.final_conv(x)

    # (B,C,T)
    x = x.moveaxis(-1,-2)
    # (B,T,C)
    return x
