In [3]:
import re, os, json, h5py, pickle
from pathlib import Path
from collections import OrderedDict, defaultdict, Counter
import os.path as osp
import numpy as np
import pandas as pd
import open3d as o3d
from fire import Fire
from natsort import natsorted
from loguru import logger
from plyfile import PlyData, PlyElement
from scipy.spatial import cKDTree
from copy import deepcopy
from pxr import Usd, UsdGeom, UsdSkel, UsdPhysics, Gf
from collections import defaultdict

In [12]:
def compute_normals(mesh):
    # Get vertex positions
    points = mesh.GetPointsAttr().Get()
    # Get face indices and counts
    face_vertex_indices = mesh.GetFaceVertexIndicesAttr().Get()
    face_vertex_counts = mesh.GetFaceVertexCountsAttr().Get()

    if not points or not face_vertex_indices or not face_vertex_counts:
        print("Mesh data is incomplete.")
        return None

    # Convert to NumPy for easier computation
    points = np.array([list(p) for p in points])
    face_vertex_indices = np.array(face_vertex_indices)
    face_vertex_counts = np.array(face_vertex_counts)

    # Initialize normals
    normals = np.zeros_like(points)

    # Compute normals per face
    index = 0
    for count in face_vertex_counts:
        if count == 3:  # Only process triangular faces
            v0, v1, v2 = face_vertex_indices[index:index + 3]
            p0, p1, p2 = points[v0], points[v1], points[v2]

            # Compute face normal
            normal = np.cross(p1 - p0, p2 - p0)
            normal = normal / np.linalg.norm(normal)  # Normalize

            # Add to vertex normals
            normals[v0] += normal
            normals[v1] += normal
            normals[v2] += normal

        index += count

    # Normalize vertex normals
    normals = np.array([n / np.linalg.norm(n) if np.linalg.norm(n) > 0 else n for n in normals])
    return normals

class USDAParser:
    def __init__(self, file_path, interaction_as_movement = False, exlude_stuff = False):
        self.file_path = file_path
        self.interaction_as_movement = interaction_as_movement
        self.exlude_stuff = exlude_stuff
        
        self.stage = Usd.Stage.Open(self.file_path)
        self.points = {}
        self.colors = {}
        self.normals = {}
                
        self.mesh_id_to_path = {}
        self.mesh_path_to_id = {}
        
        self.mesh_hierarchy = {}
        self.mesh_offspring_tree = defaultdict(list)
        self.mesh_ancestor_tree = defaultdict(None)
        
        self.is_mov = {}
        self.is_inter = {}
        self.mov_ids = []
        self.inter_ids = []
        self.parse()

    def parse(self):
        """Parses the USD file for meshes, point clouds, and joints."""
        for prim in self.stage.Traverse():
            # 1. Capture Meshes and point clouds
            if prim.IsA(UsdGeom.Mesh):
                points = self._get_pointcloud(prim)
                normals = self._get_normals(prim)
                colors = self._get_color(prim)
                
                prim_path = str(prim.GetPath())
                mesh_id = int(prim_path.split('_')[-1])
                
                self.points[prim_path] = points
                self.colors[prim_path] = colors
                self.normals[prim_path] = normals
                # Extracting ID from the prim path (assuming the ID is at the end)
                
                self.mesh_id_to_path[mesh_id] = prim_path
                self.mesh_path_to_id[prim_path] = mesh_id

                # set movable and interactable flags
                self.is_mov[mesh_id] = False
                self.is_inter[mesh_id] = False
                # 2. Capture Hierarchy
                self._add_hierarchy_edge(prim_path)

        for prim in self.stage.Traverse():    
            if prim.IsA(UsdGeom.Mesh):
                # 3. Capture movable and interactable objects
                if prim.HasAttribute("movable") and bool(prim.GetAttribute("movable").Get()):
                    id = int(self.mesh_path_to_id[str(prim.GetPath())])
                    self.mov_ids.append(id)
                    self.is_mov[id] = True
                if prim.HasAttribute("interactable") and bool(prim.GetAttribute("interactable").Get()):
                    id = int(self.mesh_path_to_id[str(prim.GetPath())])
                    self.inter_ids.append(id)
                    self.is_inter[id] = True
                    
        # get connectivity of articulation parts
        self.parse_connectivity()
                    
    def parse_connectivity(self):
        self.edges = defaultdict(set)
        part_pathes = list(self.mesh_path_to_id.keys())
        # Function to extract the numeric part of a node
        def extract_node_number(node):
            match = re.search(r'_(\d+)$', node)
            if match:
                return int(match.group(1))
            return None

        def extract_root_part_id(path):
            # Split the path into components
            components = path.split('/')
            for component in components:
                # Match the first numeric part after the root
                match = re.search(r'_(\d+)$', component)
                if match:
                    return int(match.group(1))
            return None  # Return None if no numeric ID is found

        # Process the hierarchy and build edges
        for part in part_pathes:
            components = part.split('/')
            current_node = ""
            for component in components:
                if not component:  # Skip empty components (like the first /)
                    continue
                parent_node = current_node
                current_node = f"{parent_node}/{component}" if parent_node else f"/{component}"
                # Get numbers for the parent and current node
                root_part_id = extract_root_part_id(current_node)
                current_number = extract_node_number(current_node)
                if parent_node:
                    parent_number = extract_node_number(parent_node)
                    if parent_number is not None and current_number is not None:
                        self.edges[root_part_id].add((parent_number, current_number))

    def _get_pointcloud(self, mesh_prim):
        """Extracts the point cloud from a UsdGeom.Mesh primitive."""
        mesh = UsdGeom.Mesh(mesh_prim)
        points_attr = mesh.GetPointsAttr()
        if points_attr.HasValue():
            return np.array(points_attr.Get(), dtype=np.float32)
        return np.array([], dtype=np.float32)
    def _get_color(self, mesh_prim):
        """Extracts the color from a UsdGeom.Mesh primitive."""
        mesh = UsdGeom.Mesh(mesh_prim)
        color_attr = mesh.GetDisplayColorAttr()
        if color_attr.HasValue():
            return np.array(color_attr.Get(), dtype=np.float32)
        return np.array([], dtype=np.float32)
    def _get_normals(self, mesh_prim):
        """Extracts the normals from a UsdGeom.Mesh primitive."""
        # mesh = UsdGeom.Mesh(mesh_prim)
        # normals_attr = mesh.GetNormalsAttr()
        # if normals_attr.HasValue():
        #     return np.array(normals_attr.Get(), dtype=np.float32)
        # return np.array([], dtype=np.float32)
        return compute_normals(UsdGeom.Mesh(mesh_prim))

    def _add_hierarchy_edge(self, prim_path_str):
                # Extract numbers using regex and store them in a list
        parts = prim_path_str.split('/')
        numbers = []
        for part in parts:
            # Find all numbers in the part
            found_numbers = re.findall(r'\d+', part)
            numbers.extend(map(int, found_numbers))  # Convert found numbers to integers
        hierarchical_info = []
        for i in range(len(numbers) - 1):
            hierarchical_info.append((numbers[i], numbers[i + 1]))
            self.mesh_offspring_tree[numbers[i]].append(numbers[i + 1])
            self.mesh_ancestor_tree[numbers[i + 1]] = numbers[i]

    def get_mesh_points(self):
        """Returns dictionary of meshes with their point clouds."""
        return self.points

    def get_mesh_id_to_path(self):
        """Returns dictionary of mesh ID to prim path mappings."""
        return self.mesh_id_to_path

    def get_mesh_hierarchy(self):
        """Returns the hierarchy structure with joint types."""
        return self.mesh_hierarchy

    def get_mov_ids(self):
        """Returns list of IDs for movable objects."""
        return self.mov_ids

    def get_inter_ids(self):
        """Returns list of IDs for interactable objects."""
        return self.inter_ids
    
    def get_mesh_ids(self):
        return self.mesh_id_to_path.keys()
    
    def get_pointcloud(self, mesh_id):
        mesh_path = self.mesh_id_to_path[mesh_id]
        mesh_ = self.meshes[mesh_path]
        # get point cloud from mesh
        pointcloud = []
        for i in range(len(mesh_)):
            pointcloud.append(list(mesh_[i]))
        pointcloud = np.array(pointcloud)
    
    def get_articulations(self):
        articulations = []
        inter_ids = self.get_inter_ids()
        for inter_id in inter_ids:
            # if inter_id in self.mov_ids:
            #     articulations.append({
            #         'interactable_id': inter_id,
            #         'movable_id': inter_id,
            #         'is_hierarchy_mov': False,
            #         'trace_list': [inter_id]
            #     })
            # else:
            # trace up to find movable object in the hierarchy
            is_hierarchy_mov = False
            mesh_mov_id = inter_id
            trace_list = [inter_id]
            hierarchy = self.mesh_ancestor_tree.get(inter_id, None)
            # exclude curtain and blind
            inter_name = str(self.mesh_id_to_path[inter_id])
            mov_name = str(self.mesh_id_to_path[mesh_mov_id])
            if 'curtain' in inter_name or 'blind' in inter_name:
                    continue
            while hierarchy is not None:
                trace_list.append(hierarchy)
                if self.is_mov[hierarchy]:
                    is_hierarchy_mov = True
                    mesh_mov_id = hierarchy
                    break
                hierarchy = self.mesh_ancestor_tree.get(hierarchy, None)
            if is_hierarchy_mov: 
                articulations.append({
                    'interactable_id': inter_id,
                    'movable_id': mesh_mov_id,
                    'is_hierarchy_mov': is_hierarchy_mov,
                    'trace_list': trace_list
                    })
            else:
                if not self.exlude_stuff:
                    articulations.append({
                        'interactable_id': inter_id,
                        'movable_id': inter_id,
                        'is_hierarchy_mov': is_hierarchy_mov,
                        'trace_list': [inter_id]
                        })
        return articulations


In [13]:
usda_file =  "/workspace/Mask3D_adapted/data/raw/articulate3d/scans/0a5c013435/0a5c013435.usda"
# usda_file = "/workspace/Mask3D_adapted/data/raw/articulate3d/scans/1b75758486/1b75758486.usda"
usda_parser = USDAParser(usda_file)

In [67]:
parts_list = list(usda_parser.mesh_path_to_id.keys())

In [17]:
normals = usda_parser.normals['/SceneRoot/door_frame_1']

In [20]:
np.unique(np.linalg.norm(normals, axis=1))

array([1., 1., 1., 1.])

In [80]:
parts_list

['/SceneRoot/door_frame_1',
 '/SceneRoot/door_frame_1/door_2',
 '/SceneRoot/door_frame_1/door_3',
 '/SceneRoot/door_frame_1/door_3/handle_base_4',
 '/SceneRoot/door_frame_1/door_3/handle_base_4/handle_5',
 '/SceneRoot/light_switch_base_6',
 '/SceneRoot/light_switch_base_6/light_switch_7',
 '/SceneRoot/cabinet_8',
 '/SceneRoot/cabinet_8/door_9',
 '/SceneRoot/cabinet_8/door_10',
 '/SceneRoot/cabinet_11',
 '/SceneRoot/cabinet_11/door_12',
 '/SceneRoot/cabinet_11/door_12/handle_15',
 '/SceneRoot/cabinet_11/door_13',
 '/SceneRoot/cabinet_11/door_13/handle_14',
 '/SceneRoot/cabinet_16',
 '/SceneRoot/cabinet_16/door_17',
 '/SceneRoot/cabinet_16/door_17/handle_19',
 '/SceneRoot/cabinet_16/door_18',
 '/SceneRoot/cabinet_16/door_18/handle_20',
 '/SceneRoot/microwave_21',
 '/SceneRoot/microwave_21/microwave_door_22',
 '/SceneRoot/microwave_21/microwave_door_22/microwave_handle_23',
 '/SceneRoot/microwave_21/button_base_24',
 '/SceneRoot/microwave_21/button_base_24/button_25',
 '/SceneRoot/microwa

In [81]:
import re
from collections import defaultdict
# List of scene parts
# Initialize edges list
edges = defaultdict(set)
nodes = set()
id_to_cate = {}
# Function to extract the numeric part of a node
def extract_node_number(node):
    match = re.search(r'_(\d+)$', node)
    if match:
        return int(match.group(1))
    return None

def extract_root_part_id(path):
    # Split the path into components
    components = path.split('/')
    for component in components:
        # Match the first numeric part after the root
        match = re.search(r'_(\d+)$', component)
        if match:
            return int(match.group(1))
    return None  # Return None if no numeric ID is found

def get_category_and_id(path):
    # Extract the last component of the path
    last_part = path.split('/')[-1]
    # Split by the underscore to separate category and ID
    if "_" in last_part:
        category, part_id = last_part.rsplit("_", 1)  # Split from the right
        return category, int(part_id)  # Convert part_id to integer
    else:
        return last_part, None  # If no underscore, return the whole as category and None for ID


# # Process the hierarchy and build edges
# for part in parts_list:
#     components = part.split('/')
#     current_node = ""
#     for component in components:
#         if not component:  # Skip empty components (like the first /)
#             continue
#         parent_node = current_node
#         current_node = f"{parent_node}/{component}" if parent_node else f"/{component}"

#         # Add the current node to nodes
#         cur_node_id = extract_node_number(component)
#         if cur_node_id is not None:
#             nodes.add(cur_node_id)
        
#         # Get numbers for the parent and current node
#         root_part_id = extract_root_part_id(current_node)
#         current_number = extract_node_number(current_node)
#         if parent_node:
#             parent_number = extract_node_number(parent_node)
#             if parent_number is not None and current_number is not None:
#                 edges[root_part_id].add((parent_number, current_number))

# Process the hierarchy and build edges
for part in parts_list:
    components = part.split('/')
    current_node = ""
    for component in components:
        if not component:  # Skip empty components (like the first /)
            continue
        parent_node = current_node
        current_node = f"{parent_node}/{component}" if parent_node else f"/{component}"
        # Get numbers for the parent and current node
        root_part_id = extract_root_part_id(current_node)
        current_number = extract_node_number(current_node)
        category, part_id = get_category_and_id(current_node)
        if part_id is not None:
            id_to_cate[part_id] = category
        if parent_node:
            parent_number = extract_node_number(parent_node)
            if parent_number is not None and current_number is not None:
                edges[root_part_id].add((parent_number, current_number))

# Print the edges
print("\nEdges (Hierarchies as Tuples):")
# for edge in edges:
#     print(edge)
print(edges)



Edges (Hierarchies as Tuples):
defaultdict(<class 'set'>, {1: {(4, 5), (1, 2), (1, 3), (3, 4)}, 6: {(6, 7)}, 8: {(8, 9), (8, 10)}, 11: {(11, 12), (11, 13), (13, 14), (12, 15)}, 16: {(18, 20), (17, 19), (16, 17), (16, 18)}, 21: {(21, 26), (21, 22), (22, 23), (24, 25), (21, 24)}, 27: {(28, 29), (27, 28)}, 30: {(30, 31), (31, 64)}, 32: {(32, 33), (33, 65)}, 34: {(35, 37), (34, 35), (34, 36), (36, 38)}, 40: {(40, 41)}, 42: {(44, 45), (42, 47), (42, 43), (43, 44), (42, 46)}, 48: {(72, 61), (48, 49), (49, 72)}, 50: {(50, 57), (50, 58)}, 51: {(51, 52)}, 53: {(53, 54)}, 55: {(55, 56)}, 59: {(59, 60)}, 62: {(62, 63)}, 67: {(67, 68)}, 71: {(71, 70)}})


In [40]:
usda_parser.points["/SceneRoot/door_frame_1/door_2/handle_base_5/handle_6"].shape

(303, 3)

In [9]:
mesh = UsdGeom.Mesh(usda_parser.stage.GetPrimAtPath("/SceneRoot/cabinet_8"))
normals = np.array(mesh.GetNormalsAttr())
points = np.array(mesh.GetPointsAttr())

RuntimeError: Accessed schema on invalid prim

In [75]:
usda_parser.colors["/SceneRoot/books_137"]

array([[0.40784314, 0.45882353, 0.4627451 ],
       [0.4117647 , 0.45490196, 0.45490196],
       [0.40784314, 0.45882353, 0.4627451 ],
       ...,
       [0.45882353, 0.46666667, 0.44705883],
       [0.34901962, 0.3882353 , 0.36862746],
       [0.41568628, 0.4392157 , 0.4392157 ]], dtype=float32)

In [76]:
usda_parser.points["/SceneRoot/books_137"]

array([[0.99570364, 1.4000267 , 0.7301998 ],
       [0.99498713, 1.401158  , 0.7343674 ],
       [0.9921936 , 1.4023616 , 0.72912335],
       ...,
       [1.0171105 , 1.1367188 , 0.96253556],
       [1.0062737 , 1.1485891 , 0.981486  ],
       [1.0126009 , 1.1421669 , 1.0013292 ]], dtype=float32)

In [71]:
usda_parser.colors["/SceneRoot/books_137"].min()

0.0

In [65]:
usda_parser.colors.keys()

dict_keys(['/SceneRoot/door_frame_1', '/SceneRoot/door_frame_1/door_2', '/SceneRoot/door_frame_1/door_3', '/SceneRoot/door_frame_1/door_3/handle_base_4', '/SceneRoot/door_frame_1/door_3/handle_base_4/handle_5', '/SceneRoot/light_switch_base_6', '/SceneRoot/light_switch_base_6/light_switch_7', '/SceneRoot/cabinet_8', '/SceneRoot/cabinet_8/door_9', '/SceneRoot/cabinet_8/door_10', '/SceneRoot/cabinet_11', '/SceneRoot/cabinet_11/door_12', '/SceneRoot/cabinet_11/door_12/handle_15', '/SceneRoot/cabinet_11/door_13', '/SceneRoot/cabinet_11/door_13/handle_14', '/SceneRoot/cabinet_16', '/SceneRoot/cabinet_16/door_17', '/SceneRoot/cabinet_16/door_17/handle_19', '/SceneRoot/cabinet_16/door_18', '/SceneRoot/cabinet_16/door_18/handle_20', '/SceneRoot/microwave_21', '/SceneRoot/microwave_21/microwave_door_22', '/SceneRoot/microwave_21/microwave_door_22/microwave_handle_23', '/SceneRoot/microwave_21/button_base_24', '/SceneRoot/microwave_21/button_base_24/button_25', '/SceneRoot/microwave_21/button_26

In [82]:
id_to_cate

{1: 'door_frame',
 2: 'door',
 3: 'door',
 4: 'handle_base',
 5: 'handle',
 6: 'light_switch_base',
 7: 'light_switch',
 8: 'cabinet',
 9: 'door',
 10: 'door',
 11: 'cabinet',
 12: 'door',
 15: 'handle',
 13: 'door',
 14: 'handle',
 16: 'cabinet',
 17: 'door',
 19: 'handle',
 18: 'door',
 20: 'handle',
 21: 'microwave',
 22: 'microwave_door',
 23: 'microwave_handle',
 24: 'button_base',
 25: 'button',
 26: 'button',
 27: 'sink',
 28: 'faucet',
 29: 'faucet_control',
 30: 'soap_dispenser',
 31: 'screw_cap',
 64: 'pump_head',
 133: 'bag',
 73: 'wall',
 77: 'wall',
 82: 'wall',
 32: 'soap_dispenser',
 33: 'screw_cap',
 65: 'pump_head',
 34: 'cabinet',
 35: 'door',
 37: 'handle',
 36: 'door',
 38: 'handle',
 83: 'wall',
 40: 'switch_base',
 41: 'switch',
 42: 'window_frame',
 43: 'window',
 44: 'handle_base',
 45: 'handle',
 46: 'hinge',
 47: 'hinge',
 48: 'window_frame',
 49: 'window',
 72: 'blinds_rail',
 61: 'blinds',
 50: 'blinds_rail',
 57: 'blinds',
 58: 'blinds',
 51: 'heater',
 52: