In [None]:
import sys
import os
import torch

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

import numpy as np
import plotly.graph_objects as go
from pytorch3d.ops import sample_farthest_points

In [None]:
def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()

In [None]:
def load_data(num_points, clouds, classes, action_class, anchor_class):
    points_raw_np = clouds
    classes_raw_np = classes

    points_action_np = points_raw_np[classes_raw_np == action_class].copy()
    points_action_mean_np = points_action_np.mean(axis=0)
    points_action_np = points_action_np - points_action_mean_np
    
    points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy()
    points_anchor_np = points_anchor_np - points_action_mean_np

    points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
    points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)
    points_action, points_anchor = subsample(num_points,points_action, points_anchor)
    return points_action.cuda(), points_anchor.cuda()

def subsample(num_points,points_action,points_anchor):
    if(points_action.shape[1] > num_points):
        points_action, _ = sample_farthest_points(points_action, 
            K=num_points, random_start_point=True)
    elif(points_action.shape[1] < num_points):
        raise NotImplementedError(f'Action point cloud is smaller than cloud size ({points_action.shape[1]} < {num_points})')

    if(points_anchor.shape[1] > num_points):
        points_anchor, _ = sample_farthest_points(points_anchor, 
            K=num_points, random_start_point=True)
    elif(points_anchor.shape[1] < num_points):
        raise NotImplementedError(f'Anchor point cloud is smaller than cloud size ({points_anchor.shape[1]} < {num_points})')
    
    return points_action, points_anchor

In [None]:
def plot(points_action, points_anchor):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    points_action_dp = toDisplay(points_action)
    points_anchor_dp = toDisplay(points_anchor)
    go_data=[
        go.Scatter3d(x=points_action_dp[::skip,0], y=points_action_dp[::skip,1], z=points_action_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=points_anchor_dp[::skip,0], y=points_anchor_dp[::skip,1], z=points_anchor_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()
def plot_multi(plist):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
#     '#9467bd',  # muted purple
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#e377c2',  # raspberry yogurt pink
    '#8c564b',  # chestnut brown
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    go_data = []
    for i in range(len(plist)):
        p_dp = toDisplay(plist[i])
        plot = go.Scatter3d(x=p_dp[::skip,0], y=p_dp[::skip,1], z=p_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[i],
                     symbol='circle'))
        go_data.append(plot)
 
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()

In [None]:
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3
from pytorch3d.transforms import Transform3d, Rotate
import numpy as np
import torch
num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'init'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/place_test_0_my_model_may12_dgcnn') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
points_action, points_anchor = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)

# Get T_pred
T_pred = torch.from_numpy(point_data['T_pred'].T).unsqueeze(0).cuda().float()
T_pred_transform = Transform3d(matrix=T_pred)
mat_copy = point_data['T_pred']
mat_copy[:3,-1]*=-1
T_pred_negative_mat = mat_copy
T_pred_negative = torch.from_numpy(T_pred_negative_mat.T).unsqueeze(0).cuda().float()
print(T_pred_negative_mat)
T_pred_transform = Transform3d(matrix=T_pred)
T_pred_negative_transform = T_pred_transform.inverse()
print(point_data['T_pred'])
# Teleport Stage
point_data_tele = np.load(data_path / f'{data_idx}_teleport_obj_points.npz', allow_pickle = True)
points_action_tele, points_anchor_tele = load_data(num_points=1024, clouds = point_data_tele['clouds'] ,classes = point_data_tele['classes'], action_class= 0, anchor_class= 1)
plot(points_action_tele , points_anchor_tele )
points_action_trans_tele = points_action_tele
points_anchor_trans_tele = points_anchor_tele

In [None]:
plot_multi([T_pred_transform.transform_points(points_action),
           points_anchor])

In [None]:
plot_multi([points_action,
           T_pred_negative_transform.transform_points(points_anchor)])

In [None]:
plot_multi([T_pred_transform.transform_points(points_action),
           points_anchor,
           points_action_tele, 
           points_anchor_tele])

In [None]:
plot_multi([points_action,
           T_pred_negative_transform.transform_points(points_anchor),
           points_action_tele, 
           points_anchor_tele])

In [None]:
plot_multi([
            T_pred_transform.transform_points(points_action),
           points_anchor,
           points_action_tele, 
           points_anchor_tele,
          points_action,
          points_anchor])