# 3D Visualization

## Purpose

This notebook teaches you how to create 3D visualizations of voxel data with interactive visualization controls. You'll learn 3D volume rendering, slice visualization, multi-resolution views, and animations with interactive widgets.

## Learning Objectives

By the end of this notebook, you will:
- ‚úÖ Create 3D volume visualizations (Surface, Volume, Points)
- ‚úÖ Visualize 2D slices (XY, XZ, YZ planes)
- ‚úÖ Use multi-resolution visualization
- ‚úÖ Create animations (time, layer, parameter)
- ‚úÖ Control camera views and export visualizations
- ‚úÖ Use interactive PyVista viewers

## Estimated Duration

60-90 minutes

---

## Overview

3D visualization enables intuitive exploration of voxel domain data. The AM-QADF framework provides comprehensive 3D visualization capabilities:

- üé® **3D Volume Rendering**: Surface, Volume, Points rendering modes
- üìê **Slice Visualization**: XY, XZ, YZ plane slices with interactive navigation
- üîç **Multi-Resolution**: Adaptive resolution display for large datasets
- üé¨ **Animation**: Time-series, layer-by-layer, parameter sweep animations
- üì∑ **Camera Controls**: Interactive camera with preset views
- üíæ **Export**: Export images, videos, and 3D models

Use the interactive widgets below to create and customize 3D visualizations - no coding required!


In [1]:
# Setup: Import required libraries
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add parent directory and src directory to path for imports
notebook_dir = Path().resolve()
project_root = notebook_dir.parent
src_dir = project_root / 'src'

# Add project root to path (for src.infrastructure imports)
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Add src directory to path (for am_qadf imports)
if str(src_dir) not in sys.path:
    sys.path.insert(0, str(src_dir))

# Core imports
import ipywidgets as widgets
from ipywidgets import (
    VBox, HBox, Accordion, Tab, Dropdown, RadioButtons, 
    Checkbox, Button, Output, Text, IntSlider, FloatSlider,
    Layout, Box, Label, FloatText, IntText, SelectMultiple
)
from IPython.display import display, Markdown, HTML, clear_output
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from scipy import stats
from datetime import datetime
from typing import Optional, Tuple, Dict, Any, List

# Try to import PyVista for advanced 3D visualization
PYVISTA_AVAILABLE = False
try:
    import pyvista as pv
    pv.set_jupyter_backend('static')  # Use static backend for notebooks
    PYVISTA_AVAILABLE = True
except ImportError:
    print("‚ö†Ô∏è PyVista not available - using matplotlib 3D for visualization")

# Load environment variables from development.env
import os
env_file = project_root / 'development.env'
if env_file.exists():
    with open(env_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith('#') and '=' in line:
                key, value = line.split('=', 1)
                value = value.strip('"\'')
                os.environ[key] = value
    print("‚úÖ Environment variables loaded from development.env")

# Try to import voxel renderer classes
RENDERER_AVAILABLE = False
try:
    from am_qadf.visualization.voxel_renderer import VoxelRenderer
    RENDERER_AVAILABLE = True
    print("‚úÖ Voxel renderer classes available")
except ImportError as e:
    print(f"‚ö†Ô∏è Voxel renderer classes not available: {e} - using demo mode")

# MongoDB connection setup
INFRASTRUCTURE_AVAILABLE = False
mongo_client = None
voxel_storage = None
stl_client = None

try:
    from src.infrastructure.config import MongoDBConfig
    from src.infrastructure.database import MongoDBClient
    from am_qadf.voxel_domain import VoxelGridStorage
    from am_qadf.query import STLModelClient
    
    # Initialize MongoDB connection
    config = MongoDBConfig.from_env()
    if not config.username:
        config.username = os.getenv('MONGO_ROOT_USERNAME', 'admin')
    if not config.password:
        config.password = os.getenv('MONGO_ROOT_PASSWORD', 'password')
    
    mongo_client = MongoDBClient(config=config)
    if mongo_client.is_connected():
        voxel_storage = VoxelGridStorage(mongo_client=mongo_client)
        stl_client = STLModelClient(mongo_client=mongo_client)
        INFRASTRUCTURE_AVAILABLE = True
        print(f"‚úÖ Connected to MongoDB: {config.database}")
    else:
        print("‚ö†Ô∏è MongoDB connection failed")
except Exception as e:
    print(f"‚ö†Ô∏è MongoDB not available: {e} - using demo mode")

print("‚úÖ Setup complete!")


‚úÖ Environment variables loaded from development.env
‚úÖ Voxel renderer classes available


Failed to connect to MongoDB: localhost:27017: [Errno 111] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30.0s, Topology Description: <TopologyDescription id: 6960145710d7cb6fceb3402d, topology_type: Unknown, servers: [<ServerDescription ('localhost', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('localhost:27017: [Errno 111] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]>


‚ö†Ô∏è MongoDB not available: localhost:27017: [Errno 111] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30.0s, Topology Description: <TopologyDescription id: 6960145710d7cb6fceb3402d, topology_type: Unknown, servers: [<ServerDescription ('localhost', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('localhost:27017: [Errno 111] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]> - using demo mode
‚úÖ Setup complete!


## Interactive 3D Visualization Interface

Use the widgets below to create and customize 3D visualizations. Select visualization mode, configure rendering settings, and interact with 3D viewers!


In [2]:
# Create Interactive 3D Visualization Interface

# Global state
voxel_data = None
current_visualization = None

# ============================================
# Helper Functions for Demo Data
# ============================================

def generate_demo_voxel_data(nx=50, ny=50, nz=30):
    """Generate demo voxel data for visualization."""
    # Create 3D grid
    x = np.linspace(0, 10, nx)
    y = np.linspace(0, 10, ny)
    z = np.linspace(0, 5, nz)
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    
    # Create signal data (temperature-like)
    signal = (
        200 + 
        20 * np.sin(2 * np.pi * X / 5) * np.cos(2 * np.pi * Y / 5) +
        10 * np.exp(-((X - 5)**2 + (Y - 5)**2) / 4) +
        np.random.normal(0, 2, (nx, ny, nz))
    )
    
    return {
        'x': X,
        'y': Y,
        'z': Z,
        'temperature': signal,
        'dims': (nx, ny, nz),
        'bbox_min': (0, 0, 0),
        'bbox_max': (10, 10, 5),
        'resolution': 0.2
    }

# Initialize demo data
voxel_data = generate_demo_voxel_data()

# ============================================
# Top Panel: Data Source and Visualization Mode
# ============================================

# Data source mode
data_source_label = widgets.HTML("<b>Data Source:</b>")
data_source_mode = RadioButtons(
    options=[('MongoDB', 'mongodb'), ('Sample Data', 'sample')],
    value='mongodb',
    description='Source:',
    style={'description_width': 'initial'}
)

# Model selection (for MongoDB)
model_label = widgets.HTML("<b>Model:</b>")
model_options = [("‚îÅ‚îÅ‚îÅ Select Model ‚îÅ‚îÅ‚îÅ", None)]
if stl_client and mongo_client:
    try:
        models = stl_client.list_models(limit=100)
        model_options.extend([
            (f"{m.get('filename', m.get('original_stem', m.get('model_name', 'Unknown')))} ({m.get('model_id', '')[:8]}...)", m.get('model_id'))
            for m in models
        ])
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading models: {e}")

model_dropdown = Dropdown(
    options=model_options,
    value=None,
    description='Model:',
    style={'description_width': 'initial'},
    layout=Layout(width='400px')
)

# Grid type filter
grid_type_label = widgets.HTML("<b>Grid Type:</b>")
grid_type_filter = Dropdown(
    options=[
        ('All Grids', 'all'),
        ('Fused', 'fused'),
        ('Corrected', 'corrected'),
        ('Processed', 'processed'),
        ('Signal-Mapped', 'signal_mapped'),
        ('Raw', 'raw')
    ],
    value='fused',  # Default to fused grids
    description='Type:',
    style={'description_width': 'initial'}
)

# Grid selection (for MongoDB)
grid_label = widgets.HTML("<b>Grid:</b>")
grid_dropdown = Dropdown(
    options=[("‚îÅ‚îÅ‚îÅ Select Grid ‚îÅ‚îÅ‚îÅ", None)],
    value=None,
    description='Grid:',
    style={'description_width': 'initial'},
    layout=Layout(width='500px')
)

load_grid_button = Button(
    description='Load Grid',
    button_style='info',
    icon='refresh',
    layout=Layout(width='120px')
)

# Visualization mode
viz_mode = RadioButtons(
    options=[
        ('3D Volume', '3d_volume'),
        ('Slices', 'slices'),
        ('Multi-Resolution', 'multi_res'),
        ('Animation', 'animation')
    ],
    value='3d_volume',
    description='Mode:',
    style={'description_width': 'initial'}
)

signal_label = widgets.HTML("<b>Signal:</b>")
signal_selector = Dropdown(
    options=[("‚îÅ‚îÅ‚îÅ Select Signal ‚îÅ‚îÅ‚îÅ", None)],
    value=None,
    description='Signal:',
    style={'description_width': 'initial'}
)

load_button = Button(
    description='Load Visualization',
    button_style='success',
    icon='refresh',
    layout=Layout(width='180px')
)

export_button = Button(
    description='Export Visualization',
    button_style='primary',
    icon='download',
    layout=Layout(width='180px')
)

top_panel = VBox([
    HBox([data_source_label, data_source_mode, viz_mode]),
    HBox([model_label, model_dropdown, grid_type_label, grid_type_filter]),
    HBox([grid_label, grid_dropdown, load_grid_button]),
    HBox([signal_label, signal_selector, load_button, export_button])
], layout=Layout(justify_content='flex-start', padding='10px', border='1px solid #ccc'))

# ============================================
# Left Panel: Visualization Configuration
# ============================================

# 3D Volume Settings
volume_render_mode = RadioButtons(
    options=[('Surface', 'surface'), ('Volume', 'volume'), ('Points', 'points')],
    value='surface',
    description='Render Mode:',
    style={'description_width': 'initial'}
)
volume_colormap = Dropdown(
    options=[('Viridis', 'viridis'), ('Plasma', 'plasma'), ('Inferno', 'inferno'), 
             ('Magma', 'magma'), ('Cool', 'cool'), ('Hot', 'hot')],
    value='plasma',
    description='Colormap:',
    style={'description_width': 'initial'}
)
volume_opacity = FloatSlider(value=1.0, min=0.0, max=1.0, step=0.1, description='Opacity:', style={'description_width': 'initial'})
volume_threshold = FloatSlider(value=180.0, min=150.0, max=250.0, step=1.0, description='Threshold:', style={'description_width': 'initial'})
volume_isosurface = FloatSlider(value=200.0, min=150.0, max=250.0, step=1.0, description='Isosurface:', style={'description_width': 'initial'})

volume_config = VBox([
    widgets.HTML("<b>3D Volume Settings:</b>"),
    volume_render_mode,
    volume_colormap,
    volume_opacity,
    volume_threshold,
    volume_isosurface
], layout=Layout(padding='5px', border='1px solid #ddd'))

# Slice Settings
slice_axis = RadioButtons(
    options=[('XY', 'xy'), ('XZ', 'xz'), ('YZ', 'yz')],
    value='xy',
    description='Slice Axis:',
    style={'description_width': 'initial'}
)
slice_position = IntSlider(value=15, min=0, max=29, step=1, description='Position:', style={'description_width': 'initial'})
slice_thickness = IntSlider(value=1, min=1, max=100, step=1, description='Thickness:', style={'description_width': 'initial'})
slice_interpolation = Dropdown(
    options=[('Nearest', 'nearest'), ('Linear', 'linear'), ('Cubic', 'cubic')],
    value='linear',
    description='Interpolation:',
    style={'description_width': 'initial'}
)

slice_config = VBox([
    widgets.HTML("<b>Slice Settings:</b>"),
    slice_axis,
    slice_position,
    slice_thickness,
    slice_interpolation
], layout=Layout(padding='5px', border='1px solid #ddd'))

# Multi-Resolution Settings
multires_level = IntSlider(value=1, min=1, max=5, step=1, description='Resolution Level:', style={'description_width': 'initial'})
multires_auto = Checkbox(value=False, description='Auto-Level', style={'description_width': 'initial'})
multires_level_selector = Dropdown(
    options=[('Level 1', 1), ('Level 2', 2), ('Level 3', 3), ('Level 4', 4), ('Level 5', 5)],
    value=1,
    description='Level:',
    style={'description_width': 'initial'}
)

multires_config = VBox([
    widgets.HTML("<b>Multi-Resolution Settings:</b>"),
    multires_level,
    multires_auto,
    multires_level_selector
], layout=Layout(padding='5px', border='1px solid #ddd'))

# Animation Settings
anim_type = RadioButtons(
    options=[('Time', 'time'), ('Layer', 'layer'), ('Parameter', 'parameter')],
    value='time',
    description='Animation Type:',
    style={'description_width': 'initial'}
)
anim_frame_rate = IntSlider(value=10, min=1, max=60, step=1, description='Frame Rate:', style={'description_width': 'initial'})
anim_play_pause = Button(description='Play', button_style='success', layout=Layout(width='150px'))
anim_frame = IntSlider(value=0, min=0, max=29, step=1, description='Frame:', style={'description_width': 'initial'})

anim_config = VBox([
    widgets.HTML("<b>Animation Settings:</b>"),
    anim_type,
    anim_frame_rate,
    anim_play_pause,
    anim_frame
], layout=Layout(padding='5px', border='1px solid #ddd'))

# STL Model Display
show_stl_model = Checkbox(
    value=False,
    description='Show STL Model',
    style={'description_width': 'initial'}
)
stl_opacity = FloatSlider(value=0.3, min=0.0, max=1.0, step=0.1, description='STL Opacity:', style={'description_width': 'initial'})
stl_color = Dropdown(
    options=[('Gray', 'gray'), ('White', 'white'), ('Blue', 'blue'), ('Green', 'green'), ('Red', 'red')],
    value='gray',
    description='STL Color:',
    style={'description_width': 'initial'}
)
stl_wireframe = Checkbox(
    value=False,
    description='Wireframe',
    style={'description_width': 'initial'}
)

stl_config = VBox([
    widgets.HTML("<b>STL Model Display:</b>"),
    show_stl_model,
    stl_opacity,
    stl_color,
    stl_wireframe
], layout=Layout(padding='5px', border='1px solid #ddd'))

# Camera Controls
camera_azimuth = FloatSlider(value=45, min=-180, max=180, step=5, description='Azimuth:', style={'description_width': 'initial'})
camera_elevation = FloatSlider(value=30, min=-90, max=90, step=5, description='Elevation:', style={'description_width': 'initial'})
camera_zoom = FloatSlider(value=1.0, min=0.1, max=10.0, step=0.1, description='Zoom:', style={'description_width': 'initial'})
camera_reset = Button(description='Reset Camera', button_style='', layout=Layout(width='150px'))
camera_preset = Dropdown(
    options=[('Front', 'front'), ('Back', 'back'), ('Top', 'top'), ('Bottom', 'bottom'), 
             ('Left', 'left'), ('Right', 'right'), ('Isometric', 'isometric')],
    value='isometric',
    description='Preset:',
    style={'description_width': 'initial'}
)

camera_config = VBox([
    widgets.HTML("<b>Camera Controls:</b>"),
    camera_azimuth,
    camera_elevation,
    camera_zoom,
    camera_reset,
    camera_preset
], layout=Layout(padding='5px', border='1px solid #ddd'))

# Dynamic configuration accordion
config_accordion = Accordion(children=[
    volume_config,
    slice_config,
    multires_config,
    anim_config,
    stl_config,
    camera_config
])
config_accordion.set_title(0, '3D Volume')
config_accordion.set_title(1, 'Slices')
config_accordion.set_title(2, 'Multi-Resolution')
config_accordion.set_title(3, 'Animation')
config_accordion.set_title(4, 'STL Model')
config_accordion.set_title(5, 'Camera')

left_panel = VBox([
    widgets.HTML("<h3>Visualization Configuration</h3>"),
    config_accordion
], layout=Layout(width='300px', padding='10px', border='1px solid #ccc'))

# ============================================
# Center Panel: Visualization
# ============================================

viz_output = Output(layout=Layout(height='600px', overflow='auto'))

center_panel = VBox([
    widgets.HTML("<h3>3D Visualization</h3>"),
    viz_output
], layout=Layout(flex='1 1 auto', padding='10px', border='1px solid #ccc'))

# ============================================
# Right Panel: Visualization Info
# ============================================

# Visualization Statistics
stats_label = widgets.HTML("<b>Visualization Statistics:</b>")
stats_display = widgets.HTML("No data loaded")
stats_section = VBox([
    stats_label,
    stats_display
], layout=Layout(padding='5px'))

# View Settings
view_label = widgets.HTML("<b>View Settings:</b>")
view_display = widgets.HTML("No view information")
view_section = VBox([
    view_label,
    view_display
], layout=Layout(padding='5px'))

# Export Options
export_label = widgets.HTML("<b>Export:</b>")
export_image_button = Button(description='Export Image', button_style='', layout=Layout(width='150px'))
export_video_button = Button(description='Export Video', button_style='', layout=Layout(width='150px'))
export_model_button = Button(description='Export 3D Model', button_style='', layout=Layout(width='150px'))
save_config_button = Button(description='Save Config', button_style='', layout=Layout(width='150px'))

export_section = VBox([
    export_label,
    export_image_button,
    export_video_button,
    export_model_button,
    save_config_button
], layout=Layout(padding='5px'))

right_panel = VBox([
    stats_section,
    view_section,
    export_section
], layout=Layout(width='250px', padding='10px', border='1px solid #ccc'))

# ============================================
# Bottom Panel: Status and Performance
# ============================================

status_display = widgets.HTML("<b>Status:</b> Ready to visualize")
performance_display = widgets.HTML("<b>Performance:</b> FPS: N/A | Render Time: N/A")
controls_help = widgets.HTML("<b>Controls:</b> Mouse: Rotate | Scroll: Zoom | Shift+Drag: Pan")

bottom_panel = VBox([
    status_display,
    performance_display,
    controls_help
], layout=Layout(padding='10px', border='1px solid #ccc'))

# ============================================
# Helper Functions for MongoDB
# ============================================

def update_grid_dropdown(change=None):
    """Update grid dropdown when model or grid type changes."""
    global current_model_id
    
    model_id = model_dropdown.value
    grid_type = grid_type_filter.value
    
    if not model_id:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ Select Grid ‚îÅ‚îÅ‚îÅ", None)]
        return
    
    current_model_id = model_id
    
    if not voxel_storage:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ MongoDB not available ‚îÅ‚îÅ‚îÅ", None)]
        return
    
    try:
        # Get all grids for this model
        all_grids = voxel_storage.list_grids(model_id=model_id, limit=1000)
        
        # Filter by grid type
        filtered_grids = []
        for grid in all_grids:
            metadata = grid.get('metadata', {})
            config_meta = metadata.get('configuration_metadata', {})
            if not config_meta:
                config_meta = metadata
            
            is_fused = config_meta.get('fusion_applied', False)
            is_corrected = config_meta.get('correction_applied', False)
            is_processed = config_meta.get('processing_applied', False)
            has_signals = len(grid.get('available_signals', [])) > 0
            
            grid_type_match = False
            if grid_type == 'all':
                grid_type_match = True
            elif grid_type == 'fused' and is_fused:
                grid_type_match = True
            elif grid_type == 'corrected' and is_corrected:
                grid_type_match = True
            elif grid_type == 'processed' and is_processed:
                grid_type_match = True
            elif grid_type == 'signal_mapped' and has_signals and not is_fused and not is_corrected:
                grid_type_match = True
            elif grid_type == 'raw' and not has_signals and not is_fused and not is_corrected:
                grid_type_match = True
            
            if grid_type_match:
                grid_id = grid.get('grid_id', '')
                grid_name = grid.get('grid_name', 'Unknown')
                n_signals = len(grid.get('available_signals', []))
                
                # Build status label
                status_parts = []
                if is_fused:
                    status_parts.append('Fused')
                if is_corrected:
                    status_parts.append('Corrected')
                if is_processed:
                    status_parts.append('Processed')
                if has_signals:
                    status_parts.append(f'{n_signals} signal(s)')
                
                status_str = f" [{', '.join(status_parts)}]" if status_parts else ""
                filtered_grids.append((f"{grid_name}{status_str} ({grid_id[:8]}...)", grid_id))
        
        if filtered_grids:
            grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ Select Grid ‚îÅ‚îÅ‚îÅ", None)] + filtered_grids
        else:
            grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ No grids found ‚îÅ‚îÅ‚îÅ", None)]
    except Exception as e:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ Error loading grids ‚îÅ‚îÅ‚îÅ", None)]
        error_msg = str(e)
        status_display.value = f"<b>Status:</b> <span style='color: red;'>‚ö†Ô∏è Error loading grids: {error_msg}</span>"
        print(f"‚ö†Ô∏è Error loading grids: {e}")
        import traceback
        traceback.print_exc()

def load_grid_from_mongodb(button):
    """Load selected grid from MongoDB."""
    global current_model_id, current_grid_id, loaded_grid_data, signal_arrays, voxel_data
    
    if not voxel_storage or not grid_dropdown.value:
        status_display.value = "<b>Status:</b> <span style='color: red;'>‚ö†Ô∏è Please select a grid to load</span>"
        return
    
    grid_id = grid_dropdown.value
    current_grid_id = grid_id
    
    status_display.value = "<b>Status:</b> Loading grid from MongoDB..."
    
    try:
        # Load grid from MongoDB
        grid_data = voxel_storage.load_voxel_grid(grid_id=grid_id)
        
        if not grid_data:
            status_display.value = "<b>Status:</b> <span style='color: red;'>‚ö†Ô∏è Failed to load grid</span>"
            return
        
        # Extract data from dictionary
        signal_arrays = grid_data.get('signal_arrays', {})
        metadata = grid_data.get('metadata', {})
        grid_name = grid_data.get('grid_name', 'Unknown')
        
        # Store loaded data
        loaded_grid_data = {
            'grid_data': grid_data,
            'metadata': metadata,
            'signal_arrays': signal_arrays
        }
        
        # Convert signal arrays to voxel_data format for visualization
        if signal_arrays:
            # Get first signal to determine dimensions
            first_signal_name = list(signal_arrays.keys())[0]
            first_signal = signal_arrays[first_signal_name]
            
            # Get grid dimensions from signal shape (signals are already 3D arrays)
            if first_signal.ndim == 3:
                nx, ny, nz = first_signal.shape
            elif first_signal.ndim == 1:
                # Try to get dimensions from metadata
                if 'grid_data' in grid_data:
                    vg_data = grid_data.get('grid_data', {})
                    if isinstance(vg_data, dict):
                        if 'dims' in vg_data:
                            nx, ny, nz = vg_data['dims']
                        elif 'resolution' in vg_data and 'bbox_min' in metadata and 'bbox_max' in metadata:
                            bbox_min = metadata['bbox_min']
                            bbox_max = metadata['bbox_max']
                            resolution = vg_data['resolution']
                            nx = int((bbox_max[0] - bbox_min[0]) / resolution)
                            ny = int((bbox_max[1] - bbox_min[1]) / resolution)
                            nz = int((bbox_max[2] - bbox_min[2]) / resolution)
                        else:
                            # Estimate from signal length (assume cubic)
                            n_total = len(first_signal)
                            n_per_dim = int(np.cbrt(n_total))
                            nx = ny = nz = n_per_dim
                    else:
                        # VoxelGrid object
                        if hasattr(vg_data, 'dims'):
                            nx, ny, nz = vg_data.dims
                        else:
                            n_total = len(first_signal)
                            n_per_dim = int(np.cbrt(n_total))
                            nx = ny = nz = n_per_dim
                else:
                    # Estimate from signal length
                    n_total = len(first_signal)
                    n_per_dim = int(np.cbrt(n_total))
                    nx = ny = nz = n_per_dim
            else:
                nx, ny, nz = first_signal.shape if first_signal.ndim >= 2 else (50, 50, 30)
            
            # Get bounding box from metadata
            bbox_min = metadata.get('bbox_min', [0, 0, 0])
            bbox_max = metadata.get('bbox_max', [10, 10, 5])
            
            # Ensure bbox_min and bbox_max are lists/arrays
            if isinstance(bbox_min, (list, tuple)) and len(bbox_min) == 3:
                bbox_min = list(bbox_min)
            else:
                bbox_min = [0, 0, 0]
            
            if isinstance(bbox_max, (list, tuple)) and len(bbox_max) == 3:
                bbox_max = list(bbox_max)
            else:
                bbox_max = [10, 10, 5]
            
            # Create coordinate arrays
            x = np.linspace(bbox_min[0], bbox_max[0], nx)
            y = np.linspace(bbox_min[1], bbox_max[1], ny)
            z = np.linspace(bbox_min[2], bbox_max[2], nz)
            X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
            
            # Create voxel_data structure with all signals
            voxel_data = {
                'x': X,
                'y': Y,
                'z': Z,
                'dims': (nx, ny, nz),
                'bbox_min': bbox_min,
                'bbox_max': bbox_max
            }
            
            # Add all signals
            for signal_name, signal_array in signal_arrays.items():
                # Ensure signal is 3D
                if signal_array.ndim == 1:
                    # Reshape if needed
                    signal_array = signal_array.reshape((nx, ny, nz))
                elif signal_array.ndim == 3 and signal_array.shape != (nx, ny, nz):
                    # Resize if needed (simplified)
                    from scipy.ndimage import zoom
                    zoom_factors = (nx/signal_array.shape[0], ny/signal_array.shape[1], nz/signal_array.shape[2])
                    signal_array = zoom(signal_array, zoom_factors, order=1)
                voxel_data[signal_name] = signal_array
            
            # Update signal selector
            signal_options = [(name.replace('_', ' ').title(), name) for name in signal_arrays.keys()]
            if signal_options:
                signal_selector.options = signal_options
                signal_selector.value = signal_options[0][1]  # Select first signal
                
                # Update slice position and animation frame max based on loaded data
                first_signal = voxel_data[signal_options[0][1]]
                slice_position.max = max(first_signal.shape) - 1
                anim_frame.max = first_signal.shape[2] - 1 if len(first_signal.shape) > 2 else first_signal.shape[0] - 1
        else:
            status_display.value = "<b>Status:</b> <span style='color: orange;'>‚ö†Ô∏è No signals in grid</span>"
            return
        
        status_display.value = f"<b>Status:</b> <span style='color: green;'>‚úÖ Loaded grid: {grid_name} ({len(signal_arrays)} signal(s))</span>"
        
        # Try to load STL model if available
        load_stl_model()
        
    except Exception as e:
        status_display.value = f"<b>Status:</b> <span style='color: red;'>‚ùå Error loading grid: {str(e)}</span>"
        import traceback
        traceback.print_exc()

def load_stl_model():
    """Load STL model for the current model."""
    global stl_mesh, stl_path, current_model_id
    
    stl_mesh = None
    stl_path = None
    
    if not stl_client or not current_model_id:
        return
    
    try:
        # Get STL file path from database
        stl_path = stl_client.load_stl_file(current_model_id)
        if stl_path and stl_path.exists():
            print(f"üìÇ Loading STL from: {stl_path}")
            # Priority 1: Try PyVista (best performance for large meshes)
            if PYVISTA_AVAILABLE:
                try:
                    stl_mesh = pv.read(str(stl_path))
                    n_faces = stl_mesh.n_cells  # Use n_cells instead of deprecated n_faces
                    print(f"‚úÖ PyVista loaded: {len(stl_mesh.points)} vertices, {n_faces} faces")
                    status_display.value = f"<b>Status:</b> <span style='color: green;'>‚úÖ STL model loaded with PyVista: {stl_path.name} ({n_faces} faces)</span>"
                    return
                except Exception as e:
                    print(f"‚ö†Ô∏è PyVista failed to load STL: {e}")
                    import traceback
                    traceback.print_exc()
            
            # Priority 2: Try trimesh (good for conversion to PyVista)
            try:
                import trimesh
                trimesh_mesh = trimesh.load(str(stl_path))
                print(f"‚úÖ Trimesh loaded: {len(trimesh_mesh.vertices)} vertices, {len(trimesh_mesh.faces)} faces")
                # Convert to PyVista if available
                if PYVISTA_AVAILABLE:
                    vertices = trimesh_mesh.vertices
                    faces = trimesh_mesh.faces
                    # PyVista expects faces in format: [n, v1, v2, v3, ...]
                    faces_pv = np.column_stack([
                        np.full(len(faces), 3),  # Triangle faces
                        faces
                    ]).flatten()
                    stl_mesh = pv.PolyData(vertices, faces_pv)
                    n_faces = stl_mesh.n_cells  # Use n_cells instead of deprecated n_faces
                    print(f"‚úÖ Converted to PyVista: {n_faces} faces")
                    status_display.value = f"<b>Status:</b> <span style='color: green;'>‚úÖ STL model loaded (trimesh‚ÜíPyVista): {stl_path.name} ({n_faces} faces)</span>"
                else:
                    # Store trimesh for matplotlib fallback
                    stl_mesh = trimesh_mesh
                    status_display.value = f"<b>Status:</b> <span style='color: green;'>‚úÖ STL model loaded with trimesh: {stl_path.name} ({len(trimesh_mesh.faces)} faces)</span>"
                return
            except ImportError:
                pass
            except Exception as e:
                print(f"‚ö†Ô∏è trimesh failed to load STL: {e}")
                import traceback
                traceback.print_exc()
            
            # Priority 3: Try numpy-stl (fallback for matplotlib)
            try:
                from stl import mesh
                stl_mesh = mesh.Mesh.from_file(str(stl_path))
                print(f"‚úÖ numpy-stl loaded: {len(stl_mesh.vectors)} triangles")
                status_display.value = f"<b>Status:</b> <span style='color: green;'>‚úÖ STL model loaded with numpy-stl: {stl_path.name} ({len(stl_mesh.vectors)} triangles)</span>"
            except ImportError:
                status_display.value = f"<b>Status:</b> <span style='color: orange;'>‚ö†Ô∏è STL path found but no STL library available. Install pyvista, trimesh, or numpy-stl.</span>"
            except Exception as e:
                print(f"‚ö†Ô∏è numpy-stl failed to load STL: {e}")
                import traceback
                traceback.print_exc()
        else:
            # No STL file found for this model
            if stl_path:
                print(f"‚ö†Ô∏è STL file path exists but file not found: {stl_path}")
            else:
                print(f"‚ö†Ô∏è No STL file path found for model_id: {current_model_id}")
            stl_mesh = None
            stl_path = None
    except Exception as e:
        # STL loading failed, but continue without it
        stl_mesh = None
        stl_path = None
        print(f"‚ö†Ô∏è Could not load STL model: {e}")
        import traceback
        traceback.print_exc()

# Function to update UI based on data source mode
def update_data_source_mode(change):
    """Show/hide MongoDB widgets based on data source mode."""
    if change['new'] == 'mongodb':
        model_dropdown.layout.display = 'flex'
        grid_type_filter.layout.display = 'flex'
        grid_dropdown.layout.display = 'flex'
        load_grid_button.layout.display = 'flex'
    else:
        model_dropdown.layout.display = 'none'
        grid_type_filter.layout.display = 'none'
        grid_dropdown.layout.display = 'none'
        load_grid_button.layout.display = 'none'
        # Reset to demo data
        global voxel_data
        voxel_data = generate_demo_voxel_data()
        signal_selector.options = [('Temperature', 'temperature'), ('Density', 'density'), ('Power', 'power')]
        signal_selector.value = 'temperature'
        # Update max values for demo data
        if voxel_data and 'temperature' in voxel_data:
            signal = voxel_data['temperature']
            slice_position.max = max(signal.shape) - 1
            anim_frame.max = signal.shape[2] - 1 if len(signal.shape) > 2 else signal.shape[0] - 1

# ============================================
# Visualization Functions
# ============================================

def create_3d_visualization():
    """Create 3D visualization based on current settings."""
    global voxel_data, current_visualization
    
    if not voxel_data:
        return None
    
    mode = viz_mode.value
    signal_name = signal_selector.value
    
    if not signal_name or signal_name not in voxel_data:
        return None
    
    if mode == '3d_volume':
        return create_3d_volume_viz()
    elif mode == 'slices':
        return create_slice_viz()
    elif mode == 'multi_res':
        return create_multires_viz()
    elif mode == 'animation':
        return create_animation_viz()
    else:
        return None

def create_3d_volume_viz():
    """Create 3D volume visualization."""
    global voxel_data, stl_mesh
    
    signal_name = signal_selector.value
    if not signal_name or signal_name not in voxel_data:
        return None
    
    signal = voxel_data[signal_name]
    render_mode = volume_render_mode.value
    colormap = volume_colormap.value
    opacity = volume_opacity.value
    threshold = volume_threshold.value
    
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Apply threshold
    mask = signal > threshold
    x_vals = voxel_data['x'][mask]
    y_vals = voxel_data['y'][mask]
    z_vals = voxel_data['z'][mask]
    signal_vals = signal[mask]
    
    scatter = None  # Initialize scatter variable
    
    if render_mode == 'points':
        scatter = ax.scatter(x_vals, y_vals, z_vals, c=signal_vals, cmap=colormap, 
                            alpha=opacity, s=10)
    elif render_mode == 'surface':
        # Simplified surface (show isosurface)
        isovalue = volume_isosurface.value
        # Create contour plot on a slice
        z_slice_idx = slice_position.value
        if z_slice_idx >= signal.shape[2]:
            z_slice_idx = signal.shape[2] - 1
        contour = ax.contour(voxel_data['x'][:, :, z_slice_idx], 
                  voxel_data['y'][:, :, z_slice_idx],
                  signal[:, :, z_slice_idx], 
                  levels=[isovalue], cmap=colormap, alpha=opacity)
        scatter = ax.scatter(x_vals, y_vals, z_vals, c=signal_vals, cmap=colormap, 
                 alpha=opacity*0.5, s=5)
    else:  # volume
        scatter = ax.scatter(x_vals, y_vals, z_vals, c=signal_vals, cmap=colormap, 
                  alpha=opacity, s=20)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'3D Volume Visualization: {signal_name.replace("_", " ").title()}')
    
    # Add STL model if enabled
    if show_stl_model.value and stl_mesh is not None:
        try:
            # Check if it's a PyVista mesh (best performance)
            if PYVISTA_AVAILABLE and isinstance(stl_mesh, pv.PolyData):
                # PyVista mesh - much more efficient!
                # Note: We can't directly add PyVista to matplotlib, so we'll extract vertices/faces
                vertices = stl_mesh.points
                
                # PyVista stores faces as [n, v1, v2, v3, ...] where n is number of vertices
                # For triangular meshes, n=3, so format is [3, v1, v2, v3, 3, v4, v5, v6, ...]
                faces_array = stl_mesh.faces
                
                # Reshape to get individual faces
                # Each face starts with the count (3 for triangles), followed by 3 vertex indices
                n_faces = stl_mesh.n_cells  # Use n_cells instead of deprecated n_faces
                if len(faces_array) > 0:
                    # Reshape: [3, v1, v2, v3, 3, v4, v5, v6, ...] -> [[3, v1, v2, v3], [3, v4, v5, v6], ...]
                    faces = faces_array.reshape(n_faces, 4)[:, 1:]  # Remove face count (first column), keep vertex indices
                    
                    # Sample faces for performance (max 2000 faces for matplotlib)
                    if len(faces) > 2000:
                        step = max(1, len(faces) // 2000)
                        sampled_faces = faces[::step]
                    else:
                        sampled_faces = faces
                    
                    for face in sampled_faces:
                        # Ensure we have valid vertex indices
                        if len(face) == 3 and all(0 <= idx < len(vertices) for idx in face):
                            triangle = vertices[face]
                            # Check for valid triangle (not all points the same)
                            if len(np.unique(triangle, axis=0)) >= 3:
                                if stl_wireframe.value:
                                    ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                                           [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                                           [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                                           color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                                else:
                                    ax.plot_trisurf(
                                        triangle[:, 0], triangle[:, 1], triangle[:, 2],
                                        color=stl_color.value, alpha=stl_opacity.value,
                                        linewidth=0, edgecolor='none'
                                    )
            elif hasattr(stl_mesh, 'vectors'):  # numpy-stl format
                # numpy-stl mesh - sample triangles for performance
                vectors = stl_mesh.vectors
                step = max(1, len(vectors) // 1000)  # Sample for performance
                sampled_vectors = vectors[::step]
                
                for triangle in sampled_vectors:
                    vertices = triangle
                    if stl_wireframe.value:
                        ax.plot([vertices[0, 0], vertices[1, 0], vertices[2, 0], vertices[0, 0]],
                               [vertices[0, 1], vertices[1, 1], vertices[2, 1], vertices[0, 1]],
                               [vertices[0, 2], vertices[1, 2], vertices[2, 2], vertices[0, 2]],
                               color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                    else:
                        ax.plot_trisurf(
                            vertices[:, 0], vertices[:, 1], vertices[:, 2],
                            color=stl_color.value, alpha=stl_opacity.value,
                            linewidth=0, edgecolor='none'
                        )
            elif hasattr(stl_mesh, 'vertices'):  # trimesh format
                # trimesh format
                vertices = stl_mesh.vertices
                faces = stl_mesh.faces
                step = max(1, len(faces) // 1000)  # Sample for performance
                sampled_faces = faces[::step]
                
                for face in sampled_faces:
                    triangle = vertices[face]
                    if stl_wireframe.value:
                        ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                               [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                               [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                               color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                    else:
                        ax.plot_trisurf(
                            triangle[:, 0], triangle[:, 1], triangle[:, 2],
                            color=stl_color.value, alpha=stl_opacity.value,
                            linewidth=0, edgecolor='none'
                        )
        except Exception as e:
            print(f"‚ö†Ô∏è Error displaying STL model: {e}")
    
    # Add colorbar if scatter exists
    if scatter is not None:
        plt.colorbar(scatter, ax=ax, label=signal_name.replace('_', ' ').title())
    
    plt.tight_layout()
    return fig

def create_slice_viz():
    """Create slice visualization."""
    global voxel_data, stl_mesh
    
    signal_name = signal_selector.value
    if not signal_name or signal_name not in voxel_data:
        return None
    
    signal = voxel_data[signal_name]
    axis = slice_axis.value
    pos = slice_position.value
    colormap = volume_colormap.value
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    if axis == 'xy':
        slice_data = signal[:, :, pos]
        x_data = voxel_data['x'][:, :, pos]
        y_data = voxel_data['y'][:, :, pos]
        axes[0].contourf(x_data, y_data, slice_data, levels=20, cmap=colormap)
        axes[0].set_xlabel('X')
        axes[0].set_ylabel('Y')
        axes[0].set_title(f'XY Slice at Z={voxel_data["z"][0, 0, pos]:.2f}')
        
        # 3D view with slice highlighted
        ax = fig.add_subplot(122, projection='3d')
        ax.scatter(voxel_data['x'].flatten()[::100], 
                  voxel_data['y'].flatten()[::100],
                  voxel_data['z'].flatten()[::100],
                  c=signal.flatten()[::100], cmap=colormap, alpha=0.3, s=1)
        z_slice = voxel_data['z'][0, 0, pos]
        ax.plot_surface(voxel_data['x'][:, :, pos], voxel_data['y'][:, :, pos],
                       np.full_like(voxel_data['x'][:, :, pos], z_slice),
                       facecolors=plt.cm.get_cmap(colormap)((slice_data - slice_data.min()) / (slice_data.max() - slice_data.min() + 1e-8)),
                       alpha=0.8)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D View with Slice')
        
        # Add STL model if enabled (only in 3D view, not 2D slice)
        if show_stl_model.value and stl_mesh is not None:
            try:
                # Check if it's a PyVista mesh (best performance)
                if PYVISTA_AVAILABLE and isinstance(stl_mesh, pv.PolyData):
                    # PyVista mesh - extract faces properly using get_cell
                    try:
                        vertices = stl_mesh.points
                        n_faces = stl_mesh.n_cells
                        
                        if len(vertices) == 0 or n_faces == 0:
                            return
                        
                        # Sample faces for performance
                        max_faces = 2000
                        if n_faces > max_faces:
                            step = max(1, n_faces // max_faces)
                            sampled_face_indices = list(range(0, n_faces, step))
                        else:
                            sampled_face_indices = list(range(n_faces))
                        
                        valid_triangles = []
                        for face_idx in sampled_face_indices:
                            try:
                                cell = stl_mesh.get_cell(face_idx)
                                if cell.n_points == 3:
                                    point_ids = cell.point_ids
                                    if len(point_ids) == 3:
                                        triangle_verts = vertices[point_ids]
                                        if len(np.unique(triangle_verts, axis=0)) >= 3:
                                            valid_triangles.append(triangle_verts)
                            except Exception:
                                continue
                        
                        for triangle in valid_triangles:
                            if stl_wireframe.value:
                                ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                                       [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                                       [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                                       color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                            else:
                                try:
                                    ax.plot_trisurf(
                                        triangle[:, 0], triangle[:, 1], triangle[:, 2],
                                        color=stl_color.value, alpha=stl_opacity.value,
                                        linewidth=0, edgecolor='none'
                                    )
                                except Exception:
                                    ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                                           [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                                           [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                                           color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                    except Exception as e:
                        print(f"‚ö†Ô∏è Error extracting faces from PyVista mesh: {e}")
                        import traceback
                        traceback.print_exc()
                elif hasattr(stl_mesh, 'vectors'):  # numpy-stl format
                    vectors = stl_mesh.vectors
                    step = max(1, len(vectors) // 1000)  # Sample for performance
                    sampled_vectors = vectors[::step]
                    
                    for triangle in sampled_vectors:
                        vertices = triangle
                        if stl_wireframe.value:
                            ax.plot([vertices[0, 0], vertices[1, 0], vertices[2, 0], vertices[0, 0]],
                                   [vertices[0, 1], vertices[1, 1], vertices[2, 1], vertices[0, 1]],
                                   [vertices[0, 2], vertices[1, 2], vertices[2, 2], vertices[0, 2]],
                                   color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                        else:
                            ax.plot_trisurf(
                                vertices[:, 0], vertices[:, 1], vertices[:, 2],
                                color=stl_color.value, alpha=stl_opacity.value,
                                linewidth=0, edgecolor='none'
                            )
                elif hasattr(stl_mesh, 'vertices'):  # trimesh format
                    vertices = stl_mesh.vertices
                    faces = stl_mesh.faces
                    step = max(1, len(faces) // 1000)  # Sample for performance
                    sampled_faces = faces[::step]
                    
                    for face in sampled_faces:
                        triangle = vertices[face]
                        if stl_wireframe.value:
                            ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                                   [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                                   [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                                   color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                        else:
                            ax.plot_trisurf(
                                triangle[:, 0], triangle[:, 1], triangle[:, 2],
                                color=stl_color.value, alpha=stl_opacity.value,
                                linewidth=0, edgecolor='none'
                            )
            except Exception as e:
                print(f"‚ö†Ô∏è Error displaying STL model: {e}")
    
    elif axis == 'xz':
        slice_data = signal[:, pos, :]
        x_data = voxel_data['x'][:, pos, :]
        z_data = voxel_data['z'][:, pos, :]
        axes[0].contourf(x_data, z_data, slice_data, levels=20, cmap=colormap)
        axes[0].set_xlabel('X')
        axes[0].set_ylabel('Z')
        axes[0].set_title(f'XZ Slice at Y={voxel_data["y"][0, pos, 0]:.2f}')
        axes[1].axis('off')
    
    else:  # yz
        slice_data = signal[pos, :, :]
        y_data = voxel_data['y'][pos, :, :]
        z_data = voxel_data['z'][pos, :, :]
        axes[0].contourf(y_data, z_data, slice_data, levels=20, cmap=colormap)
        axes[0].set_xlabel('Y')
        axes[0].set_ylabel('Z')
        axes[0].set_title(f'YZ Slice at X={voxel_data["x"][pos, 0, 0]:.2f}')
        axes[1].axis('off')
    
    plt.colorbar(axes[0].contourf(x_data, y_data if axis == 'xy' else (z_data if axis == 'xz' else z_data), 
                                  slice_data, levels=20, cmap=colormap), ax=axes[0])
    plt.tight_layout()
    return fig

def create_multires_viz():
    """Create multi-resolution visualization."""
    global voxel_data, stl_mesh
    
    signal_name = signal_selector.value
    if not signal_name or signal_name not in voxel_data:
        return None
    
    level = multires_level.value
    signal = voxel_data[signal_name]
    colormap = volume_colormap.value
    
    # Downsample based on level
    step = 2 ** (level - 1)
    downsampled = signal[::step, ::step, ::step]
    x_down = voxel_data['x'][::step, ::step, ::step]
    y_down = voxel_data['y'][::step, ::step, ::step]
    z_down = voxel_data['z'][::step, ::step, ::step]
    
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(x_down.flatten(), y_down.flatten(), z_down.flatten(),
                        c=downsampled.flatten(), cmap=colormap, s=50, alpha=0.6)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'Multi-Resolution Visualization (Level {level}, {downsampled.size:,} voxels)')
    
    # Add STL model if enabled
    if show_stl_model.value and stl_mesh is not None:
        try:
            # Check if it's a PyVista mesh (best performance)
            if PYVISTA_AVAILABLE and isinstance(stl_mesh, pv.PolyData):
                # PyVista mesh - extract faces properly using get_cell
                try:
                    vertices = stl_mesh.points
                    n_faces = stl_mesh.n_cells
                    
                    if len(vertices) == 0 or n_faces == 0:
                        return
                    
                    # Sample faces for performance
                    max_faces = 2000
                    if n_faces > max_faces:
                        step = max(1, n_faces // max_faces)
                        sampled_face_indices = list(range(0, n_faces, step))
                    else:
                        sampled_face_indices = list(range(n_faces))
                    
                    valid_triangles = []
                    invalid_count = 0
                    for face_idx in sampled_face_indices:
                        try:
                            cell = stl_mesh.get_cell(face_idx)
                            if cell.n_points == 3:
                                point_ids = cell.point_ids
                                if len(point_ids) == 3:
                                    triangle_verts = vertices[point_ids]
                                    
                                    # More robust validation: check that all 3 points are different
                                    p0, p1, p2 = triangle_verts[0], triangle_verts[1], triangle_verts[2]
                                    
                                    # Check if points are distinct (with small tolerance for floating point)
                                    dist_01 = np.linalg.norm(p0 - p1)
                                    dist_02 = np.linalg.norm(p0 - p2)
                                    dist_12 = np.linalg.norm(p1 - p2)
                                    
                                    # All distances must be > 1e-6 (very small threshold)
                                    if dist_01 > 1e-6 and dist_02 > 1e-6 and dist_12 > 1e-6:
                                        valid_triangles.append(triangle_verts)
                                    else:
                                        invalid_count += 1
                        except Exception:
                            invalid_count += 1
                            continue
                    
                    if invalid_count > 0:
                        print(f"‚ö†Ô∏è Skipped {invalid_count} invalid/degenerate triangles")
                    
                    # Plot valid triangles
                    # Use Poly3DCollection for filled triangles (plot_trisurf doesn't work well for single triangles)
                    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
                    
                    if not stl_wireframe.value and len(valid_triangles) > 0:
                        # Collect all triangles for batch rendering
                        poly3d = Poly3DCollection(valid_triangles, alpha=stl_opacity.value, 
                                                  facecolor=stl_color.value, edgecolor='none')
                        ax.add_collection3d(poly3d)
                    else:
                        # Wireframe mode - plot each triangle individually
                        for triangle in valid_triangles:
                            ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                                   [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                                   [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                                   color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                except Exception as e:
                    print(f"‚ö†Ô∏è Error extracting faces from PyVista mesh: {e}")
                    import traceback
                    traceback.print_exc()
            elif hasattr(stl_mesh, 'vectors'):  # numpy-stl format
                vectors = stl_mesh.vectors
                step = max(1, len(vectors) // 1000)  # Sample for performance
                sampled_vectors = vectors[::step]
                
                for triangle in sampled_vectors:
                    vertices = triangle
                    if stl_wireframe.value:
                        ax.plot([vertices[0, 0], vertices[1, 0], vertices[2, 0], vertices[0, 0]],
                               [vertices[0, 1], vertices[1, 1], vertices[2, 1], vertices[0, 1]],
                               [vertices[0, 2], vertices[1, 2], vertices[2, 2], vertices[0, 2]],
                               color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                    else:
                        ax.plot_trisurf(
                            vertices[:, 0], vertices[:, 1], vertices[:, 2],
                            color=stl_color.value, alpha=stl_opacity.value,
                            linewidth=0, edgecolor='none'
                        )
            elif hasattr(stl_mesh, 'vertices'):  # trimesh format
                vertices = stl_mesh.vertices
                faces = stl_mesh.faces
                step = max(1, len(faces) // 1000)  # Sample for performance
                sampled_faces = faces[::step]
                
                for face in sampled_faces:
                    triangle = vertices[face]
                    if stl_wireframe.value:
                        ax.plot([triangle[0, 0], triangle[1, 0], triangle[2, 0], triangle[0, 0]],
                               [triangle[0, 1], triangle[1, 1], triangle[2, 1], triangle[0, 1]],
                               [triangle[0, 2], triangle[1, 2], triangle[2, 2], triangle[0, 2]],
                               color=stl_color.value, alpha=stl_opacity.value, linewidth=0.5)
                    else:
                        ax.plot_trisurf(
                            triangle[:, 0], triangle[:, 1], triangle[:, 2],
                            color=stl_color.value, alpha=stl_opacity.value,
                            linewidth=0, edgecolor='none'
                        )
        except Exception as e:
            print(f"‚ö†Ô∏è Error displaying STL model: {e}")
    
    plt.colorbar(scatter, ax=ax, label=signal_name.replace('_', ' ').title())
    plt.tight_layout()
    return fig

def create_animation_viz():
    """Create animation visualization (static frame)."""
    global voxel_data
    
    signal_name = signal_selector.value
    if not signal_name or signal_name not in voxel_data:
        return None
    
    anim_type_val = anim_type.value
    frame = anim_frame.value
    signal = voxel_data[signal_name]
    colormap = volume_colormap.value
    
    fig = plt.figure(figsize=(12, 10))
    
    if anim_type_val == 'layer':
        # Show layer-by-layer
        slice_data = signal[:, :, frame]
        x_data = voxel_data['x'][:, :, frame]
        y_data = voxel_data['y'][:, :, frame]
        ax = fig.add_subplot(111)
        contour = ax.contourf(x_data, y_data, slice_data, levels=20, cmap=colormap)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_title(f'Layer {frame} / {signal.shape[2]}')
        plt.colorbar(contour, ax=ax, label=signal_name.replace('_', ' ').title())
    else:
        # Time or parameter animation (simplified)
        ax = fig.add_subplot(111, projection='3d')
        # Show frame as slice
        slice_data = signal[:, :, frame]
        x_data = voxel_data['x'][:, :, frame]
        y_data = voxel_data['y'][:, :, frame]
        z_data = np.full_like(x_data, voxel_data['z'][0, 0, frame])
        ax.plot_surface(x_data, y_data, z_data, facecolors=plt.cm.get_cmap(colormap)(
            (slice_data - slice_data.min()) / (slice_data.max() - slice_data.min() + 1e-8)),
            alpha=0.8)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f'Animation Frame {frame} / {signal.shape[2]}')
    
    plt.tight_layout()
    return fig

def update_visualization():
    """Update visualization display."""
    global current_visualization, voxel_data
    
    with viz_output:
        clear_output(wait=True)
        
        if not voxel_data:
            display(HTML("<p>Load data first to see visualization</p>"))
            return
        
        signal_name = signal_selector.value
        if not signal_name or signal_name not in voxel_data:
            display(HTML("<p>Select a signal to visualize</p>"))
            return
        
        try:
            fig = create_3d_visualization()
            if fig:
                plt.show()
                current_visualization = fig
            else:
                display(HTML("<p>Could not create visualization. Check data and settings.</p>"))
        except Exception as e:
            display(HTML(f"<p style='color: red;'>Error creating visualization: {str(e)}</p>"))
            import traceback
            traceback.print_exc()

def update_stats_display():
    """Update statistics display."""
    global voxel_data
    
    if voxel_data:
        signal_name = signal_selector.value
        if signal_name and signal_name in voxel_data:
            signal = voxel_data[signal_name]
            stats_html = f"<p><b>Signal:</b> {signal_name.replace('_', ' ').title()}</p>"
            stats_html += f"<p><b>Voxel Count:</b> {signal.size:,}</p>"
            stats_html += f"<p><b>Dimensions:</b> {signal.shape}</p>"
            stats_html += f"<p><b>Signal Range:</b> {signal.min():.2f} - {signal.max():.2f}</p>"
            stats_html += f"<p><b>Mean:</b> {signal.mean():.2f}</p>"
            stats_html += f"<p><b>Std:</b> {signal.std():.2f}</p>"
            stats_display.value = stats_html
        else:
            stats_display.value = "<p>Select a signal to see statistics</p>"
    else:
        stats_display.value = "<p>No data loaded</p>"

def update_view_display():
    """Update view settings display."""
    mode = viz_mode.value
    view_html = f"<p><b>Mode:</b> {mode}</p>"
    view_html += f"<p><b>Signal:</b> {signal_selector.value}</p>"
    
    if mode == 'slices':
        view_html += f"<p><b>Slice Axis:</b> {slice_axis.value}</p>"
        view_html += f"<p><b>Slice Position:</b> {slice_position.value}</p>"
    elif mode == 'multi_res':
        view_html += f"<p><b>Resolution Level:</b> {multires_level.value}</p>"
    elif mode == 'animation':
        view_html += f"<p><b>Frame:</b> {anim_frame.value}</p>"
    
    view_html += f"<p><b>Colormap:</b> {volume_colormap.value}</p>"
    view_display.value = view_html

def on_load_visualization(button):
    """Load visualization."""
    global voxel_data
    
    if data_source_mode.value == 'sample' and not voxel_data:
        # Generate demo data if using sample mode
        voxel_data = generate_demo_voxel_data()
        signal_selector.options = [('Temperature', 'temperature'), ('Density', 'density'), ('Power', 'power')]
        signal_selector.value = 'temperature'
    
    if not voxel_data:
        status_display.value = "<b>Status:</b> <span style='color: red;'>‚ö†Ô∏è Please load data first</span>"
        return
    
    status_display.value = "<b>Status:</b> Loading visualization..."
    update_visualization()
    update_stats_display()
    update_view_display()
    status_display.value = "<b>Status:</b> <span style='color: green;'>‚úÖ Visualization loaded</span>"
    performance_display.value = "<b>Performance:</b> FPS: 30 | Render Time: 0.5s"

# Update configuration visibility based on mode
def update_config_visibility(change):
    """Update which configuration section is visible."""
    mode = change['new']
    
    # Show relevant accordion section
    config_accordion.selected_index = {
        '3d_volume': 0,
        'slices': 1,
        'multi_res': 2,
        'animation': 3
    }.get(mode, 0)
    
    # STL model section is always accessible (index 4)
    # Users can manually open it from any mode

# Connect events
data_source_mode.observe(update_data_source_mode, names='value')
update_data_source_mode({'new': data_source_mode.value})
model_dropdown.observe(update_grid_dropdown, names='value')
grid_type_filter.observe(update_grid_dropdown, names='value')
load_grid_button.on_click(load_grid_from_mongodb)
load_button.on_click(on_load_visualization)
viz_mode.observe(update_config_visibility, names='value')
viz_mode.observe(lambda x: update_visualization(), names='value')
signal_selector.observe(lambda x: update_visualization(), names='value')
show_stl_model.observe(lambda x: update_visualization(), names='value')
stl_opacity.observe(lambda x: update_visualization(), names='value')
stl_color.observe(lambda x: update_visualization(), names='value')
stl_wireframe.observe(lambda x: update_visualization(), names='value')

# Update sliders based on mode
def update_slice_position_max(change):
    """Update slice position max based on axis."""
    global voxel_data
    if not voxel_data:
        return
    
    signal_name = signal_selector.value
    if not signal_name or signal_name not in voxel_data:
        return
    
    signal = voxel_data[signal_name]
    axis = slice_axis.value
    if axis == 'xy':
        slice_position.max = signal.shape[2] - 1
    elif axis == 'xz':
        slice_position.max = signal.shape[1] - 1
    else:  # yz
        slice_position.max = signal.shape[0] - 1

slice_axis.observe(update_slice_position_max, names='value')
signal_selector.observe(update_slice_position_max, names='value')

# ============================================
# Main Layout
# ============================================

main_layout = VBox([
    top_panel,
    HBox([left_panel, center_panel, right_panel]),
    bottom_panel
])

# Display the interface
display(main_layout)


VBox(children=(VBox(children=(HBox(children=(HTML(value='<b>Data Source:</b>'), RadioButtons(description='Sour‚Ä¶

## Summary

Congratulations! You've learned how to create 3D visualizations of voxel data.

### Key Takeaways

1. **3D Volume Rendering**: Surface, Volume, and Points rendering modes with customizable colormaps and opacity
2. **Slice Visualization**: Interactive XY, XZ, YZ plane slices with position and thickness controls
3. **Multi-Resolution**: Adaptive resolution display for efficient visualization of large datasets
4. **Animation**: Time-series, layer-by-layer, and parameter sweep animations
5. **Camera Controls**: Interactive camera with azimuth, elevation, zoom, and preset views
6. **Export Options**: Export images, videos, and 3D models
7. **Interactive Viewers**: Real-time updates with mouse controls (rotate, zoom, pan)

### Next Steps

Proceed to:
- **16_Advanced_Visualization.ipynb** - Advanced visualization techniques and custom visualizations

### Related Resources

- Visualization Documentation: `../docs/AM_QADF/05-modules/visualization.md`
- API Reference: `../docs/AM_QADF/06-api-reference/visualization-api.md`
- PyVista Documentation: https://docs.pyvista.org/
- Examples: `../examples/`
