In [None]:
import numpy as np
from os.path import join
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import pdb, sys, torch
sys.path.append('.')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = True
from diffuser.guides.policies import Policy
import diffuser.datasets as datasets
import diffuser.utils as utils
from diffuser.models import GaussianDiffusionPB
from datetime import datetime
import os.path as osp
from importlib import reload
from diffuser.utils.jupyter_utils import suppress_stdout
np.set_printoptions(precision=3)

## You need to download 'kuka7d-base' from OneDrive Link in README.md to launch
class Parser(utils.Parser):
    config: str = "config/kuka7d/kuka_exp.py"

#---------------------------------- setup ----------------------------------#

## training args
args_train = Parser().parse_args('diffusion', from_jupyter=True)
args = Parser().parse_args('plan', from_jupyter=True)

args.savepath = None # osp.join(args.savepath, sub_dir)
args.load_unseen_maze = True

## load dataset here, dataset is a string: name of the env
print('args.dataset', type(args.dataset), args.dataset)
print('args.dataset_eval', type(args.dataset_eval), args.dataset_eval)
use_normed_wallLoc = args_train.dataset_config.get('use_normed_wallLoc', False)

## use the trained env or eval env
load_unseen_maze = args.load_unseen_maze # True False

with suppress_stdout():
    # ---------- load normalizer ------------
    train_env_list = datasets.load_environment(args.dataset, is_eval=load_unseen_maze)
    train_normalizer = utils.load_datasetNormalizer(train_env_list.dataset_url, args_train, train_env_list)

    #---------------------------------- loading ----------------------------------#
    ld_config = dict(env_instance=train_env_list) 
    ## dataset, renderer, model, diffusion, trainer.ema_model, trainer, epoch
    diffusion_experiment = utils.load_potential_diffusion_model(args.logbase, args.dataset, \
                args_train.exp_name, epoch=args.diffusion_epoch, ld_config=ld_config)
    diffusion = diffusion_experiment.ema

### load eval problems

In [None]:
from diffuser.guides.kuka_plan_utils import load_eval_problems_pb
problems_dict = load_eval_problems_pb(train_env_list, {'load_unseen_maze': load_unseen_maze})
probs: np.ndarray
probs = problems_dict['problems'] # (300, 200, 2, 7)
print('loaded epoch:', diffusion_experiment.epoch)

In [None]:
import diffuser.utils.jupyter_utils as vu; reload(vu)
from diffuser.utils.jupyter_utils import get_all_suc_trajs
print('condition_guidance_w', diffusion.condition_guidance_w )

In [None]:
## example of creating an env and get the wall locations
env_single = train_env_list.create_single_env(17)
env_single.wall_locations

In [None]:
diffusion: GaussianDiffusionPB
bs = 100
policy = Policy(diffusion, train_normalizer, use_ddim=True)
test_hor = 48
diffusion.horizon = test_hor

wloc_source = 'dataset'
if wloc_source == 'dataset':
    ## env 
    env_id, p_id = 87, 8

    env_single = train_env_list.create_single_env(env_id)
    prob = probs[env_id, p_id, ] # (2, 7)
    ## env 
    start = prob[None, 0].repeat( bs, 0 ) # (bs, 7)
    goal =  prob[None, 1].repeat( bs, 0 )
    
elif wloc_source == 'custom':
    '''Define your own wall locations'''
    env_id = 0 ## will overwrite env0
    xyz_list = np.array([
        [-0.492,  0.177,  0.846],
        [-0.045, -0.433,  0.291],
        [ 0.431,  0.476,  0.409],
        [ 0.469, -0.03 ,  0.886]])
    assert len(xyz_list) == 4, 'since w/o using composition, should equal to training number.'
    
    env_single = train_env_list.create_env_by_pos(env_id, xyz_list)
    ## define the start and goal joint state as well...

## prepare wloc
wloc_np = train_env_list.wallLoc_list[env_id]
wloc_tensor = utils.to_torch(wloc_np).reshape(1, -1).repeat( (bs, 1) )
print( f'wloc_np: ', wloc_np.shape )

cond = {0: start, 
        test_hor-1: goal}

samples = policy(cond, batch_size=-1, wall_locations=wloc_tensor, use_normed_wallLoc=use_normed_wallLoc, return_diffusion=False)
## samples[0] is dummy, can be ignored
unnm_traj = samples[1].observations # 3d (B, H, 7)

unnm_suc_traj, suc_idxs = get_all_suc_trajs( env_single, unnm_traj, goal=goal)
num_suc_trajs = len(unnm_suc_traj)
print(f'num_suc_trajs: {num_suc_trajs}') # a batch of traj
print(f'unnm_traj: {len(unnm_traj)}, {unnm_traj[0].shape}') # a batch of traj


In [None]:
print('start:', np.array2string(start[0], separator=', '))
print('goal:',  np.array2string(goal[0], separator=', '))

### Do visualization

In [None]:
h = w = 300 # render resolution
pfr = 77
import pybullet as p
import matplotlib.pyplot as plt
from tqdm import tqdm; 
from pb_diff_envs.utils.kuka_utils_luo import add_start_end_marker

'''
generate kuka video for website and ppt slides
'''
def visualize_kuka_traj_luo_web(env, traj, lock=None, is_ee3d=False, is_debug=False, is_mpo=False):
    '''
    The difference with visualize_traj is that this function directly receives 
    [a list of 1d numpy] or [2d np] as trajectory, not using a traj_agent
    
    '''
    ## the objects' trajectories, should be 2 for static obj
    max_len_traj = max([len(obj.trajectory.waypoints) for obj in env.objects])
    gifs = []
    max_len = max(len(traj), max_len_traj)
    assert max_len == len(traj)
    # print('[vis traj] max_len_traj', max_len_traj)
    print('[vis traj] max_len:', max_len)

    # se_valid = check_start_end(traj) # if the start and end is valid
    se_valid = None
    # if not p.isConnected():
    if lock is not None:
        lock.acquire()
        env.load(GUI=False)
        lock.release()
    else:
        # if p.isConnected()
        # print(f'p isConnected: {p.isConnected()}')
        env.load(GUI=is_debug) # ok

    i_y, i_p = 0, 0
    colors = plt.cm.jet(np.linspace(0,1, max_len)) # color of the dot
    vshape_id = p.createVisualShape(p.GEOM_SPHERE, radius=0.02)  # Red color
    has_collision = False
    num_collision_frame = 0
    collision_list = []
    add_start_end_marker(env, traj)
    if is_mpo:
        traj = traj[:-1]; max_len -= 1

    assert not is_ee3d, 'depricated, prevent bugs'
    # for c, timestep in tqdm(enumerate(traj)):
    for i_t in tqdm(range(max_len)):

        if not is_ee3d: # traj is xyz-level
            env.robot.set_config(traj[i_t])
            new_pos = p.getLinkState(env.robot_id, 6)[0]
        else:
            new_pos = traj[i_t]
        cam_pos = p.getLinkState(env.robot_id, 5)[0]
        new_pos_q = p.getQuaternionFromEuler(new_pos)

        dist = 2 - np.linalg.norm(cam_pos)
        # tmp_id = p.loadURDF("sphere2red.urdf", new_pos, globalScaling=0.05, flags=p.URDF_IGNORE_COLLISION_SHAPES)
        ## slow
        tmp_id = p.createMultiBody(baseMass=0, basePosition=new_pos, baseVisualShapeIndex=vshape_id,)
        p.changeVisualShape(tmp_id, linkIndex=-1, rgbaColor=colors[i_t])

        # print(p.getBodyInfo(tmp_id))
        p.performCollisionDetection()
        c_pts = p.getContactPoints(env.robot_id)
        if c_pts is None:
            pass # also no collision
        elif len(c_pts) > 0: # very important, check is None
            has_collision = True
            num_collision_frame += 1
            collision_list.append(i_t)
            # print(f'has_collision {has_collision}; c_pts: {c_pts}')

        ## -----------------------------------
        ## Configurate Camera
        cam_pos = [0,0,0.3]
        dist = 1.95 # original exp settings:1.8
        
        view_mat = p.computeViewMatrixFromYawPitchRoll(
                cameraTargetPosition=cam_pos, distance=dist, # [0, 0, 0]
                yaw=-i_y, pitch=-30-i_p, roll=0, upAxisIndex=2) # yaw=-90-i_y
        i_p = i_p + 0.2 if i_p < 20 else i_p
        i_y += 0.3

        ## ---- for web high-quality visualization ------

        resol_mul = 1
        k7d_camera_info = (h, w, (0.886519193649292, 0.2863903045654297, -0.36340656876564026, 0.0, -0.4626918435096741, 0.5487248301506042, -0.6962882280349731, 0.0, 0.0, 0.7854180932044983, 0.6189655065536499, 0.0, -0.0, 1.1920928955078125e-07, -2.101985216140747, 1.0), (1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0000200271606445, -1.0, 0.0, 0.0, -0.02000020071864128, 0.0), (0.0, 0.0, 1.0), (0.36340656876564026, 0.6962882280349731, -0.6189655065536499), (17730.3828125, -9253.8369140625, 0.0), (5727.80615234375, 10974.49609375, 15708.36328125), -27.56093978881836, -38.24062728881836, 2.101985216140747, (-0.0, -0.0, 0.0))
                           

        act_yaw, act_pitch, act_dist, act_target = k7d_camera_info[-4:]
        
        view_mat = k7d_camera_info[2]

        width, height, img, depth, seg = p.getCameraImage(
            width=k7d_camera_info[0] * resol_mul,  # Image width
            height=k7d_camera_info[1] * resol_mul,  # Image height
            # viewMatrix=k7d_camera_info[2],  # View matrix
            viewMatrix=view_mat,  # View matrix
            projectionMatrix=k7d_camera_info[3],  # Projection matrix
            flags=p.ER_NO_SEGMENTATION_MASK,
            renderer=p.ER_BULLET_HARDWARE_OPENGL,
        )
        gifs.append( img )

        ## ----------------------------------
    
    ## stop at the last one for 2 sec
    ds = [150] * (len(gifs)-1)
    ds.append(2000)

    def process_elem(st, end, n_interp, idx):
        assert 0 <= idx <= n_interp
        return st + (end - st) * idx / n_interp 
    
    def process_list(start_list, end_list, n_interp, idx):
        '''n_interp: how many in total between 
        '''
        r_list = []
        for i in range(len(start_list)):
            st = start_list[i]
            end = end_list[i]
            if isinstance(st, tuple):
                result = []
                for iii in range(len(st)):
                    result.append( process_elem( st[iii], end[iii], n_interp, idx) )
                result = tuple(result)
                # print(f"Tuple with length {len(element)}: {element}")
            elif isinstance(st, float):
                # Process if the element is a float
                # print(f"Float rounded to 2 decimal places: {round(element, 2)}")
                result = process_elem(st, end, n_interp, idx)
            else:
                raise NotImplementedError()
            r_list.append( result )
        return r_list

    end_yaw = -260 # 100 # 40 + i_y
    # gifs.extend([gifs[-1]]*10)
    rot_gap = 3; # n_rot_frames = round(360 / rot_gap)
    rot_frames = []

    ## --------
    ## --------
    rot_pitch = - 40
    start_list = list( k7d_camera_info[-4:] )
    end_list = [end_yaw, rot_pitch, dist, cam_pos]
    yaw_diff = end_list[0] - start_list[0]
    n_interp = 36 if abs(yaw_diff) < 120 else int(abs(yaw_diff // 3)) # 12
    print('XX:', end_list[0] - start_list[0])
    for i_itp in range(n_interp):
        i_y, i_p, i_d, i_cp = process_list(start_list, end_list, n_interp, i_itp)

        
        view_mat = p.computeViewMatrixFromYawPitchRoll(
                cameraTargetPosition=i_cp, distance=i_d, # [0, 0, 0]
                yaw=i_y, pitch=i_p, roll=0, upAxisIndex=2)
        
        rot_frames.append(p.getCameraImage(width=h, height=w, lightDirection=[1, 1, 1], shadow=0,
                                                 renderer=p.ER_TINY_RENDERER,
                                                 viewMatrix=view_mat,
                                                #  projectionMatrix=env.proj_mat,
                                                projectionMatrix=k7d_camera_info[3],
                                                 flags=p.ER_NO_SEGMENTATION_MASK,
                                                 )[2])
    ## --------

    
    
    for i_y in range(0, 360, rot_gap): # 2 3
        view_mat = p.computeViewMatrixFromYawPitchRoll(
                cameraTargetPosition=cam_pos, distance=dist, # [0, 0, 0]
                yaw=(end_yaw+i_y), pitch=rot_pitch, roll=0, upAxisIndex=2)
        
        rot_frames.append(p.getCameraImage(width=h, height=w, lightDirection=[1, 1, 1], shadow=0,
                                                 renderer=p.ER_TINY_RENDERER,
                                                 viewMatrix=view_mat,
                                                #  projectionMatrix=env.proj_mat,
                                                projectionMatrix=k7d_camera_info[3],
                                                 flags=p.ER_NO_SEGMENTATION_MASK,
                                                 )[2])
        
        if len(rot_frames) + len(gifs) == pfr:
            print(f'camera {pfr}', view_mat, )
            print(f'yaw: {(end_yaw+i_y)}')


    # gifs = set_collision_marker(gifs, has_collision, se_valid)
    # set_collision_frame(gifs, collision_list)
    print(f'view_mat: {view_mat}')
    print(f'[vis traj] num_collision_frame: {num_collision_frame}; start end valid: {se_valid}')
    

    ds.extend(  [150,] * (len(rot_frames) -1) )
    ds.append(6000)
    gifs.extend(rot_frames)


    # ds[-n_rot_frames:] = [200,] * n_rot_frames # slow down, a list
    ds[0] = 1200
    assert len(ds) == len(gifs)
    vis_dict = dict(ncoll_frame=num_collision_frame, se_valid=se_valid)
    env.unload_env()

    return gifs, ds, vis_dict


In [None]:
from IPython.display import HTML
import base64, datetime, pickle
from diffuser.utils import KukaRenderer
import diffuser.utils.luo_utils as xxx; reload(xxx)
from diffuser.utils.luo_utils import save_gif_ethucy
from pb_diff_envs.utils.kuka_utils_luo import visualize_kuka_traj_luo


rd = KukaRenderer(train_env_list, is_eval=True)
rootdir = './visualization/'
os.makedirs(rootdir, exist_ok=True)
gifname = f'{rootdir}/kuka7d_base.gif'


# traj_idx = suc_idxs[0] # which path
# hor_idx = traj_idx // bs
# path_idx = traj_idx % bs # which path
path_idx = 10

## unnm_suc_traj list[np.ndarray]: (n_trajs, H, dim)
gifs, ds, vis_dict = visualize_kuka_traj_luo_web( env_single, unnm_suc_traj[path_idx] )


save_gif_ethucy( gifs, gifname, duration=ds) # duration=100)
b64 = base64.b64encode(open(gifname, 'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))