In [None]:
import json
import pythreejs as p3
import numpy as np
from pygltflib import GLTF2
from IPython.display import display
from ipywidgets import HTML, VBox, HBox, Accordion, Button, Output, Layout
import pandas as pd
from collections import Counter, defaultdict
import os
from scipy.spatial.transform import Rotation

# Color palette
COLOR_PALETTE = [
    '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
    '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#aec7e8', '#ffbb78',
    '#98df8a', '#ff9896', '#c5b0d5', '#c49c94', '#f7b6d3', '#c7c7c7',
    '#dbdb8d', '#9edae5', '#393b79', '#637939', '#8c6d31', '#843c39',
    '#7b4173', '#5254a3', '#8ca252', '#bd9e39', '#ad494a', '#a55194'
]

def hex_to_rgb_normalized(hex_color):
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16)/255.0 for i in (0, 2, 4))

# Load JSON metadata
json_path = 'model_properties.json'
hierarchy_path = 'model_hierarchy.json'

if not os.path.exists(json_path):
    print(f"JSON file not found: {json_path}")
    metadata = {}
else:
    with open(json_path, 'r') as f:
        metadata = json.load(f)
    print(f"Loaded JSON metadata with {len(metadata)} elements")

if not os.path.exists(hierarchy_path):
    print(f"Hierarchy file not found: {hierarchy_path}")
    hierarchy_data = {}
else:
    with open(hierarchy_path, 'r') as f:
        hierarchy_data = json.load(f)
    print(f"Loaded hierarchy data with {len(hierarchy_data.get('Storeys', []))} storeys")

# Scene setup
width = 800
height = 600
camera = p3.PerspectiveCamera(position=[10, 10, 10], aspect=width/height)
camera.lookAt([0, 0, 0])
scene = p3.Scene(children=[], background='white')

# Lighting
ambient_light = p3.AmbientLight(intensity=1.0)
key_light = p3.DirectionalLight(color='white', position=[3, 5, 1], intensity=0.5)
scene.add([ambient_light, key_light])

# Grid at z=0 (Three.js Z is up)
grid = p3.GridHelper(size=20, divisions=20, colorCenterLine='black', colorGrid='gray')
grid.position = [0, 0, 0]
scene.add(grid)

# Global variables
gltf_meshes = []
type_counter = Counter()
type_to_color = {}
guid_to_mesh = {}
all_meshes_in_scene = []
visible_guids = set()
selected_guid = None
guid_to_button = {}  # Map GUID to button
selected_button = None  # Track selected button
default_opacity = 0.5
selection_opacity = 1.0
background_opacity = 0.0

# Load glTF and create pythreejs meshes
def load_gltf_to_pythreejs(file_path):
    try:
        if not os.path.exists(file_path):
            print(f"GLTF file not found: {file_path}")
            return [], Counter(), {}, {}
        
        print(f"Loading GLTF file: {file_path}")
        gltf = GLTF2().load(file_path)
        
        meshes = []
        type_counter = Counter()
        guid_to_mesh = {}
        
        # Collect types from JSON metadata
        all_types = set()
        valid_guids = set(metadata.keys())
        for guid, data in metadata.items():
            ifc_type = data.get('ifcType', 'Unknown')
            all_types.add(ifc_type)
        print(f"Found {len(all_types)} unique IFC types in JSON metadata")
        
        # Create color mapping
        type_to_color = {t: COLOR_PALETTE[i % len(COLOR_PALETTE)] for i, t in enumerate(sorted(all_types))}
        
        # Process glTF nodes
        def process_node(node_idx, parent_matrix=np.eye(4)):
            node = gltf.nodes[node_idx]
            node_matrix = np.eye(4)
            try:
                if hasattr(node, 'matrix') and node.matrix:
                    node_matrix = np.array(node.matrix).reshape(4, 4)
                else:
                    t = np.array(node.translation) if hasattr(node, 'translation') and node.translation else np.zeros(3)
                    r_quat = node.rotation if hasattr(node, 'rotation') and node.rotation and len(node.rotation) == 4 else [0, 0, 0, 1]
                    s = np.array(node.scale) if hasattr(node, 'scale') and node.scale else np.ones(3)
                    try:
                        r = Rotation.from_quat(r_quat).as_matrix()
                    except ValueError as e:
                        print(f"Invalid quaternion in node {node_idx}: {r_quat}, defaulting to identity rotation")
                        r = np.eye(3)
                    node_matrix[:3, :3] = r * s[:, None]
                    node_matrix[:3, 3] = t
            except Exception as e:
                print(f"Error processing transformation for node {node_idx}: {e}, using identity matrix")
            
            total_matrix = parent_matrix @ node_matrix
            
            if hasattr(node, 'mesh') and node.mesh is not None:
                mesh_idx = node.mesh
                mesh = gltf.meshes[mesh_idx]
                extras = node.extras if hasattr(node, 'extras') else {}
                guid = extras.get('GlobalId')
                if not guid or guid not in valid_guids:
                    print(f"Skipping mesh in node {node_idx}: GUID {guid} not found in JSON metadata")
                    return
                
                obj_type = extras.get('ifcType', metadata[guid].get('ifcType', 'Unknown'))
                type_counter[obj_type] += 1
                hex_color = type_to_color[obj_type]
                
                for primitive in mesh.primitives:
                    try:
                        # Get vertices
                        positions_accessor = gltf.accessors[primitive.attributes.POSITION]
                        positions_buffer_view = gltf.bufferViews[positions_accessor.bufferView]
                        positions_buffer = gltf.buffers[positions_buffer_view.buffer]
                        positions_data = gltf.get_data_from_buffer_uri(positions_buffer.uri)
                        vertices = np.frombuffer(
                            positions_data[positions_buffer_view.byteOffset + positions_accessor.byteOffset:],
                            dtype=np.float32,
                            count=positions_accessor.count * 3
                        ).reshape(-1, 3)
                        
                        # Rotate vertices: IFC Z-up to Three.js Y-up
                        rotation_matrix = np.array([
                            [1, 0, 0],
                            [0, 0, -1],
                            [0, 1, 0]
                        ])
                        vertices = vertices @ rotation_matrix
                        
                        # Apply node transformation
                        vertices_homogeneous = np.hstack([vertices, np.ones((vertices.shape[0], 1))])
                        transformed_vertices = (total_matrix @ vertices_homogeneous.T).T[:, :3]
                        
                        # Get indices
                        indices_accessor = gltf.accessors[primitive.indices]
                        indices_buffer_view = gltf.bufferViews[indices_accessor.bufferView]
                        indices_buffer = gltf.buffers[indices_buffer_view.buffer]
                        indices_data = gltf.get_data_from_buffer_uri(indices_buffer.uri)
                        indices = np.frombuffer(
                            indices_data[indices_buffer_view.byteOffset + indices_accessor.byteOffset:],
                            dtype=np.uint32 if indices_accessor.componentType == 5125 else np.uint16,
                            count=indices_accessor.count
                        )
                        
                        # Get normals (if available)
                        normals = None
                        if hasattr(primitive.attributes, 'NORMAL') and primitive.attributes.NORMAL is not None:
                            normals_accessor = gltf.accessors[primitive.attributes.NORMAL]
                            normals_buffer_view = gltf.bufferViews[normals_accessor.bufferView]
                            normals_buffer = gltf.buffers[normals_buffer_view.buffer]
                            normals_data = gltf.get_data_from_buffer_uri(normals_buffer.uri)
                            normals = np.frombuffer(
                                normals_data[normals_buffer_view.byteOffset + normals_accessor.byteOffset:],
                                dtype=np.float32,
                                count=normals_accessor.count * 3
                            ).reshape(-1, 3)
                            normals = normals @ rotation_matrix
                            normal_transform = total_matrix[:3, :3]
                            normals = normals @ normal_transform
                            norms = np.linalg.norm(normals, axis=1)
                            norms[norms == 0] = 1  # Avoid division by zero
                            normals /= norms[:, None]
                        
                        attributes = {
                            'position': p3.BufferAttribute(array=transformed_vertices.astype(np.float32), normalized=False),
                            'index': p3.BufferAttribute(array=indices, normalized=False)
                        }
                        if normals is not None:
                            attributes['normal'] = p3.BufferAttribute(array=normals.astype(np.float32), normalized=False)
                        
                        geometry = p3.BufferGeometry(attributes=attributes)
                        material = p3.MeshStandardMaterial(
                            color=hex_color,
                            side='DoubleSide',
                            transparent=True,
                            opacity=default_opacity,
                            depthWrite=True
                        )
                        p3_mesh = p3.Mesh(geometry=geometry, material=material)
                        p3_mesh.name = f"mesh_{guid}"
                        
                        p3_mesh.userData = {
                            'original_color': hex_to_rgb_normalized(hex_color),
                            'hex_color': hex_color,
                            'object_type': obj_type,
                            'guid': guid,
                            'default_depthWrite': True
                        }
                        guid_to_mesh[guid] = p3_mesh
                        meshes.append(p3_mesh)
                    except Exception as e:
                        print(f"Error processing primitive in mesh {mesh_idx} for GUID {guid}: {e}")
            
            # Process children
            if hasattr(node, 'children'):
                for child_idx in node.children:
                    process_node(child_idx, total_matrix)
        
        # Start with scene nodes
        for node_idx in gltf.scenes[gltf.scene].nodes:
            process_node(node_idx)
        
        print(f"Loaded {len(meshes)} meshes from GLTF")
        return meshes, type_counter, type_to_color, guid_to_mesh
    
    except Exception as e:
        print(f"Failed to load GLTF: {e}")
        return [], Counter(), {}, {}

# Load GLTF file
gltf_path = 'model.gltf'
gltf_meshes, type_counter, type_to_color, guid_to_mesh = load_gltf_to_pythreejs(gltf_path)
if gltf_meshes:
    for mesh in gltf_meshes:
        scene.add(mesh)
        all_meshes_in_scene.append(mesh)
    visible_guids = {mesh.userData['guid'] for mesh in gltf_meshes}
    print(f"Successfully loaded {len(gltf_meshes)} meshes from GLTF")
else:
    print("No valid meshes found in GLTF file")

# Build hierarchy from model_hierarchy.json
def build_json_hierarchy(hierarchy_data, metadata):
    hierarchy = defaultdict(lambda: defaultdict(list))
    if not hierarchy_data:
        # Fallback to metadata-based hierarchy using SpatialContainer
        for guid, data in metadata.items():
            container = data.get('SpatialContainer', {})
            story_guid = container.get('GlobalId', 'Unknown')
            story_name = container.get('Name', 'Unknown')
            ifc_type = data.get('ifcType', 'Unknown')
            ifc_name = data.get('ifcName', guid)
            hierarchy[(story_guid, story_name, 'IfcBuildingStorey')][ifc_type].append((guid, ifc_name))
    else:
        # Use model_hierarchy.json
        for storey in hierarchy_data.get('Storeys', []):
            story_guid = storey['GlobalId']
            story_name = storey['Name']
            for element_guid in storey['Children']:
                if element_guid in metadata:
                    data = metadata[element_guid]
                    ifc_type = data.get('ifcType', 'Unknown')
                    ifc_name = data.get('ifcName', element_guid)
                    hierarchy[(story_guid, story_name, 'IfcBuildingStorey')][ifc_type].append((element_guid, ifc_name))
    return hierarchy

hierarchy = build_json_hierarchy(hierarchy_data, metadata)

# Create tree filter
def create_tree_filter():
    def create_level(path, level_elements):
        buttons = []
        for ifc_type, elements in sorted(level_elements.items()):
            type_button = Button(
                description=f"{ifc_type} ({len(elements)})",
                layout={'width': 'auto', 'margin': f'0 0 0 {20*len(path)}px'},
                style={'button_color': 'lightblue'}
            )
            type_button.on_click(lambda b, p=path, t=ifc_type: toggle_type_visibility(p, t))
            buttons.append(type_button)
            for guid, name in sorted(elements, key=lambda x: x[1]):
                element_button = Button(
                    description=f"{name}",
                    layout={'width': 'auto', 'margin': f'0 0 0 {20*(len(path)+1)}px'},
                    style={'button_color': 'lightgreen'}
                )
                element_button.on_click(lambda b, g=guid: select_element(g))
                guid_to_button[guid] = element_button  # Map GUID to button
                buttons.append(element_button)
        return buttons
    
    accordions = []
    def build_accordion(path, parent_path=()):
        if path not in hierarchy:
            return
        level_elements = hierarchy[path]
        buttons = create_level(path, level_elements)
        if buttons:
            level_name = path[1]  # Use story_name
            accordion = Accordion(children=[VBox(buttons)])
            accordion.set_title(0, level_name)
            accordions.append(accordion)
    
    for path in sorted(hierarchy.keys(), key=lambda x: x[1]):  # Sort by story_name
        build_accordion(path)
    return VBox(accordions)

# Add deselect button
def create_deselect_button():
    deselect_btn = Button(
        description="Deselect All",
        style={'button_color': 'lightcoral'},
        layout={'width': 'auto', 'margin': '10px 0'}
    )
    deselect_btn.on_click(lambda b: deselect_all())
    return deselect_btn

# Toggle visibility
def toggle_type_visibility(path, ifc_type):
    guids = [guid for guid, _ in hierarchy[path][ifc_type]]
    all_visible = all(guid in visible_guids for guid in guids)
    if all_visible:
        visible_guids.difference_update(guids)
    else:
        visible_guids.update(guids)
    update_scene_visibility()

# Select element
def select_element(guid):
    global selected_guid, selected_button
    
    # Reset previous button color
    if selected_guid and selected_guid in guid_to_button:
        guid_to_button[selected_guid].style.button_color = 'lightgreen'
    
    # Set new selection
    selected_guid = guid
    
    # Update button color
    if guid in guid_to_button:
        guid_to_button[guid].style.button_color = '#ffeb3b'  # Yellow for active
        selected_button = guid_to_button[guid]
    
    # Update meshes
    for mesh in all_meshes_in_scene:
        mesh_guid = mesh.userData.get('guid', '')
        if mesh_guid in visible_guids:
            if mesh_guid == selected_guid:
                mesh.material.opacity = selection_opacity
                mesh.material.depthWrite = True
                mesh.renderOrder = 10
            else:
                mesh.material.opacity = background_opacity
                mesh.material.depthWrite = False
                mesh.renderOrder = 0
        else:
            mesh.material.opacity = 0.0
            mesh.material.depthWrite = False
            mesh.visible = False
            mesh.renderOrder = 0
    
    # Display properties
    with properties_output:
        properties_output.clear_output()
        if guid in metadata:
            data = metadata[guid]
            props = []
            props.append({'PropertySet': 'General', 'Property': 'GUID', 'Value': guid})
            props.append({'PropertySet': 'General', 'Property': 'ifcType', 'Value': data.get('ifcType', 'N/A')})
            props.append({'PropertySet': 'General', 'Property': 'ifcName', 'Value': data.get('ifcName', 'N/A')})
            if 'SpatialContainer' in data and data['SpatialContainer']:
                props.append({'PropertySet': 'Spatial', 'Property': 'Container', 'Value': data['SpatialContainer'].get('Name', 'N/A')})
            for pset_name, pset_data in data.get('PropertySets', {}).items():
                for key, value in pset_data.items():
                    props.append({'PropertySet': pset_name, 'Property': key, 'Value': str(value)})
            for qto_name, qto_data in data.get('Quantities', {}).items():
                for key, qto in qto_data.items():
                    props.append({'PropertySet': qto_name, 'Property': key, 'Value': f"{qto['Value']} ({qto['Type']})"})
            df = pd.DataFrame(props)
            display(HTML(f'<h4>Properties for: {data.get("ifcName", guid)}</h4>'))
            display(df.style.set_caption('Element Properties'))
        else:
            display(HTML('<b>No properties found for this object</b>'))

# Deselect all
def deselect_all():
    global selected_guid, selected_button
    # Reset button color
    if selected_guid and selected_guid in guid_to_button:
        guid_to_button[selected_guid].style.button_color = 'lightgreen'
    
    selected_guid = None
    selected_button = None
    
    # Reset meshes
    for mesh in all_meshes_in_scene:
        if mesh.userData.get('guid', '') in visible_guids:
            mesh.material.opacity = default_opacity
            mesh.material.depthWrite = mesh.userData.get('default_depthWrite', True)
            mesh.renderOrder = 0
            mesh.visible = True
        else:
            mesh.material.opacity = 0.0
            mesh.material.depthWrite = False
            mesh.visible = False
            mesh.renderOrder = 0
    
    # Clear properties
    with properties_output:
        properties_output.clear_output()
        display(HTML('<i>No element selected</i>'))
    update_scene_visibility()

# Update scene visibility
def update_scene_visibility():
    visible_count = 0
    for mesh in all_meshes_in_scene:
        guid = mesh.userData.get('guid', '')
        mesh.visible = guid in visible_guids
        if mesh.visible:
            visible_count += 1
            if selected_guid is None:
                mesh.material.opacity = default_opacity
                mesh.material.depthWrite = mesh.userData.get('default_depthWrite', True)
                mesh.renderOrder = 0
            elif guid == selected_guid:
                mesh.material.opacity = selection_opacity
                mesh.material.depthWrite = True
                mesh.renderOrder = 10
            else:
                mesh.material.opacity = background_opacity
                mesh.material.depthWrite = False
                mesh.renderOrder = 0
        else:
            mesh.material.opacity = 0.0
            mesh.material.depthWrite = False
            mesh.renderOrder = 0
    
    total_count = len(all_meshes_in_scene)
    filter_info_widget.value = f"<b>Visible Objects:</b> {visible_count}/{total_count}<br>"

# Renderer
controller = p3.OrbitControls(controlling=camera)
renderer = p3.Renderer(camera=camera, scene=scene, controls=[controller], width=width, height=height)

# Widgets
filter_info_widget = HTML(value="<b>Visible Objects:</b> 0/0<br>")
properties_output = Output()
tree_filter = create_tree_filter()
deselect_button = create_deselect_button()

# Layout
filter_sidebar = VBox([
    HTML(value='<h3>Data Filter</h3>'),
    filter_info_widget,
    deselect_button,
    tree_filter
], layout=Layout(width='400px', height='600px', overflow='auto'))

viewer_section = VBox([
    renderer,
    HTML(value='<h3>Element Properties</h3>'),
    properties_output
], layout=Layout(width='800px'))

main_layout = HBox([viewer_section, filter_sidebar])

# Initialize properties
with properties_output:
    display(HTML('<i>No element selected</i>'))

# Display application
display(main_layout)

# Initial visibility update
update_scene_visibility()