In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface.data import PointCloud, SE3, DemoDataset, TargetPoseDemo, preprocess
from edf_interface.utils import manipulation_utils
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

> **Warning**\
> Copy and paste this notebook to the project root directory (which should have the 'config' and 'demo' directories) before running it!

## Eval Configuration

In [None]:
# ----------------------------------------------
# Choose your device
# ----------------------------------------------
# device = 'cpu'
device = 'cuda:0'



# ----------------------------------------------
# Choose the task to evaluate
# ----------------------------------------------
# task_type = "pick"
task_type = "place"

# Initialize Models

In [None]:
config_root_dir = 'configs/panda_bowl'
testset = DemoDataset(dataset_dir='demo/panda_bowl_on_dish_test')




with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs = yaml.load(f, Loader=yaml.FullLoader)
    model_kwargs_list = model_kwargs['model_kwargs'][f"{task_type}_models_kwargs"]
    try:
        critic_kwargs = model_kwargs['model_kwargs'][f"{task_type}_critic_kwargs"]
    except:
        critic_kwargs = None

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device,
    critic_kwargs=critic_kwargs
) # Model initialization and warm-up takes 2~3 minutes (very slow due to e3nn codegen and torch jit).

# Configure Denoising settings

In [None]:
denoising_configs = dict(
    N_steps_list = [[200, 200], [100, 100, 150]],
    timesteps_list = [[0.04, 0.04], [0.02, 0.02, 0.02]],
    temperatures_list = [[1., 1.], [1., 1., 0.0]],
    log_t_schedule = True,
    diffusion_schedules_list = [
        [[1., 0.15], [0.15, 0.05]],
        [[0.09, 0.03], [0.03, 0.012], [0.012, 0.012]],
    ],
    time_exponent_temp = 1.0,
    time_exponent_alpha = 0.5,
    return_info=True
)

# Initialize Input Data and Initial Pose

### demo_idx:
* 0,1,2: Default (in Red, Green, Blue order)
* 3,4,5: Default (in Red, Green, Blue order)
* 6,7,8: Unseen Poses (in Red, Green, Blue order)

In [None]:
demo_idx = 0
demo: TargetPoseDemo = testset[demo_idx][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd

# Sample

In [None]:
if task_type == "pick":
    N_samples = 20 # reduce number of samples if too slow or short of memory
elif task_type == "place":
    N_samples = 10 # reduce number of samples if too slow or short of memory
else:
    raise ValueError(f"'task_type' must be either 'pick' or 'place', but {task_type} is given.")


T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[0., 0., 0.3]], device=device)
], dim=-1).repeat(N_samples, 1)
Ts_init = SE3(poses=T0).to(device)


Ts_out_raw, scene_proc, grasp_proc, info = agent.sample(
    scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
    **denoising_configs
)

if 'energy' in info.keys():
    Ts_out, energy = Ts_out_raw, info['energy']
    Ts_out = Ts_out[:,2:-3] # Remove outlier energy poses
else:
    Ts_out = Ts_out_raw

# Visualize Samples

In [None]:
visualization = TargetPoseDemo(
    target_poses=agent.unprocess_fn(SE3(poses=Ts_out[-1])),
    scene_pcd=preprocess.downsample(data=scene_pcd, voxel_size=0.01),
    grasp_pcd=preprocess.downsample(data=grasp_pcd, voxel_size=0.01)
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show(bg_color=[0.3, 0.3, 0.3], width = 1000, height=1000, point_size=1.5)

# Visualize Denoising Trajectory

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show(bg_color=[0.3, 0.3, 0.3], width = 1000, height=1000, point_size=2.5)