In [None]:
import pytorch3d
from pytorch3d.transforms import Transform3d, Rotate, Translate, \
    rotation_6d_to_matrix, axis_angle_to_matrix, so3_rotation_angle
# from pytorch3d.transforms import se3_exp_map

import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
def random_se3(N, rot_var=np.pi/180*120, trans_var=0.1, device=None, fix_random=False):
    axis_angle_random = torch.randn(N, 3, device=device)
    rot_ratio = torch.rand(1).item()*rot_var / \
        torch.norm(axis_angle_random, dim=1).max().item()
    constrained_axix_angle = rot_ratio*axis_angle_random  # max angle is rot_var
    R = axis_angle_to_matrix(constrained_axix_angle)
    random_translation = torch.randn(N, 3, device=device)
    translation_ratio = trans_var / \
        torch.norm(random_translation, dim=1).max().item()
    t = torch.rand(1).item()*translation_ratio*random_translation
    return Rotate(R, device=device).translate(t)
def pure_translation_se3(N, t, device=None):
    """
    Args
        t: torch tensor of shape (3)
    """
    axis = torch.tensor([0, 0, 1])
    axis_angle = 0.*axis
    axis_angle = axis_angle.unsqueeze(0)  # (1,3)
    axis_angle = torch.repeat_interleave(axis_angle, N, dim=0)  # (N,3)
    R = axis_angle_to_matrix(axis_angle.to(device))  # identity
    assert torch.allclose(
        torch.eye(3).to(device), R[0]), "R should be identity for pure translation se3"
    t = torch.repeat_interleave(t.unsqueeze(0), N, dim=0).to(device)  # N,3
    return Rotate(R, device=device).translate(t)
def get_rt(T):
    t = T.get_matrix()[:, 3, :3]  # B,3
    t_norm = torch.norm(t, dim=1)  # B
    angle_rad_T = so3_rotation_angle(
        T.get_matrix()[:, :3, :3], eps=1e-2)*180/np.pi
    return angle_rad_T, t_norm
T1 = random_se3(1, rot_var = 40*np.pi/180, trans_var = 0.1).cuda()
angle_rad_T, t_norm = get_rt(T1) 
print(angle_rad_T,t_norm)
T2 = random_se3(1, rot_var = 130*np.pi/180, trans_var = 0.1).cuda()
angle_rad_T2, t_norm2 = get_rt(T2) 
print(angle_rad_T2,t_norm2)
T3 = random_se3(1).cuda()
angle_rad_T3, t_norm3 = get_rt(T3) 
print(angle_rad_T3,t_norm3)


In [None]:
import sys
import os
import torch

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

import numpy as np
import plotly.graph_objects as go

In [None]:
def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()

In [None]:
import sys
sys.path.insert(1, '/home/exx/Documents/equivariant_pose_graph/python')
from equivariant_pose_graph.models.transformer_flow import ResidualFlow, ResidualFlow_DiffEmbTransformerVis
 
from equivariant_pose_graph.training.flow_equivariance_training_module_nocentering_vis import EquivarianceTrainingModule
# from ndf_robot.eval.test_trained_model_place import load_data
from pytorch3d.ops import sample_farthest_points

In [None]:
from equivariant_pose_graph.models.transformer_flow import ResidualFlow_DiffEmbTransformer
from equivariant_pose_graph.training.flow_equivariance_training_module_nocentering_eval_init import EquivarianceTestingModule
checkpoint_file='/home/exx/media/DataDrive/singularity_chuerp/epg_results/residual_flow/residual_flow_occlusion_mr_refinement/2022-06-14_130306/residual_flow_occlusion_mr_refinement/dysnpvtn/checkpoints/epoch_127_global_step_16000.ckpt'

network = ResidualFlow_DiffEmbTransformerVis(
    emb_dims=4,
                    emb_nn='dgcnn', return_flow_component=False, center_feature=True,
                    inital_sampling_ratio=1, pred_weight= True, residual_on=True)
model = EquivarianceTrainingModule(
    network,
    lr=1e-4,
    image_log_period=100,
    weight_normalize='l1',
    smoothness_weight = 0.1,
    consistency_weight = 1,
    sigmoid_on=True,
    
)

model.cuda()
model.load_state_dict(torch.load(checkpoint_file)['state_dict'])
 


In [None]:
def load_data(num_points, point_data, action_class, anchor_class):
    clouds = point_data['clouds'] 
    classes = point_data['classes']
    points_raw_np = clouds
    classes_raw_np = classes

    points_action_np = points_raw_np[classes_raw_np == action_class].copy()
    points_action_mean_np = points_action_np.mean(axis=0)
    points_action_np = points_action_np - points_action_mean_np
    
    points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy()
    points_anchor_np = points_anchor_np - points_action_mean_np

    points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
    points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)
    points_action, points_anchor = subsample(num_points,points_action, points_anchor)
    return points_action.cuda(), points_anchor.cuda(), points_action_mean_np

def load_data2(num_points, point_data, action_class, anchor_class):
    clouds = point_data['clouds'] 
    classes = point_data['classes']
    pred_T_action_transformed = point_data["pred_T_action_transformed"]
    points_raw_np = clouds
    classes_raw_np = classes

    points_action_np = points_raw_np[classes_raw_np == action_class].copy()
    points_action_mean_np = points_action_np.mean(axis=0)
    pred_T_action_transformed = pred_T_action_transformed+ points_action_mean_np
    print("points_action_mean_np.shape:{}".format(points_action_mean_np.shape))
    
    points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy()
#     points_anchor_np = points_anchor_np - points_action_mean_np

    points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
    points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)
    points_action, points_anchor = subsample(num_points,points_action, points_anchor)
    return torch.from_numpy(pred_T_action_transformed).cuda()

def subsample(num_points,points_action,points_anchor):
    if(points_action.shape[1] > num_points):
        points_action, _ = sample_farthest_points(points_action, 
            K=num_points, random_start_point=True)
    elif(points_action.shape[1] < num_points):
        raise NotImplementedError(f'Action point cloud is smaller than cloud size ({points_action.shape[1]} < {num_points})')

    if(points_anchor.shape[1] > num_points):
        points_anchor, _ = sample_farthest_points(points_anchor, 
            K=num_points, random_start_point=True)
    elif(points_anchor.shape[1] < num_points):
        raise NotImplementedError(f'Anchor point cloud is smaller than cloud size ({points_anchor.shape[1]} < {num_points})')
    
    return points_action, points_anchor

In [None]:
def plot(points_action, points_anchor):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    points_action_dp = toDisplay(points_action)
    points_anchor_dp = toDisplay(points_anchor)
    go_data=[
        go.Scatter3d(x=points_action_dp[::skip,0], y=points_action_dp[::skip,1], z=points_action_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=points_anchor_dp[::skip,0], y=points_anchor_dp[::skip,1], z=points_anchor_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()
    
def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()

def plot_multi(plist, idx=None):
    colors = [
        '#ee6677', # light red
        '#bb5566', # dark red
        '#66ccee',  # cyan
        '#1f77b4',  # muted blue
    '#bb5566', # dark red
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple

    '#8c564b',  # chestnut brown
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    if idx == None:
        idx = list(np.arange(len(plist)))
    skip = 1
    go_data = []
    for i in range(len(plist)):
        p_dp = toDisplay(plist[i])
        plot = go.Scatter3d(x=p_dp[::skip,0], y=p_dp[::skip,1], z=p_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[idx[i]],
                     symbol='circle'))
        go_data.append(plot)
 
    layout = go.Layout(
        xaxis=dict(
    showgrid= False, # thin lines in the background
    zeroline= False, # thick line at x=0
    visible= False,  # numbers below
#     plot_bgcolor="rgba(0,0,0,0)"
 
 
    ), 
            yaxis=dict(
    showgrid= False, # thin lines in the background
    zeroline= False, # thick line at x=0
    visible= False,  # numbers below
 
 
    ), 
 
        scene=dict(
            aspectmode='data',
   
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
 
    fig.update_layout(#plot_bgcolor='rgb(12,163,135)',
                      #paper_bgcolor='rgb(12,163,135)'
                      #coloraxis={"colorbar": {"x": -0.2, "len": 0.5, "y": 0.8}}, #I think this is for contours
                     scene = dict(
                                    xaxis = dict(
                                         backgroundcolor="rgba(0, 0, 0,0)",
                                         gridcolor="white",
                                         showbackground=True,
                                         zerolinecolor="white",
                                        showticklabels = False),
                                    yaxis = dict(
                                        backgroundcolor="rgba(0, 0, 0,0)",
                                        gridcolor="white",
                                        showbackground=True,
                                        zerolinecolor="white",
                                    showticklabels = False),
                                    zaxis = dict(
                                        backgroundcolor="rgba(0, 0, 0,0)",
                                        gridcolor="white",
                                        showbackground=True,
                                        zerolinecolor="white",
                                    showticklabels = False),),
                     )
    fig.show()
def xyz2homo(xyz):
    """
    xyz:shape 1,num_points, 3
    """
    num_points = xyz.shape[1]
    homo = torch.cat([xyz.squeeze(0).detach().cpu(),torch.ones(num_points,1)],dim=-1)
    return homo
def transform(T,points):
    """
    points: num_points, 4
    """
    points = torch.permute(points,(-1,-2)) # 4,1000
    apply_here= torch.from_numpy(T).cuda()@points.cuda()
    apply_here = torch.permute(apply_here, (-1, -2))
    return apply_here[:,:3]

In [None]:
from pathlib import Path
batch_idx= 0
data_path = Path('/home/exx/Documents/equivariant_pose_graph/graph')
point_data = np.load(data_path / f'{batch_idx}_sim.npz', allow_pickle = True)
sim = point_data['points_colors']
print(sim.shape)
sim_points = sim[:,:3]
sim_colors = sim[:,3:]

skip = 1
# points_action_dp = toDisplay(points_action)
# points_anchor_dp = toDisplay(points_anchor)
go_data=[
    go.Scatter3d(x=sim_points[::skip,0], y=sim_points[::skip,1], z=sim_points[::skip,2], 
                 mode='markers', marker=dict(size=1, color=sim_colors,
                 symbol='circle')),
#     go.Scatter3d(x=points_anchor_dp[::skip,0], y=points_anchor_dp[::skip,1], z=points_anchor_dp[::skip,2], 
#                  mode='markers', marker=dict(size=1, color=colors[1],
#                  symbol='circle')),
]
layout = go.Layout(
    scene=dict(
        aspectmode='data'
    )
)

fig = go.Figure(data=go_data, layout=layout)
fig.update_layout(#plot_bgcolor='rgb(12,163,135)',
                      #paper_bgcolor='rgb(12,163,135)'
                      #coloraxis={"colorbar": {"x": -0.2, "len": 0.5, "y": 0.8}}, #I think this is for contours
                     scene = dict(
                                    xaxis = dict(
                                         backgroundcolor="rgba(0, 0, 0,0)",
                                         gridcolor="white",
                                         showbackground=True,
                                         zerolinecolor="white",
                                        showticklabels = False),
                                    yaxis = dict(
                                        backgroundcolor="rgba(0, 0, 0,0)",
                                        gridcolor="white",
                                        showbackground=True,
                                        zerolinecolor="white",
                                    showticklabels = False),
                                    zaxis = dict(
                                        backgroundcolor="rgba(0, 0, 0,0)",
                                        gridcolor="white",
                                        showbackground=True,
                                        zerolinecolor="white",
                                    showticklabels = False),),
                     )
fig.show()

In [None]:
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3

num_classes = 3
data_idx = 8
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'init'
data_path = Path('/home/exx/Documents/ndf_robot/train_data_ndf_mug_place_0/renders') 
# data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/test_grasp_place_demo_10_only_test_ids_upright_seed0_loop5')
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
print(point_data['shapenet_id'])
points_mug, points_rack, points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
# points_mug, points_gripper,points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 2)
points_anchor[:,:,0]+=0.03
points_anchor[:,:,-1]-=0.03
plot_multi([points_mug, points_rack],[3,0,5] )

In [None]:
x_action, x_anchor, action_embedding, anchor_embedding  = model.model(points_mug, points_rack)

In [None]:
from equivariant_pose_graph.utils.se3 import dualflow2pose
pred_flow_action = x_action[:, :, :3]
if(x_action .shape[2] > 3):
  
    pred_w_action = torch.sigmoid(x_action[:, :, 3])
      
else:
    pred_w_action = None

pred_flow_anchor = x_anchor[:, :, :3]
if(x_anchor .shape[2] > 3):
     
    pred_w_anchor = torch.sigmoid(x_anchor[:, :, 3])
else:
    pred_w_anchor = None

pred_T_action = dualflow2pose(xyz_src=points_mug, xyz_tgt=points_rack,
                              flow_src=pred_flow_action, flow_tgt=pred_flow_anchor,
                              weights_src=pred_w_action, weights_tgt=pred_w_anchor,
                              return_transform3d=True, normalization_scehme='l1',
                              temperature=1)

pred_points_action = pred_T_action.transform_points(
    points_mug)

In [None]:
plot_multi([points_mug, points_rack,pred_points_action],[3,0,5] )

In [None]:
action_emb = action_embedding[0]  # 4,num_points
annchor_emb = anchor_embedding[0]  # 4,num_points
action_emb = F.normalize(action_emb, dim=0)
annchor_emb = F.normalize(annchor_emb, dim=0)
color = action_emb[:3].T.detach().cpu().numpy()  # 1024,3
color_trans = annchor_emb[:3].T.detach().cpu().numpy()

color = 255*(color + 1)/2.
color_trans = 255*(color_trans + 1)/2.

In [None]:
print(color.shape)
print(action_emb.shape)

In [None]:
points_action.shape


In [None]:
t=  pure_translation_se3(1, torch.Tensor([0,-0.4,0]), device=points_action.device)
R_x = Rotate(axis_angle_to_matrix(torch.tensor([-1.6,-1.2,0.3])), device=points_action.device)
T_acc = T3.compose(R_x).compose(t)
plot_multi([T_acc.transform_points(points_action),points_action,points_anchor],idx=[3,3,0])

In [None]:
 
t1 =  pure_translation_se3(1, torch.Tensor([0,0.4,0]), device=points_action.device)
R_x1 = Rotate(axis_angle_to_matrix(torch.tensor([1.2,0.9,0.3])), device=points_action.device)
T_acc1 = T2.inverse().compose(R_x1).compose(t1)

In [None]:
points_action, points_anchor, points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 1, anchor_class= 0)
plot_multi([points_action, T_acc1.transform_points(points_action),points_anchor],idx=[0,0,3])

In [None]:
t1 =  pure_translation_se3(1, torch.Tensor([0,0.3,0]), device=points_action.device)
R_x1 = Rotate(axis_angle_to_matrix(torch.tensor([0.8,1.2,0])), device=points_action.device)
T_acc1 = T2.inverse().compose(R_x1).compose(t1)

In [None]:
points_action, points_anchor, points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 1, anchor_class= 0)
plot_multi([T_acc1.transform_points(points_action),points_anchor],idx=[0,3])

In [None]:
points_action, points_anchor, points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 1, anchor_class= 0)
plot_multi([points_action,T_acc1.transform_points(points_action),points_anchor],idx=[0,0,3])

In [None]:
plot_multi([T_acc1.transform_points(points_anchor),points_anchor,T_acc1.transform_points(points_action)],idx=[3,3,0])

In [None]:
# points_action, points_anchor, points_action_mean = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
plot_multi([points_action,T_acc1.transform_points(points_anchor),T_acc1.transform_points(points_action)],idx=[0,3,0])

In [None]:
plot_multi([T_acc1.transform_points(points_anchor),T_acc1.transform_points(points_action)],idx=[3,0])

In [None]:
T_trans = pure_translation_se3(
                1, points_action_mean.squeeze(), device=points_action.device)

In [None]:
plot_multi([points_anchor,T_acc.inverse().transform_points(points_anchor),points_action])

In [None]:
print(T_acc.get_matrix())

In [None]:
print(T3.get_matrix())

### points_action_trans = T1.compose(T2).transform_points(points_action)
plot_multi([points_action_trans,points_anchor])

In [None]:
T1_action_mean = T1.transform_points(points_action).mean(1)
T1_action_centered = T1.transform_points(points_action)-T1_action_mean
points_anchor_centered = points_anchor - T1_action_mean
points_action_target = T2.transform_points(T1_action_centered)
plot_multi([points_action_target,points_anchor_centered])
points_action_final = points_action_target+T1_action_mean
points_anchor_final = points_anchor_centered+T1_action_mean

In [None]:
T_trans = pure_translation_se3(
                1, T1_action_mean.squeeze(), device=points_action.device)


T2_in_points_action_frame = T1.compose(T_trans.inverse().compose(T2.compose(T_trans)))
points_action_trans_in_original_frame = T2_in_points_action_frame.transform_points(points_action)
plot_multi([points_action_trans_in_original_frame, points_anchor])

In [None]:
plot_multi([points_action_final,points_action_trans_in_original_frame, points_anchor,points_anchor_final])

In [None]:
plot_multi([points_action_trans_in_original_frame, points_anchor,points_action_target,points_anchor_centered])

In [None]:
plot_multi([points_action_target,points_anchor_centered,points_action_trans,points_anchor])

In [None]:
T1.transform_points(points_action).shape

In [None]:
T_applied_once = T1.compose()

In [None]:
num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'teleport'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/heyo1') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
# print(point_data['shapenet_id'])
points_action14, points_anchor = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
 
plot_multi([points_action14,points_anchor])

pred_T_action_transformed = load_data2(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
 
 

In [None]:
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3

num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'teleport'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/newtest5') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
# print(point_data['shapenet_id'])
points_action5, points_anchor = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
pred_T_action_transformed = load_data2(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
 
# pred_T_action_transformed = torch.from_numpy(point_data["pred_T_action_transformed"]).cuda()-point_data["pred_T_action_transformed"]
# print(pred_T_action_transformed.shape)
# pred_T_action_transformed[:,:,[0,1]] = pred_T_action_transformed[:,:,[1,0]]
plot_multi([points_action5,points_anchor,pred_T_action_transformed])
 

In [None]:
num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'init'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/newtest3') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
pred_T_action_transformed_pybullet = torch.from_numpy(point_data["pred_T_action_transformed"]).cuda()
pred_T_action_mat = point_data["pred_T_action_mat"]
points_action_pybullet =torch.from_numpy(point_data["points_action"]).cuda()
points_anchor_pybullet = torch.from_numpy(point_data["points_anchor"]).cuda()
points_action_init, points_anchor = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
plot_multi([pred_T_action_transformed_pybullet,points_anchor_pybullet,points_action_pybullet])

In [None]:
num_classes = 3
data_idx = 0
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'teleport'
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/newtest3') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
points_action_action3, points_anchor = load_data(num_points=1000, point_data=point_data, action_class= 0, anchor_class= 1)
# points_action_init[:,:,1]*=-1
# points_anchor[:,:,1]*=-1
# points_action_init[:,:,[1,2]]=points_action_init[:,:,[2,1]]
# points_anchor[:,:,[1,2]]=points_anchor[:,:,[2,1]]
# print(pred_T_action_mat)
# pred_T_action_mat[[1,2]] = pred_T_action_mat[[2,1]]
# print(pred_T_action_mat)
# pred_T_action_mat[:3,-1]=pred_T_action_mat[:3,:3].T@pred_T_action_mat[:3,-1]+ pred_T_action_mat[:3,-1]
# print(pred_T_action_mat)

apply_here_xyz = transform(pred_T_action_mat, xyz2homo(points_action_init))
plot_multi([apply_here_xyz,points_anchor, points_action_init,points_action_action3])

In [None]:
ans = model.get_transform(
    points_action_init, points_anchor)  # 1, 4, 4
pred_T_action_init = ans["pred_T_action"]

In [None]:
pred_T_action_transformed = pred_T_action_init.transform_points(
    points_action_init)
plot_multi([pred_T_action_transformed,points_anchor, points_action_init])

In [None]:
pred_T_action_mat = pred_T_action_init.get_matrix()[0].T.detach().cpu().numpy()
apply_here_xyz = transform(pred_T_action_mat, xyz2homo(points_action_init))
plot_multi([apply_here_xyz,points_anchor, points_action_init])

In [None]:
plot_multi([points_action0,points_anchor,points_action1,points_action3,points_action4,points_action5, points_action7])

In [None]:
plot_multi([points_action3,points_anchor,points_action4,points_action5])

In [None]:
stop

In [None]:
plot(points_action, points_anchor)

In [None]:
plot(pred_T_action_points_action, points_anchor)

In [None]:
# Load Teleport
from pathlib import Path
import torch.nn.functional as F
from equivariant_pose_graph.utils.se3 import random_se3

num_classes = 3
cloud_types = ['init', 'pre_grasp', 'post_grasp', 'teleport', 'post_place', 'final']
cloud_type = 'teleport'
# data_path = Path('/home/bokorn/src/ndf_robot/notebooks')
# point_data = np.load(data_path / f'{data_idx}_obj_points.npz')
# data_path = Path('/home/exx/Documents/ndf_robot/train_new_data_3/renders')
# data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/debug_place_shapeid_34ae0b61b0d8aaf2d7b20fded0142d7a') 
data_path = Path('/home/exx/Documents/ndf_robot/src/ndf_robot/place_test_0_my_model_may10_overfit') 
point_data = np.load(data_path / f'{data_idx}_{cloud_type}_obj_points.npz', allow_pickle = True)
print(point_data['shapenet_id'])
points_action, points_anchor = load_data(num_points=1024, clouds = point_data['clouds'] ,classes = point_data['classes'], action_class= 0, anchor_class= 1)
 
points_action_trans = points_action
points_anchor_trans = points_anchor
pred_T_action, pred_T_anchor, pred_T  = place_model(points_action_trans, points_anchor_trans)
 
pred_T_action_points_action = pred_T_action.transform_points(points_action_trans)
pred_T_action_points_anchor = pred_T_anchor.transform_points(points_anchor_trans)
plot(points_action, points_anchor)
 
for i in range(20):
    pred_T_action, pred_T_anchor, pred_T  = place_model(pred_T_action_points_action, points_anchor)
    pred_T_action_points_action = pred_T_action.transform_points(pred_T_action_points_action)
    
    torch.cuda.empty_cache() 
plot(pred_T_action_points_action, points_anchor)