In [None]:
import sys
import os
import torch

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'

import numpy as np
import plotly.graph_objects as go

In [None]:
def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()

In [None]:
from equivariant_pose_graph.models.transformer_flow import ResidualFlow_DiffEmb, ResidualFlow_DiffEmbTransformer
from equivariant_pose_graph.training.flow_equivariance_training_module import EquivarianceTrainingModule
from pytorch3d.ops import sample_farthest_points

In [None]:
def load_data(num_points, clouds, classes, action_class, anchor_class):
    points_raw_np = clouds
    classes_raw_np = classes

    points_action_np = points_raw_np[classes_raw_np == action_class].copy()
    points_action_mean_np = points_action_np.mean(axis=0)
    points_action_np = points_action_np - points_action_mean_np
    
    points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy()
    points_anchor_np = points_anchor_np - points_action_mean_np

    points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
    points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)
    points_action, points_anchor = subsample(num_points,points_action, points_anchor)
    return points_action.cuda(), points_anchor.cuda()

def subsample(num_points,points_action,points_anchor):
    if(points_action.shape[1] > num_points):
        points_action, _ = sample_farthest_points(points_action, 
            K=num_points, random_start_point=True)
    elif(points_action.shape[1] < num_points):
        raise NotImplementedError(f'Action point cloud is smaller than cloud size ({points_action.shape[1]} < {num_points})')

    if(points_anchor.shape[1] > num_points):
        points_anchor, _ = sample_farthest_points(points_anchor, 
            K=num_points, random_start_point=True)
    elif(points_anchor.shape[1] < num_points):
        raise NotImplementedError(f'Anchor point cloud is smaller than cloud size ({points_anchor.shape[1]} < {num_points})')
    
    return points_action, points_anchor

In [None]:
network = ResidualFlow_DiffEmb(emb_nn='dgcnn')

model = EquivarianceTrainingModule(
    network)
model.cuda()
#checkpoint_file = '/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0_attn_trans0.1_rot5_meancenter_diffembnn/equiv_dcpflow/version_2/checkpoints/epoch=14-step=1875.ckpt'
checkpoint_file = '/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0_attn_trans0.1_rot5_meancenter_diffembnn/equiv_dcpflow/version_2/checkpoints/epoch=14-step=1875.ckpt'
# checkpoint_file='/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0_attn_trans0.1_rot180_diffembnntrans_dgcnn/equiv_dcpflow/version_2/saved_ckpts/epoch=34-step=4375.ckpt'
model.load_state_dict(torch.load(checkpoint_file)['state_dict'])
place_model = model 

In [None]:
def plot(points_action, points_anchor):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    points_action_dp = toDisplay(points_action)
    points_anchor_dp = toDisplay(points_anchor)
    go_data=[
        go.Scatter3d(x=points_action_dp[::skip,0], y=points_action_dp[::skip,1], z=points_action_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=points_anchor_dp[::skip,0], y=points_anchor_dp[::skip,1], z=points_anchor_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()
def plot_multi(p1, p2, p3, p4):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    p1_dp = toDisplay(p1)
    p2_dp = toDisplay(p2)
    p3_dp = toDisplay(p3)
    p4_dp = toDisplay(p4)
    go_data=[
        go.Scatter3d(x=p1_dp[::skip,0], y=p1_dp[::skip,1], z=p1_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=p2_dp[::skip,0], y=p2_dp[::skip,1], z=p2_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
        go.Scatter3d(x=p3_dp[::skip,0], y=p3_dp[::skip,1], z=p3_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[2],
                     symbol='circle')),
        go.Scatter3d(x=p4_dp[::skip,0], y=p4_dp[::skip,1], z=p4_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[3],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()

In [None]:
 
# from pathlib import Path
# import torch.nn.functional as F
# from equivariant_pose_graph.utils.se3 import random_se3
# from pytorch3d.transforms import Transform3d, Rotate

# num_classes = 3
# data_idx = 0
# cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
# cloud_type = 'init'
# # data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/place_test_0_my_model_may11_dgcnn_diffembnntrans_rot180') 
# data_path = Path('/home/exx/Documents/ndf_robot/train_data_3/renders/61_teleport_obj_points.npz')
# point_data = np.load(data_path, allow_pickle = True)
 
# points_action, points_anchor = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)
# plot(points_action, points_anchor)

# points_action_trans = points_action
# points_anchor_trans = points_anchor
# pred_T_action, pred_T_anchor, pred_T  = place_model(points_action, points_anchor)
 
# pred_T_action_points_action = pred_T_action.transform_points(points_action_trans)
# pred_T_action_points_anchor = pred_T_anchor.transform_points(points_anchor_trans)
 
 

# ## What the predicted transform in plotly looks like:
# plot(pred_T_action_points_action, points_anchor)

# pred_T_action_points_action_iterate = pred_T_action_points_action
# for i in range(20):
#     pred_T_action, pred_T_anchor, pred_T  = place_model(pred_T_action_points_action_iterate, points_anchor)
#     pred_T_action_points_action_iterate = pred_T_action.transform_points(pred_T_action_points_action_iterate)
    
#     torch.cuda.empty_cache() 
# plot_multi(points_action, points_anchor, pred_T_action_points_action, points_anchor)

In [None]:
# plot_multi(points_action, points_anchor, pred_T_action_points_action_iterate, points_anchor)

In [None]:
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3
from pytorch3d.transforms import Transform3d, Rotate

num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'init'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/place_test_0_my_model_may12_dgcnn_diffemb') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
pred_T_action_init = torch.from_numpy(point_data['pred_T_action_init'][0]).cuda()
pred_T_action_pulled = Transform3d(matrix=pred_T_action_init)
# print("pred_T_action_init")
# print(pred_T_action_init)
 
points_action, points_anchor = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)


points_action_trans = points_action
points_anchor_trans = points_anchor
pred_T_action, pred_T_anchor, pred_T  = place_model(points_action, points_anchor)

pred_T_action_points_action = pred_T_action.transform_points(points_action_trans)
pred_T_action_points_anchor = pred_T_anchor.transform_points(points_anchor_trans)
## What in the intial point cloud looks like:
plot(points_action, points_anchor)
 
# ## What the predicted transform in plotly looks like:
# plot(pred_T_action_points_action, points_anchor)
 


pred_T_action_points_action_iterate = pred_T_action_points_action
for i in range(30):
    pred_T_action, pred_T_anchor, pred_T  = place_model(pred_T_action_points_action_iterate, points_anchor)
    pred_T_action_points_action_iterate = pred_T_action.transform_points(pred_T_action_points_action_iterate)
 
    torch.cuda.empty_cache() 
 

In [None]:
print("pred_T_action_pulled.get_matrix()")
print(pred_T_action_pulled.get_matrix())
print("pred_T_action.get_matrix()")
print(pred_T_action.get_matrix())

In [None]:
plot(pred_T_action_points_action, points_anchor)

In [None]:
plot_multi(pred_T_action_points_action, points_anchor, pred_T_action_points_action_iterate, points_anchor)

In [None]:
plot(pred_T_action_points_action_iterate, points_anchor)

In [None]:
 
plot(pred_T_action_pulled.transform_points(points_action),points_anchor)

In [None]:
plot_multi(pred_T_action_pulled.transform_points(points_action), 
           points_anchor, 
           pred_T_action_points_action, 
           points_anchor)

In [None]:
## What in pybullet 
cloud_type = 'teleport'
 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
 
points_action_tele, points_anchor_tele = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)
plot(points_action_tele , points_anchor_tele )
points_action_trans_tele = points_action_tele
points_anchor_trans_tele = points_anchor_tele
pred_T_action_tele, pred_T_anchor_tele, pred_T_tele = place_model(points_action_trans_tele, points_anchor_trans_tele)
 
pred_T_action_points_action_tele = pred_T_action_tele.transform_points(points_action_trans_tele)
pred_T_action_points_anchor_tele = pred_T_anchor_tele.transform_points(points_anchor_trans_tele)
 
 
# # for i in range(20):
# #     pred_T_action, pred_T_anchor, pred_T  = place_model(pred_T_action_points_action, points_anchor)
# #     pred_T_action_points_action = pred_T_action.transform_points(pred_T_action_points_action)
# #     plot(pred_T_action_points_action, points_anchor)
# #     torch.cuda.empty_cache() 
# plot(pred_T_action_points_action_tele , points_anchor_tele )

In [None]:
plot_multi(pred_T_action_pulled.transform_points(points_action), points_anchor, points_action_tele, points_anchor_tele)

In [None]:
stop

In [None]:
plot(points_action, points_anchor)

In [None]:
plot(pred_T_action_points_action, points_anchor)

In [None]:
# Load Teleport
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3

num_classes = 3
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'teleport'
# data_path = Path('/home/bokorn/src/ndf_robot/notebooks')
# point_data = np.load(data_path / f'{data_idx}_obj_points.npz')
# data_path = Path('/home/exx/Documents/ndf_robot/train_new_data_3/renders')
# data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/debug_place_shapeid_34ae0b61b0d8aaf2d7b20fded0142d7a') 
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/place_test_0_my_model_may10_overfit') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
print(point_data['shapenet_id'])
points_action, points_anchor = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)
 
points_action_trans = points_action
points_anchor_trans = points_anchor
pred_T_action, pred_T_anchor, pred_T  = place_model(points_action_trans, points_anchor_trans)
 
pred_T_action_points_action = pred_T_action.transform_points(points_action_trans)
pred_T_action_points_anchor = pred_T_anchor.transform_points(points_anchor_trans)
plot(points_action, points_anchor)
 
for i in range(20):
    pred_T_action, pred_T_anchor, pred_T  = place_model(pred_T_action_points_action, points_anchor)
    pred_T_action_points_action = pred_T_action.transform_points(pred_T_action_points_action)
    
    torch.cuda.empty_cache() 
plot(pred_T_action_points_action, points_anchor)