In [100]:
# Copyright (c) Gorilla-Lab. All rights reserved.
import os
from os.path import join as opj
import torch
import torch.nn as nn
import numpy as np

import open3d as o3d
from open3d.web_visualizer import draw
import matplotlib.pyplot as plt
import trimesh

# from gorilla.config import Config
# import models
# import loss
from utils import *
from dataset import *
# import argparse

In [101]:
# set_random_seed(0)

data = dict(
    data_root='data_root',
    partial=False,
    category=['grasp', 'contain', 'lift', 'openable', 'layable', 'sittable',
              'support', 'wrap_grasp', 'pourable', 'move', 'displaY', 'pushable', 'pull',
              'listen', 'wear', 'press', 'cut', 'stab']
)

afford_cat = data['category']
data_root = data['data_root']
partial = data['partial']
aff_train_set = AffordNetDataset(data_root, 'train', partial=partial)
aff_val_set = AffordNetDataset(data_root, 'val', partial=partial)
print("Loaded dataset with {} training samples and {} validation samples".format(len(aff_train_set), len(aff_val_set)))


Loaded dataset with 16082 training samples and 2285 validation samples


In [102]:
# pick a random datapoint
while(1):
    rand_index = np.random.randint(0, len(aff_train_set.all_data))#9532
    data_temp = aff_train_set.all_data[rand_index]
    print("Choosing random index: ", rand_index)
    if data_temp['semantic class'] == "Mug":
        break

Choosing random index:  10440
Choosing random index:  15256
Choosing random index:  6628
Choosing random index:  15208
Choosing random index:  4966
Choosing random index:  14320
Choosing random index:  14784
Choosing random index:  6968
Choosing random index:  5236
Choosing random index:  9618
Choosing random index:  6902
Choosing random index:  4765
Choosing random index:  827
Choosing random index:  1739
Choosing random index:  15476
Choosing random index:  8477
Choosing random index:  9944
Choosing random index:  15060
Choosing random index:  11272
Choosing random index:  3997
Choosing random index:  4935
Choosing random index:  8670
Choosing random index:  990
Choosing random index:  9680
Choosing random index:  8996
Choosing random index:  1641
Choosing random index:  2339
Choosing random index:  9254
Choosing random index:  1697
Choosing random index:  11093
Choosing random index:  5125
Choosing random index:  13947
Choosing random index:  3810
Choosing random index:  4079
Choosi

In [103]:
print("Shape id ", data_temp['shape_id'], "\nSemantic class: ", data_temp['semantic class'])
print("All affordances: ", data_temp["affordance"])
print("Num affordances: ", len(data_temp["affordance"]))

Shape id  dcec634f18e12427c2c72e575af174cd 
Semantic class:  Mug
All affordances:  ['grasp', 'contain', 'lift', 'openable', 'layable', 'sittable', 'support', 'wrap_grasp', 'pourable', 'move', 'displaY', 'pushable', 'pull', 'listen', 'wear', 'press', 'cut', 'stab']
Num affordances:  18


In [104]:
pcl = data_temp['data_info']['coordinate']
# Viz pcl with o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pcl)
# color black
pcd.paint_uniform_color([0, 0, 0])
# viz in jupyter nb
print("Drawing pcd with %d points", len(pcl))
draw(pcd)

Drawing pcd with %d points 2048
[Open3D INFO] Window window_32 created.


WebVisualizer(window_uid='window_32')

In [105]:
# get pcls of all labels
labels_list = []
scores_list = []
for key in data_temp['data_info']['label'].keys():
    labels = data_temp['data_info']['label'][key]
    scores = labels.sum()
    print("Label: ", key, " Score: ", scores)
    labels_list.append(labels)
    scores_list.append(scores)    

Label:  grasp  Score:  68.00398
Label:  contain  Score:  235.93713
Label:  lift  Score:  0.0
Label:  openable  Score:  0.0
Label:  layable  Score:  0.0
Label:  sittable  Score:  0.0
Label:  support  Score:  0.0
Label:  wrap_grasp  Score:  199.07465
Label:  pourable  Score:  205.49992
Label:  move  Score:  0.0
Label:  displaY  Score:  0.0
Label:  pushable  Score:  0.0
Label:  pull  Score:  0.0
Label:  listen  Score:  0.0
Label:  wear  Score:  0.0
Label:  press  Score:  0.0
Label:  cut  Score:  0.0
Label:  stab  Score:  0.0


In [106]:
labels_list = np.array(labels_list)
scores_list = np.array(scores_list)
# Get max 4 labels and viz
best_score_ids = np.argsort(scores_list)[-4:]
print("Best labels: ")
for best_score_id in best_score_ids:
    print(list(data_temp['data_info']['label'].keys())[best_score_id])
print("Best scores: ", scores_list[best_score_ids])

Best labels: 
grasp
wrap_grasp
pourable
contain
Best scores:  [ 68.00398 199.07465 205.49992 235.93713]


In [None]:
pcd_0 = o3d.geometry.PointCloud()
pcd_0.points = o3d.utility.Vector3dVector(pcl)
pcd_0.paint_uniform_color([0, 0, 0])
np.asarray(pcd_0.colors)[labels_list[best_score_ids[0]].squeeze() > 0] = [1, 0, 0]
draw([pcd_0], title="Label: " + list(data_temp['data_info']['label'].keys())[best_score_ids[0]])

pcd_1 = o3d.geometry.PointCloud()
pcd_1.points = o3d.utility.Vector3dVector(pcl)
pcd_1.paint_uniform_color([0, 0, 0])
np.asarray(pcd_1.colors)[labels_list[best_score_ids[1]].squeeze() > 0] = [0, 1, 0]
draw([pcd_1], title="Label: " + list(data_temp['data_info']['label'].keys())[best_score_ids[1]])

pcd_2 = o3d.geometry.PointCloud()
pcd_2.points = o3d.utility.Vector3dVector(pcl)
pcd_2.paint_uniform_color([0, 0, 0])
np.asarray(pcd_2.colors)[labels_list[best_score_ids[2]].squeeze() > 0] = [0, 0, 1]
draw([pcd_2], title="Label: " + list(data_temp['data_info']['label'].keys())[best_score_ids[2]])

pcd_3 = o3d.geometry.PointCloud()
pcd_3.points = o3d.utility.Vector3dVector(pcl)
pcd_3.paint_uniform_color([0, 0, 0])
np.asarray(pcd_3.colors)[labels_list[best_score_ids[3]].squeeze() > 0] = [1, 1, 0]
draw([pcd_3], title="Label: " + list(data_temp['data_info']['label'].keys())[best_score_ids[3]])

[Open3D INFO] Window window_33 created.


WebVisualizer(window_uid='window_33')

[Open3D INFO] Window window_34 created.


WebVisualizer(window_uid='window_34')

[Open3D INFO] Window window_35 created.


WebVisualizer(window_uid='window_35')

[Open3D INFO] Window window_36 created.


WebVisualizer(window_uid='window_36')

[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceServers
[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceServers
[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceServers
[Open3D INFO] [Called HTTP API (custom handshake)] /api/call
[Open3D INFO] [Called HTTP API (custom handshake)] /api/call
[Open3D INFO] [Called HTTP API (custom handshake)] /api/call
[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceServers
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] DataChannelObserver::OnStateChange label: ServerDataChannel, state: open, peerid: 0.01588552042400848
[Open3D INFO] DataChannelObserver::OnStateChange label: ClientDataChannel, state: open, peerid: 0.01588552042400848
[Open3D INFO] Sending init frames to window_33.
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate


[1650:798][11089] (stun_port.cc:96): Binding request timed out from 130.83.164.x:53338 (enp5s0)


[Open3D INFO] DataChannelObserver::OnStateChange label: ServerDataChannel, state: open, peerid: 0.05860516265367588
[Open3D INFO] DataChannelObserver::OnStateChange label: ClientDataChannel, state: open, peerid: 0.05860516265367588
[Open3D INFO] Sending init frames to window_34.
[Open3D INFO] [Called HTTP API (custom handshake)] /api/call
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/addIceCandidate
[Open3D INFO] [Called HTTP API (custom handshake)] /api/getIceCandidate
[Open3D INFO] DataChannelObserver::OnStateChange label: ServerDataChannel, state: open, peerid: 0.7369151970921259
[Open3D INFO] DataChannelObserver::OnSta

[1653:270][11089] (stun_port.cc:96): Binding request timed out from 130.83.164.x:39043 (enp5s0)
[1653:276][11089] (stun_port.cc:96): Binding request timed out from 130.83.164.x:41140 (enp5s0)
[1653:279][11089] (stun_port.cc:96): Binding request timed out from 130.83.164.x:47451 (enp5s0)
[1653:285][11089] (stun_port.cc:96): Binding request timed out from 130.83.164.x:38871 (enp5s0)
