In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface.data import PointCloud, SE3, DemoDataset, TargetPoseDemo
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
device = 'cuda:0'
task_type = 'place'
config_root_dir = 'configs/sapien'
testset = DemoDataset(dataset_dir='demo/sapien_demo_20230625')

In [None]:
with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs_list = yaml.load(f, Loader=yaml.FullLoader)['model_kwargs'][f"{task_type}_models_kwargs"]

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device
)

# Initialize Input Data and Initial Pose

In [None]:
demo: TargetPoseDemo = testset[0][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd
T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[0., 0., 0.8]], device=device)
], dim=-1)
Ts_init = SE3(poses=T0).to(device)


In [None]:
Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[500, 500], [500, 1000]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperature_list = [1., 1.],)

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()