In [None]:
import sys
import os
import torch

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'

import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [None]:
def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()

In [None]:
import sys
sys.path.insert(1, '/home/exx/Documents/equivariant_pose_graph/python')
from equivariant_pose_graph.models.transformer_flow import ResidualFlow, ResidualFlow_V1
from equivariant_pose_graph.training.flow_equivariance_training_module import EquivarianceTrainingModule
 
# from ndf_robot.eval.test_trained_model_place import load_data
from equivariant_pose_graph.utils.display import scatter3d, quiver3d
from pytorch3d.transforms import Transform3d

In [None]:
def plot(points_action, points_anchor):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    points_action_dp = toDisplay(points_action)
    points_anchor_dp = toDisplay(points_anchor)
    go_data=[
        go.Scatter3d(x=points_action_dp[::skip,0], y=points_action_dp[::skip,1], z=points_action_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=points_anchor_dp[::skip,0], y=points_anchor_dp[::skip,1], z=points_anchor_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()

In [None]:
def plot_multi(p1, p2, p3, p4):
    colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf'   # blue-teal
]
    skip = 1
    p1_dp = toDisplay(p1)
    p2_dp = toDisplay(p2)
    p3_dp = toDisplay(p3)
    p4_dp = toDisplay(p4)
    go_data=[
        go.Scatter3d(x=p1_dp[::skip,0], y=p1_dp[::skip,1], z=p1_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[0],
                     symbol='circle')),
        go.Scatter3d(x=p2_dp[::skip,0], y=p2_dp[::skip,1], z=p2_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[1],
                     symbol='circle')),
        go.Scatter3d(x=p3_dp[::skip,0], y=p3_dp[::skip,1], z=p3_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[2],
                     symbol='circle')),
        go.Scatter3d(x=p4_dp[::skip,0], y=p4_dp[::skip,1], z=p4_dp[::skip,2], 
                     mode='markers', marker=dict(size=1, color=colors[3],
                     symbol='circle')),
    ]
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()

In [None]:
network = ResidualFlow()

model = EquivarianceTrainingModule(
    network)
model.cuda()
# checkpoint_file = "/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0/equiv_dcpflow/version_9/checkpoints/epoch=410-step=51375.ckpt"
# checkpoint_file = "/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0_attn_trans0.1_overfit/equiv_dcpflow/version_3/saved_checkpoints/epoch=30-step=3875.ckpt"
# checkpoint_file = "/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual1_attn/equiv_dcpflow/version_1/checkpoints/epoch=754-step=94375.ckpt"
checkpoint_file = '/home/exx/media/DataDrive/singularity_chuerp/equiv_pgraph_logs/train_test_mr_dcpflow_residual0_attn_trans0.1_ovefit/equiv_dcpflow/version_28/checkpoints/epoch=2-step=375.ckpt'
model.load_state_dict(torch.load(checkpoint_file)['state_dict'])
 
##
place_model = model.eval() 

In [None]:

pred_T = np.load('/home/exx/Documents/equivariant_pose_graph/transformations/display_pointclouds.npz', allow_pickle=True)
pred_T_action = torch.from_numpy(pred_T['pred_T_action'])
pred_T_action_transform = Transform3d(matrix=pred_T_action)
 
pred_T_anchor = torch.from_numpy(pred_T['pred_T_anchor'])
pred_T_anchor_transform = Transform3d(matrix=pred_T_anchor)
points_trans_action = torch.from_numpy(pred_T['points_trans_action'])
points_trans_anchor = torch.from_numpy(pred_T['points_trans_anchor'])
action_transformed = torch.from_numpy(pred_T['action_transformed'])
anchor_transformed = torch.from_numpy(pred_T['anchor_transformed'])
points_action = torch.from_numpy(pred_T['points_action'])
points_anchor = torch.from_numpy(pred_T['points_anchor'])


our_action_transformed = pred_T_action_transform.transform_points(points_trans_action)
our_anchor_transformed = pred_T_anchor_transform.transform_points(points_trans_anchor)

In [None]:
 
print(torch.abs(action_transformed-our_action_transformed).sum())
print(torch.allclose(action_transformed,our_action_transformed, atol=1e-2))

In [None]:
plot(points_action[0], points_anchor[0])

In [None]:
plot(our_action_transformed[0], points_trans_anchor[0])

In [None]:
plot(action_transformed[0], points_trans_anchor[0])

In [None]:
raw_point_cloud = np.load('/home/exx/Documents/equivariant_pose_graph/transformations/load_data.npz', allow_pickle=True)

points_raw_np = raw_point_cloud['points_raw_np']
classes_raw_np = raw_point_cloud['classes_raw_np']


from ndf_robot.eval.test_trained_model_place import load_data
points_action_loaded, points_anchor_loaded = load_data(num_points=1024, clouds = points_raw_np ,\
                                         classes = classes_raw_np, action_class= 0, anchor_class= 1)
MODEL_pred_T_action, MODEL_pred_T_anchor, MODEL_pred_T  = place_model(points_trans_action.cuda(), points_trans_anchor.cuda())
MODEL_action_transformed = MODEL_pred_T_action.transform_points(points_trans_action.cuda())

In [None]:
plot(points_action[0], points_action_loaded[0])

In [None]:
plot(points_anchor[0], points_anchor_loaded[0])

In [None]:
plot_multi(points_action_loaded[0], points_anchor_loaded[0],points_action[0], points_anchor[0])

In [None]:
print(points_action[0:1].shape)
print(points_action_loaded.shape)

In [None]:
plot(MODEL_action_transformed[0], points_trans_anchor[0])

In [None]:
stop