In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os


# from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
import imageio 
import torch
import torch.nn as nn
import torch.nn.functional as F

import h5py

from pusht_data_utils import get_data_stats, normalize_data, unnormalize_data, PushTImageDatasetFromHDF5
from vision_model import ResidualBlock, ResNetFe, replace_bn_with_gn
from noise_predictor_model import ConditionalUnet1D
from myddpm import MyScheduler, MyDDPM

  from .autonotebook import tqdm as notebook_tqdm


pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html


  warn(f"Failed to load image Python extension: {e}")


In [2]:
pred_horizon = 16
obs_horizon = 2
action_horizon = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device 

device(type='cuda')

In [3]:
hdf5_file_name='data/pusht/pusht_v7_zarr_206.hdf5'
dataset = PushTImageDatasetFromHDF5(
    hdf5_file_name=hdf5_file_name,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    hdf5_filter_key="f50"
)
stats=dataset.stats

In [4]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

batch = next(iter(dataloader))
batch['image'].shape, batch['agent_pos'].shape, batch['action'].shape

(torch.Size([64, 2, 3, 96, 96]),
 torch.Size([64, 2, 2]),
 torch.Size([64, 16, 2]))

In [5]:
vision_encoder = ResNetFe(ResidualBlock, [2, 2]) 
vision_encoder = replace_bn_with_gn(vision_encoder)

In [6]:
vision_feature_dim = 512
lowdim_obs_dim = 2

obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2

noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})

_ = nets.to(device)

ConditionalUnet1D: number of parameters: 7.994727e+07


In [7]:
num_diffusion_iters = 100
sample_shape=(pred_horizon, action_dim) 

noise_scheduler=MyScheduler(T=num_diffusion_iters, device=device)
ddpm=MyDDPM(noise_scheduler, nets['noise_pred_net'], device=device)

In [8]:
num_epochs = 400


optimizer = torch.optim.AdamW(
    params=nets.parameters(),
    lr=1e-4, weight_decay=1e-6)


lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    
    for epoch_idx in tglobal:
            
        epoch_loss = list()
        for nbatch in dataloader:
            
            nimage = nbatch['image'][:,:obs_horizon].to(device)
            nagent_pos = nbatch['agent_pos'][:,:obs_horizon].to(device)
            naction = nbatch['action'].to(device)
            B = nagent_pos.shape[0]

            
            image_features = nets['vision_encoder'](
                nimage.flatten(end_dim=1))
            image_features = image_features.reshape(
                *nimage.shape[:2],-1)
            # (B,obs_horizon,D)

            # concatenate vision feature and low-dim obs
            obs_features = torch.cat([image_features, nagent_pos], dim=-1)
            obs_cond = obs_features.flatten(start_dim=1)
            # (B, obs_horizon * obs_dim)
 
            timesteps = torch.randint(
                0, noise_scheduler.T,
                (B,), device=device
            ).long()
            noisy_actions , eps= noise_scheduler.get_xt(naction, timesteps)

            eps_theta = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)

            loss = nn.functional.mse_loss(eps_theta, eps)


            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # step lr scheduler every batch
            # this is different from standard pytorch behavior
            lr_scheduler.step()


            loss_cpu = loss.item()
            epoch_loss.append(loss_cpu)
           
        tglobal.set_postfix(loss=np.mean(epoch_loss))

  return F.conv2d(input, weight, bias, self.stride,
Epoch: 100%|█████████▉| 399/400 [25:48<00:03,  3.87s/it, loss=0.00815]

In [None]:
import sys 
sys.path.append('/home/ns1254/diffusion_policy/')

from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
import collections

In [None]:
env = PushTImageEnv()

nets.eval()
pass 

In [None]:
def rollout(env, nets,  seed, max_steps=200):
        
    env.seed(200+seed)
    obs = env.reset()

    # keep a queue of last 2 steps of observations
    obs_deque = collections.deque(
        [obs] * obs_horizon, maxlen=obs_horizon)
    # save visualization and rewards
    imgs = [env.render(mode='rgb_array')]
    rewards = list()
    done = False
    step_idx = 0
    success=False
    with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar:
        while not done:
            B = 1
            # stack the last obs_horizon number of observations
            images = np.stack([x['image'] for x in obs_deque])
            agent_poses = np.stack([x['agent_pos'] for x in obs_deque])

            # normalize observation
            nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
            # images are already normalized to [0,1]
            nimages = images

            # device transfer
            nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
            # (2,3,96,96)
            nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)
            # (2,2)

            # infer action
            with torch.no_grad(): 
                image_features = nets['vision_encoder'](nimages)
                # (2,512) 
                # concat with low-dim observations
                obs_features = torch.cat([image_features, nagent_poses], dim=-1) 
                # reshape observation to (B,obs_horizon*obs_dim)
                obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

                naction,xts=ddpm.sample_ddpm(1, sample_shape, obs_cond)
                

            # unnormalize action
            naction = naction.detach().to('cpu').numpy()
            # (B, pred_horizon, action_dim)
            naction = naction[0]
            action_pred = unnormalize_data(naction, stats=stats['action'])

            # only take action_horizon number of actions
            start = obs_horizon - 1
            end = start + action_horizon
            action = action_pred[start:end,:]
            # (action_horizon, action_dim)

            # execute action_horizon number of steps
            # without replanning
            for i in range(len(action)):
                obs, reward, done, info = env.step(action[i])
                obs_deque.append(obs)
                
                rewards.append(reward)
                imgs.append(env.render(mode='rgb_array'))

                # update progress bar
                step_idx += 1
                pbar.update(1)
                pbar.set_postfix(reward=reward)
                if step_idx > max_steps:
                    done = True
                if done:
                    success=True
                    break

    return max(rewards) , success, imgs

In [None]:
rewards=[]
success=[]
lengths=[]
seed=40

np.random.seed(seed)
torch.manual_seed(seed)

for i in range(50):
    reward, suc, imgs = rollout(env, nets, seed+i, 200)
    rewards.append(reward)
    success.append(suc)
    lengths.append(len(imgs))

print('Mean Reward: ', np.mean(rewards))
print('Success Rate: ', np.mean(success))
print('Mean Length: ', np.mean(lengths))