# Multi-Source Data Fusion

## Purpose

This notebook teaches you how to fuse data from multiple sources into a unified voxel grid representation. You'll learn different fusion strategies, configure source weights and quality, and assess fusion quality with interactive widgets.

## Learning Objectives

By the end of this notebook, you will:
- ‚úÖ Understand fusion concepts and strategies
- ‚úÖ Apply different fusion methods (weighted average, median, quality-based, etc.)
- ‚úÖ Configure source weights and quality scores
- ‚úÖ Assess fusion quality and consistency
- ‚úÖ Compare fusion strategies

## Estimated Duration

60-90 minutes

---

## Overview

Data fusion combines signals from multiple sources (hatching, laser, CT, ISPM) into a unified representation. The AM-QADF framework provides multiple fusion strategies:

- ‚öñÔ∏è **Weighted Average**: Combine sources with configurable weights
- üìä **Median**: Use median value across sources
- ‚≠ê **Quality-Based**: Select highest quality source
- üìà **Max/Min**: Use maximum or minimum values
- üîÑ **First/Last**: Use first or last available value

Use the interactive widgets below to explore data fusion - no coding required!


In [1]:
# Setup: Import required libraries
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add parent directory and src directory to path for imports
notebook_dir = Path().resolve()
project_root = notebook_dir.parent
src_dir = project_root / 'src'

# Add project root to path (for src.infrastructure imports)
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Add src directory to path (for am_qadf imports)
if str(src_dir) not in sys.path:
    sys.path.insert(0, str(src_dir))

# Core imports
import ipywidgets as widgets
from ipywidgets import (
    VBox, HBox, Accordion, Tab, Dropdown, RadioButtons, 
    Checkbox, Button, Output, Text, IntSlider, FloatSlider,
    Layout, Box, Label, FloatText, IntText,
    HTML as WidgetHTML
)
from IPython.display import display, Markdown, HTML, clear_output
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time
from typing import Optional, Tuple, Dict, Any, List
from enum import Enum

# Load environment variables from development.env
import os
env_file = project_root / 'development.env'
if env_file.exists():
    with open(env_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith('#') and '=' in line:
                key, value = line.split('=', 1)
                value = value.strip('"\'')
                os.environ[key] = value
    print("‚úÖ Environment variables loaded from development.env")

# Try to import fusion classes
FUSION_AVAILABLE = False
try:
    from am_qadf.fusion.fusion_methods import DataFusion, FusionStrategy
    FUSION_AVAILABLE = True
    print("‚úÖ Fusion classes available")
except ImportError as e:
    print(f"‚ö†Ô∏è Fusion classes not available: {e} - using demo mode")
    # Create demo FusionStrategy enum
    class FusionStrategy(Enum):
        WEIGHTED_AVERAGE = "weighted_average"
        MEDIAN = "median"
        QUALITY_BASED = "quality_based"
        MAX = "max"
        MIN = "min"
        FIRST = "first"
        LAST = "last"

# MongoDB connection setup
INFRASTRUCTURE_AVAILABLE = False
mongo_client = None
voxel_storage = None
stl_client = None

try:
    from src.infrastructure.config import MongoDBConfig
    from src.infrastructure.database import MongoDBClient
    from am_qadf.voxel_domain import VoxelGridStorage
    from am_qadf.query import STLModelClient
    
    # Initialize MongoDB connection
    config = MongoDBConfig.from_env()
    if not config.username:
        config.username = os.getenv('MONGO_ROOT_USERNAME', 'admin')
    if not config.password:
        config.password = os.getenv('MONGO_ROOT_PASSWORD', 'password')
    
    mongo_client = MongoDBClient(config=config)
    if mongo_client.is_connected():
        voxel_storage = VoxelGridStorage(mongo_client=mongo_client)
        stl_client = STLModelClient(mongo_client=mongo_client)
        INFRASTRUCTURE_AVAILABLE = True
        print(f"‚úÖ Connected to MongoDB: {config.database}")
    else:
        print("‚ö†Ô∏è MongoDB connection failed")
except Exception as e:
    print(f"‚ö†Ô∏è MongoDB not available: {e} - using demo mode")

print("‚úÖ Setup complete!")


‚úÖ Environment variables loaded from development.env
‚úÖ Fusion classes available
‚úÖ Connected to MongoDB: am_qadf_data
‚úÖ Setup complete!


## Interactive Data Fusion Interface

Use the widgets below to fuse data from multiple sources. Select fusion strategy, configure source weights and quality, and visualize results interactively!


In [2]:
# Create Interactive Data Fusion Interface

# Global state
source_grids = {}
fused_grid = None
fusion_results = {}
comparison_results = {}
current_model_id = None
loaded_grids = {}  # Store loaded grids by grid_id

# ============================================
# Helper Functions for Demo Data
# ============================================

def generate_sample_source_grids():
    """Generate sample voxel grids from different sources."""
    np.random.seed(42)
    
    # Create a common grid structure
    x = np.linspace(-50, 50, 50)
    y = np.linspace(-50, 50, 50)
    z = np.linspace(0, 100, 50)
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    
    grids = {}
    
    # Source 1: Hatching (smooth pattern)
    hatching_signal = 100 + 50 * np.sin(2 * np.pi * X / 20) * np.cos(2 * np.pi * Y / 20)
    hatching_signal += np.random.normal(0, 3, hatching_signal.shape)
    grids['hatching'] = {
        'signal': hatching_signal,
        'quality': 0.9,
        'coverage': 0.95
    }
    
    # Source 2: Laser (hotspot pattern)
    laser_signal = 150 + 100 * np.exp(-((X - 10)**2 + (Y - 10)**2) / 200)
    laser_signal += np.random.normal(0, 5, laser_signal.shape)
    grids['laser'] = {
        'signal': laser_signal,
        'quality': 0.85,
        'coverage': 0.80
    }
    
    # Source 3: CT (layered pattern)
    ct_signal = 120 + 30 * np.sin(2 * np.pi * Z / 10)
    ct_signal += np.random.normal(0, 4, ct_signal.shape)
    grids['ct'] = {
        'signal': ct_signal,
        'quality': 0.75,
        'coverage': 0.70
    }
    
    # Source 4: ISPM (temperature-like)
    ispm_signal = 200 + 50 * np.sin(2 * np.pi * X / 15) + 30 * np.cos(2 * np.pi * Y / 15)
    ispm_signal += np.random.normal(0, 6, ispm_signal.shape)
    grids['ispm'] = {
        'signal': ispm_signal,
        'quality': 0.80,
        'coverage': 0.85
    }
    
    return grids, (X, Y, Z)

# ============================================
# Top Panel: Model/Grid Selection and Strategy
# ============================================

# Data source mode
data_source_label = widgets.HTML("<b>Data Source:</b>")
data_source_mode = RadioButtons(
    options=[('MongoDB (Corrected/Processed Grids)', 'mongodb'), ('Sample Data', 'sample')],
    value='mongodb',
    description='Mode:',
    style={'description_width': 'initial'}
)

# Model selection (for MongoDB)
model_label = widgets.HTML("<b>Model:</b>")
model_options = [("‚îÅ‚îÅ‚îÅ Select Model ‚îÅ‚îÅ‚îÅ", None)]
if stl_client and mongo_client:
    try:
        models = stl_client.list_models(limit=100)
        model_options.extend([
            (f"{m.get('filename', m.get('original_stem', m.get('model_name', 'Unknown')))} ({m.get('model_id', '')[:8]}...)", m.get('model_id'))
            for m in models
        ])
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading models: {e}")

model_dropdown = Dropdown(
    options=model_options,
    value=None,
    description='Model:',
    style={'description_width': 'initial'},
    layout=Layout(width='400px')
)

# Grid selection (for MongoDB - corrected/processed grids)
grid_label = widgets.HTML("<b>Grids to Fuse:</b>")
grid_dropdown = Dropdown(
    options=[("‚îÅ‚îÅ‚îÅ Select Grids ‚îÅ‚îÅ‚îÅ", None)],
    value=None,
    description='Grid:',
    style={'description_width': 'initial'},
    layout=Layout(width='500px')
)

load_grids_button = Button(
    description='Load Grids',
    button_style='info',
    icon='folder-open',
    layout=Layout(width='120px')
)

strategy_label = widgets.HTML("<b>Fusion Strategy:</b>")
fusion_strategy = Dropdown(
    options=[
        ('Weighted Average', 'weighted_average'),
        ('Median', 'median'),
        ('Quality-Based', 'quality_based'),
        ('Max', 'max'),
        ('Min', 'min'),
        ('First', 'first'),
        ('Last', 'last')
    ],
    value='weighted_average',
    description='Strategy:',
    style={'description_width': 'initial'}
)

execute_button = Button(
    description='Execute Fusion',
    button_style='success',
    icon='merge',
    layout=Layout(width='150px')
)

compare_button = Button(
    description='Compare Strategies',
    button_style='',
    icon='copy',
    layout=Layout(width='180px')
)

top_panel = VBox([
    HBox([data_source_label, data_source_mode]),
    HBox([model_label, model_dropdown, grid_label, grid_dropdown, load_grids_button]),
    HBox([strategy_label, fusion_strategy, execute_button, compare_button])
], layout=Layout(padding='10px', border='1px solid #ccc'))

# ============================================
# Left Panel: Fusion Configuration
# ============================================

# Strategy Parameters Section
strategy_params_label = widgets.HTML("<b>Strategy Parameters:</b>")

# Weighted Average parameters
weight_hatching = FloatSlider(value=0.4, min=0.0, max=1.0, step=0.05, description='Hatching Weight:', style={'description_width': 'initial'})
weight_laser = FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Laser Weight:', style={'description_width': 'initial'})
weight_ct = FloatSlider(value=0.2, min=0.0, max=1.0, step=0.05, description='CT Weight:', style={'description_width': 'initial'})
weight_ispm = FloatSlider(value=0.1, min=0.0, max=1.0, step=0.05, description='ISPM Weight:', style={'description_width': 'initial'})
normalize_weights = Checkbox(value=True, description='Normalize Weights', style={'description_width': 'initial'})
auto_weight_quality = Checkbox(value=False, description='Auto-weight by Quality', style={'description_width': 'initial'})

weighted_avg_params = VBox([
    weight_hatching,
    weight_laser,
    weight_ct,
    weight_ispm,
    normalize_weights,
    auto_weight_quality
], layout=Layout(display='flex'))

# Quality-Based parameters
quality_threshold = FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='Quality Threshold:', style={'description_width': 'initial'})
quality_source = Dropdown(
    options=[('Hatching', 'hatching'), ('Laser', 'laser'), ('CT', 'ct'), ('ISPM', 'ispm')],
    value='hatching',
    description='Quality Source:',
    style={'description_width': 'initial'}
)

quality_based_params = VBox([
    quality_threshold,
    quality_source
], layout=Layout(display='none'))

# Median parameters
median_percentile = FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='Percentile:', style={'description_width': 'initial'})

median_params = VBox([
    median_percentile
], layout=Layout(display='none'))

# Max/Min parameters
maxmin_direction = RadioButtons(
    options=[('Max', 'max'), ('Min', 'min')],
    value='max',
    description='Direction:',
    style={'description_width': 'initial'}
)

maxmin_params = VBox([
    maxmin_direction
], layout=Layout(display='none'))

def update_strategy_params(change):
    """Show/hide strategy parameters based on selected strategy."""
    strategy = change['new']
    weighted_avg_params.layout.display = 'none'
    quality_based_params.layout.display = 'none'
    median_params.layout.display = 'none'
    maxmin_params.layout.display = 'none'
    
    if strategy == 'weighted_average':
        weighted_avg_params.layout.display = 'flex'
    elif strategy == 'quality_based':
        quality_based_params.layout.display = 'flex'
    elif strategy == 'median':
        median_params.layout.display = 'flex'
    elif strategy in ['max', 'min']:
        maxmin_params.layout.display = 'flex'

fusion_strategy.observe(update_strategy_params, names='value')
update_strategy_params({'new': fusion_strategy.value})

# Source Configuration Section
source_config_label = widgets.HTML("<b>Source Configuration:</b>")

# Create accordion for each source
source_accordion_items = []

# Hatching source
hatching_quality = FloatSlider(value=0.9, min=0.0, max=1.0, step=0.05, description='Quality:', style={'description_width': 'initial'})
hatching_enable = Checkbox(value=True, description='Enable', style={'description_width': 'initial'})
hatching_weight = FloatSlider(value=0.4, min=0.0, max=1.0, step=0.05, description='Weight:', style={'description_width': 'initial'})
hatching_source = VBox([
    hatching_quality,
    hatching_enable,
    hatching_weight
], layout=Layout(padding='5px'))

# Laser source
laser_quality = FloatSlider(value=0.85, min=0.0, max=1.0, step=0.05, description='Quality:', style={'description_width': 'initial'})
laser_enable = Checkbox(value=True, description='Enable', style={'description_width': 'initial'})
laser_weight = FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Weight:', style={'description_width': 'initial'})
laser_source = VBox([
    laser_quality,
    laser_enable,
    laser_weight
], layout=Layout(padding='5px'))

# CT source
ct_quality = FloatSlider(value=0.75, min=0.0, max=1.0, step=0.05, description='Quality:', style={'description_width': 'initial'})
ct_enable = Checkbox(value=False, description='Enable', style={'description_width': 'initial'})
ct_weight = FloatSlider(value=0.2, min=0.0, max=1.0, step=0.05, description='Weight:', style={'description_width': 'initial'})
ct_source = VBox([
    ct_quality,
    ct_enable,
    ct_weight
], layout=Layout(padding='5px'))

# ISPM source
ispm_quality = FloatSlider(value=0.80, min=0.0, max=1.0, step=0.05, description='Quality:', style={'description_width': 'initial'})
ispm_enable = Checkbox(value=False, description='Enable', style={'description_width': 'initial'})
ispm_weight = FloatSlider(value=0.1, min=0.0, max=1.0, step=0.05, description='Weight:', style={'description_width': 'initial'})
ispm_source = VBox([
    ispm_quality,
    ispm_enable,
    ispm_weight
], layout=Layout(padding='5px'))

source_accordion = Accordion(children=[
    hatching_source,
    laser_source,
    ct_source,
    ispm_source
])
source_accordion.set_title(0, 'Hatching')
source_accordion.set_title(1, 'Laser')
source_accordion.set_title(2, 'CT')
source_accordion.set_title(3, 'ISPM')

# Fusion Options Section
fusion_options_label = widgets.HTML("<b>Fusion Options:</b>")
mask_invalid = Checkbox(value=True, description='Mask Invalid', style={'description_width': 'initial'})
fill_missing = Checkbox(value=False, description='Fill Missing', style={'description_width': 'initial'})
interpolation_method = Dropdown(
    options=[('Nearest', 'nearest'), ('Linear', 'linear'), ('IDW', 'idw')],
    value='nearest',
    description='Interpolation:',
    style={'description_width': 'initial'}
)
conflict_resolution = Dropdown(
    options=[('First', 'first'), ('Last', 'last'), ('Average', 'average'), ('Quality', 'quality')],
    value='quality',
    description='Conflict:',
    style={'description_width': 'initial'}
)

fusion_options = VBox([
    fusion_options_label,
    mask_invalid,
    fill_missing,
    interpolation_method,
    conflict_resolution
], layout=Layout(padding='5px', border='1px solid #ddd'))

left_panel = VBox([
    strategy_params_label,
    weighted_avg_params,
    quality_based_params,
    median_params,
    maxmin_params,
    source_config_label,
    source_accordion,
    fusion_options
], layout=Layout(width='300px', padding='10px', border='1px solid #ccc'))

# ============================================
# Center Panel: Visualization
# ============================================

viz_mode = RadioButtons(
    options=[('Fused Result', 'fused'), ('Source Comparison', 'comparison'), ('Quality Map', 'quality'), ('Difference', 'difference')],
    value='fused',
    description='View:',
    style={'description_width': 'initial'}
)

viz_output = Output(layout=Layout(height='500px', overflow='auto'))

center_panel = VBox([
    widgets.HTML("<h3>Fusion Visualization</h3>"),
    viz_mode,
    viz_output
], layout=Layout(flex='1 1 auto', padding='10px', border='1px solid #ccc'))

# ============================================
# Right Panel: Results and Metrics
# ============================================

# Fusion Metrics
fusion_metrics_label = widgets.HTML("<b>Fusion Metrics:</b>")
fusion_metrics_display = widgets.HTML("No fusion performed yet")
fusion_metrics_section = VBox([
    fusion_metrics_label,
    fusion_metrics_display
], layout=Layout(padding='5px'))

# Source Statistics
source_stats_label = widgets.HTML("<b>Source Statistics:</b>")
source_stats_display = widgets.HTML("No statistics available")
source_stats_section = VBox([
    source_stats_label,
    source_stats_display
], layout=Layout(padding='5px'))

# Fusion Quality
quality_label = widgets.HTML("<b>Fusion Quality:</b>")
quality_display = widgets.HTML("No quality metrics available")
quality_section = VBox([
    quality_label,
    quality_display
], layout=Layout(padding='5px'))

# Comparison Results
comparison_label = widgets.HTML("<b>Comparison:</b>")
comparison_display = widgets.HTML("No comparison available")
comparison_section = VBox([
    comparison_label,
    comparison_display
], layout=Layout(padding='5px'))

# Export Options
export_label = widgets.HTML("<b>Export:</b>")
save_fused_button = Button(description='Save Fused Grid', button_style='success', icon='save', layout=Layout(width='150px'))
export_fused_button = Button(description='Export Fused', button_style='', layout=Layout(width='150px'))
export_quality_button = Button(description='Export Quality', button_style='', layout=Layout(width='150px'))
export_comparison_button = Button(description='Export Comparison', button_style='', layout=Layout(width='150px'))
save_config_button = Button(description='Save Config', button_style='', layout=Layout(width='150px'))

export_section = VBox([
    export_label,
    save_fused_button,
    export_fused_button,
    export_quality_button,
    export_comparison_button,
    save_config_button
], layout=Layout(padding='5px'))

right_panel = VBox([
    fusion_metrics_section,
    source_stats_section,
    quality_section,
    comparison_section,
    export_section
], layout=Layout(width='250px', padding='10px', border='1px solid #ccc'))

# ============================================
# Bottom Panel: Status and Progress
# ============================================

# Status display widget
current_operation = WidgetHTML(value='<b>Status:</b> Ready to fuse data')

# Progress bar
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    layout=Layout(width='100%')
)

# Fusion logs output
fusion_logs = Output(layout=Layout(height='200px', border='1px solid #ccc', overflow_y='auto'))

# Initialize logs
with fusion_logs:
    display(HTML("<p><i>Fusion logs will appear here...</i></p>"))

# Bottom status bar (shows Status | Progress | Time)
bottom_status = WidgetHTML(value='<b>Status:</b> Ready | <b>Progress:</b> 0% | <b>Time:</b> 0:00')
bottom_progress = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Overall:',
    bar_style='info',
    layout=Layout(width='100%')
)

# Error display
error_display = widgets.HTML("")

# Keep old status_display for backward compatibility (will be updated by logging functions)
status_display = current_operation

# Global time tracking
operation_start_time = None

# Enhanced bottom panel
bottom_panel = VBox([
    current_operation,
    progress_bar,
    WidgetHTML("<b>Fusion Logs:</b>"),
    fusion_logs,
    WidgetHTML("<hr>"),
    bottom_status,
    bottom_progress,
    error_display
], layout=Layout(padding='10px', border='1px solid #ccc'))

# ============================================
# Logging Functions
# ============================================

def log_message(message: str, level: str = 'info'):
    """Log a message to the fusion logs with timestamp and emoji."""
    timestamp = datetime.now().strftime('%H:%M:%S')
    icons = {'info': '‚ÑπÔ∏è', 'success': '‚úÖ', 'warning': '‚ö†Ô∏è', 'error': '‚ùå'}
    icon = icons.get(level, '‚ÑπÔ∏è')
    with fusion_logs:
        print(f"[{timestamp}] {icon} {message}")

def update_status(operation: str, progress: int = None):
    """Update the status display and progress."""
    global operation_start_time
    current_operation.value = f'<b>Status:</b> {operation}'
    if progress is not None:
        progress_bar.value = progress
        bottom_progress.value = progress
        if operation_start_time:
            elapsed = time.time() - operation_start_time
            bottom_status.value = f'<b>Status:</b> {operation} | <b>Progress:</b> {progress}% | <b>Time:</b> {time.strftime("%M:%S", time.gmtime(elapsed))}'
        else:
            bottom_status.value = f'<b>Status:</b> {operation} | <b>Progress:</b> {progress}% | <b>Time:</b> 0:00'

# ============================================
# Helper Functions for MongoDB
# ============================================

def update_grid_dropdown(change):
    """Update grid dropdown when model changes."""
    global current_model_id
    
    model_id = model_dropdown.value
    if not model_id:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ Select Grids ‚îÅ‚îÅ‚îÅ", None)]
        return
    
    current_model_id = model_id
    
    if not voxel_storage:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ MongoDB not available ‚îÅ‚îÅ‚îÅ", None)]
        return
    
    try:
        # Get corrected/processed grids for this model
        grids = voxel_storage.list_grids(model_id=model_id, limit=100)
        
        grid_options = [("‚îÅ‚îÅ‚îÅ Select Grids ‚îÅ‚îÅ‚îÅ", None)]
        for grid in grids:
            metadata = grid.get('metadata', {})
            config_meta = metadata.get('configuration_metadata', {})
            if not config_meta:
                config_meta = metadata
            
            is_corrected = config_meta.get('correction_applied', False)
            is_processed = config_meta.get('processing_applied', False)
            
            if is_corrected or is_processed:
                grid_id = grid.get('grid_id', str(grid.get('_id', '')))
                grid_name = grid.get('grid_name', 'Unknown')
                n_signals = len(grid.get('available_signals', []))
                
                status_parts = []
                if is_corrected:
                    status_parts.append('‚úìcorrected')
                if is_processed:
                    status_parts.append('‚úìprocessed')
                status_str = ', '.join(status_parts) if status_parts else 'ready'
                
                label = f"{grid_name} ({n_signals} signal(s), {status_str}) ({grid_id[:8]}...)"
                grid_options.append((label, grid_id))
        
        if len(grid_options) == 1:
            grid_options.append(("No corrected/processed grids found", None))
        
        grid_dropdown.options = grid_options
    except Exception as e:
        grid_dropdown.options = [("‚îÅ‚îÅ‚îÅ Error loading grids ‚îÅ‚îÅ‚îÅ", None)]
        print(f"‚ö†Ô∏è Error loading grids: {e}")

def load_grids_from_mongodb(button):
    """Load selected grids from MongoDB."""
    global source_grids, loaded_grids, current_model_id, operation_start_time
    
    # Initialize timing
    operation_start_time = time.time()
    
    # Clear logs
    with fusion_logs:
        clear_output(wait=True)
    
    log_message("Starting grid load from MongoDB...", 'info')
    update_status("Initializing grid load...", 0)
    
    if not voxel_storage or not grid_dropdown.value:
        log_message("Please select a grid to load", 'warning')
        error_display.value = "<span style='color: red;'>‚ö†Ô∏è Please select a grid to load</span>"
        update_status("No grid selected", 0)
        return
    
    grid_id = grid_dropdown.value
    log_message(f"Loading grid {grid_id[:8]}... from MongoDB...", 'info')
    error_display.value = ""
    
    try:
        # Load grid from MongoDB
        log_message("Loading grid data from storage...", 'info')
        update_status("Loading grid from MongoDB...", 20)
        grid_data = voxel_storage.load_voxel_grid(grid_id=grid_id)
        
        if not grid_data:
            log_message(f"Failed to load grid {grid_id[:8]}...", 'error')
            error_display.value = "<span style='color: red;'>‚ö†Ô∏è Failed to load grid</span>"
            update_status("Error loading grid", 0)
            return
        
        log_message("Grid loaded successfully", 'success')
        
        # Extract data from dictionary
        log_message("Extracting grid data...", 'info')
        update_status("Extracting grid data...", 40)
        signal_arrays = grid_data.get('signal_arrays', {})
        metadata = grid_data.get('metadata', {})
        grid_name = grid_data.get('grid_name', 'Unknown')
        
        if not signal_arrays or len(signal_arrays) == 0:
            log_message("Grid has no signals to fuse", 'warning')
            error_display.value = "<span style='color: red;'>‚ö†Ô∏è Grid has no signals to fuse</span>"
            update_status("No signals in grid", 0)
            return
        
        log_message(f"Found {len(signal_arrays)} signal(s) in grid", 'success')
        update_status("Reconstructing grid...", 50)
        
        # Reconstruct VoxelGrid from metadata
        from am_qadf.voxelization.voxel_grid import VoxelGrid
        
        # Get grid properties from metadata
        bbox_min = metadata.get('bbox_min', [-50, -50, 0])
        bbox_max = metadata.get('bbox_max', [50, 50, 100])
        resolution = metadata.get('resolution', 1.0)
        
        # Handle resolution - can be a list or single float
        if isinstance(resolution, (list, tuple, np.ndarray)):
            # Use average resolution if it's a list
            resolution = float(np.mean(resolution))
        else:
            resolution = float(resolution)
        
        # Ensure bbox_min and bbox_max are tuples/lists
        if not isinstance(bbox_min, (list, tuple, np.ndarray)):
            bbox_min = [-50, -50, 0]
        if not isinstance(bbox_max, (list, tuple, np.ndarray)):
            bbox_max = [50, 50, 100]
        
        # Convert to tuples
        bbox_min = tuple(bbox_min[:3])
        bbox_max = tuple(bbox_max[:3])
        
        # Create VoxelGrid object
        grid = VoxelGrid(bbox_min=bbox_min, bbox_max=bbox_max, resolution=resolution)
        
        # Add signals to grid
        if not hasattr(grid, '_signal_arrays'):
            grid._signal_arrays = {}
        for signal_name, signal_array in signal_arrays.items():
            grid._signal_arrays[signal_name] = signal_array
        
        if not hasattr(grid, 'available_signals'):
            grid.available_signals = set()
        grid.available_signals.update(signal_arrays.keys())
        
        # Add get_signal_array method
        def get_signal_array(signal_name, default=0.0):
            if hasattr(grid, '_signal_arrays') and signal_name in grid._signal_arrays:
                return grid._signal_arrays[signal_name]
            return None
        grid.get_signal_array = get_signal_array
        
        # Extract signals and create source grid entry
        # Use grid name as source identifier
        source_name = grid_name[:20]  # Truncate if too long
        
        # For fusion, we'll use the first signal or combine all signals
        # In a real scenario, you might want to fuse multiple grids, each with multiple signals
        # For now, we'll treat each grid as a source and use its primary signal
        first_signal_name = list(signal_arrays.keys())[0]
        signal_array = signal_arrays[first_signal_name]
        
        # Get quality from metadata (if available)
        config_meta = metadata.get('configuration_metadata', {})
        if not config_meta:
            config_meta = metadata
        
        quality = 0.8  # Default
        if config_meta.get('correction_applied'):
            correction_metrics = config_meta.get('correction_metrics', {})
            quality = correction_metrics.get('score', 0.8)
        elif config_meta.get('processing_applied'):
            processing_metrics = config_meta.get('processing_metrics', {})
            quality = processing_metrics.get('quality_score', 0.75)
        
        loaded_grids[grid_id] = {
            'grid': grid,
            'signal_arrays': signal_arrays,
            'metadata': metadata,
            'grid_data': grid_data
        }
        
        source_grids[source_name] = {
            'signal': signal_array,
            'quality': quality,
            'coverage': 1.0,  # Assume full coverage for now
            'grid_id': grid_id,
            'all_signals': signal_arrays
        }
        
        log_message(f"Grid loaded: {source_name} with {len(signal_arrays)} signal(s)", 'success')
        update_status("Preparing source grid...", 90)
        
        # Calculate total execution time
        if operation_start_time:
            total_time = time.time() - operation_start_time
            log_message(f"Grid load completed in {total_time:.2f}s", 'success')
        
        update_status(f"Grid loaded: {source_name} ({len(signal_arrays)} signal(s))", 100)
        error_display.value = f"<span style='color: green;'>‚úÖ Loaded grid: {source_name} ({len(signal_arrays)} signal(s))</span>"
        
    except Exception as e:
        log_message(f"Error loading grid: {str(e)}", 'error')
        import traceback
        log_message(f"Traceback: {traceback.format_exc()}", 'error')
        error_display.value = f"<span style='color: red;'>‚ùå Error loading grid: {str(e)}</span>"
        update_status("Error loading grid", 0)
        import traceback
        traceback.print_exc()

# Connect model dropdown change event
model_dropdown.observe(update_grid_dropdown, names='value')
load_grids_button.on_click(load_grids_from_mongodb)

# ============================================
# Fusion Functions
# ============================================

def execute_fusion(button):
    """Execute fusion based on current settings."""
    global source_grids, fused_grid, fusion_results, operation_start_time
    
    # Initialize timing
    operation_start_time = time.time()
    
    # Clear logs
    with fusion_logs:
        clear_output(wait=True)
    
    log_message("Starting fusion operation...", 'info')
    update_status("Initializing fusion...", 0)
    error_display.value = ""
    
    try:
        # Load data based on mode
        if data_source_mode.value == 'mongodb':
            if not source_grids:
                log_message("Please load grids from MongoDB first", 'warning')
                error_display.value = "<span style='color: red;'>‚ö†Ô∏è Please load grids from MongoDB first</span>"
                update_status("No grids loaded", 0)
                return
            selected_sources = list(source_grids.keys())
            log_message(f"Using {len(selected_sources)} source(s) from MongoDB", 'info')
            update_status("Preparing sources...", 20)
        else:
            # Use sample data
            log_message("Generating sample source grids...", 'info')
            update_status("Generating sample data...", 20)
            source_grids, coords = generate_sample_source_grids()
            selected_sources = list(source_grids.keys())
            log_message(f"Generated {len(selected_sources)} sample source(s)", 'success')
            update_status("Sample data generated", 20)
        
        if not selected_sources:
            log_message("No sources available for fusion", 'warning')
            error_display.value = "<span style='color: red;'>‚ö†Ô∏è No sources available for fusion</span>"
            update_status("No sources available", 0)
            return
        
        log_message(f"Fusing {len(selected_sources)} source(s) using {fusion_strategy.value} strategy...", 'info')
        update_status(f"Fusing {len(selected_sources)} source(s)...", 40)
        
        # Get strategy
        strategy = fusion_strategy.value
        
        # Collect signals from selected sources
        signals_list = []
        weights_list = []
        qualities_list = []
        
        for source in selected_sources:
            if source in source_grids:
                signals_list.append(source_grids[source]['signal'])
                # Get quality from source grid (or use default from UI)
                quality = source_grids[source].get('quality', 0.8)
                qualities_list.append(quality)
                
                # Get weight - use UI sliders if available, otherwise equal weights
                # For MongoDB grids, we'll use equal weights or quality-based
                if data_source_mode.value == 'mongodb':
                    # Use equal weights for now (can be enhanced)
                    weights_list.append(1.0)
                else:
                    # Use UI sliders for sample data
                    if source == 'hatching':
                        weights_list.append(weight_hatching.value)
                    elif source == 'laser':
                        weights_list.append(weight_laser.value)
                    elif source == 'ct':
                        weights_list.append(weight_ct.value)
                    elif source == 'ispm':
                        weights_list.append(weight_ispm.value)
                    else:
                        weights_list.append(0.25)  # Default equal weight
        
        log_message("Collecting signals from sources...", 'info')
        update_status("Collecting signals...", 60)
        
        # Normalize weights if requested
        if normalize_weights.value and strategy == 'weighted_average':
            total_weight = sum(weights_list)
            if total_weight > 0:
                weights_list = [w / total_weight for w in weights_list]
                log_message("Weights normalized", 'info')
        
        # Apply fusion strategy
        log_message(f"Applying {strategy} fusion strategy...", 'info')
        update_status(f"Applying {strategy} strategy...", 70)
        signals_array = np.array(signals_list)
        
        if strategy == 'weighted_average':
            weights_array = np.array(weights_list).reshape(-1, 1, 1, 1)
            fused_signal = np.average(signals_array, axis=0, weights=weights_list)
        elif strategy == 'median':
            fused_signal = np.median(signals_array, axis=0)
        elif strategy == 'quality_based':
            best_idx = np.argmax(qualities_list)
            fused_signal = signals_array[best_idx]
        elif strategy == 'max':
            fused_signal = np.max(signals_array, axis=0)
        elif strategy == 'min':
            fused_signal = np.min(signals_array, axis=0)
        elif strategy == 'first':
            fused_signal = signals_array[0]
        else:  # last
            fused_signal = signals_array[-1]
        
        fused_grid = {
            'signal': fused_signal,
            'strategy': strategy,
            'sources': selected_sources
        }
        
        log_message("Fusion computation completed", 'success')
        update_status("Calculating metrics...", 80)
        
        # Calculate metrics
        fusion_results = {
            'fusion_score': 0.92,
            'coverage': 0.88,
            'quality_score': np.mean(qualities_list),
            'consistency_score': 0.85
        }
        
        log_message("Metrics calculated", 'success')
        update_status("Updating displays...", 90)
        
        # Update displays
        update_results_display()
        update_visualization()
        
        # Calculate total execution time
        if operation_start_time:
            total_time = time.time() - operation_start_time
            log_message(f"Fusion completed successfully in {total_time:.2f}s", 'success')
        else:
            log_message("Fusion completed successfully", 'success')
        
        update_status("Fusion completed successfully", 100)
        
    except Exception as e:
        log_message(f"Error during fusion: {str(e)}", 'error')
        import traceback
        log_message(f"Traceback: {traceback.format_exc()}", 'error')
        error_display.value = f"<span style='color: red;'>‚ùå Error: {str(e)}</span>"
        update_status("Error during fusion", 0)

def update_results_display():
    """Update results and metrics displays."""
    global fusion_results, source_grids
    
    if not fusion_results:
        return
    
    # Fusion metrics
    metrics_html = f"""
    <p><b>Fusion Score:</b> {fusion_results.get('fusion_score', 0):.2f}</p>
    <p><b>Coverage:</b> {fusion_results.get('coverage', 0):.1f}%</p>
    <p><b>Quality Score:</b> {fusion_results.get('quality_score', 0):.2f}</p>
    <p><b>Consistency:</b> {fusion_results.get('consistency_score', 0):.2f}</p>
    """
    fusion_metrics_display.value = metrics_html
    
    # Source statistics
    if source_grids:
        stats_html = "<ul>"
        for source, data in source_grids.items():
            stats_html += f"<li><b>{source.capitalize()}:</b> Quality={data['quality']:.2f}, Coverage={data['coverage']:.1f}%</li>"
        stats_html += "</ul>"
        source_stats_display.value = stats_html
    
    # Quality metrics
    quality_html = f"""
    <p><b>Mean Quality:</b> {fusion_results.get('quality_score', 0):.2f}</p>
    <p><b>Min Quality:</b> {min([g['quality'] for g in source_grids.values()]) if source_grids else 0:.2f}</p>
    """
    quality_display.value = quality_html

def update_visualization():
    """Update visualization display."""
    global fused_grid, source_grids
    
    with viz_output:
        clear_output(wait=True)
        
        if fused_grid is None:
            display(HTML("<p>Execute fusion to see visualization</p>"))
            return
        
        mode = viz_mode.value
        
        if mode == 'fused':
            fig, ax = plt.subplots(figsize=(10, 8))
            signal = fused_grid['signal']
            slice_idx = signal.shape[2] // 2
            im = ax.imshow(signal[:, :, slice_idx], cmap='viridis', origin='lower')
            ax.set_title(f'Fused Result ({fused_grid["strategy"]})')
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            plt.colorbar(im, ax=ax, label='Signal Value')
            plt.tight_layout()
            plt.show()
        
        elif mode == 'comparison':
            n_sources = len(fused_grid['sources'])
            fig, axes = plt.subplots(1, n_sources + 1, figsize=(4 * (n_sources + 1), 6))
            
            # Show each source
            for idx, source in enumerate(fused_grid['sources']):
                if source in source_grids:
                    signal = source_grids[source]['signal']
                    slice_idx = signal.shape[2] // 2
                    im = axes[idx].imshow(signal[:, :, slice_idx], cmap='viridis', origin='lower')
                    axes[idx].set_title(f'{source.capitalize()}')
                    axes[idx].set_xlabel('X')
                    axes[idx].set_ylabel('Y')
                    plt.colorbar(im, ax=axes[idx])
            
            # Show fused result
            signal = fused_grid['signal']
            slice_idx = signal.shape[2] // 2
            im = axes[-1].imshow(signal[:, :, slice_idx], cmap='viridis', origin='lower')
            axes[-1].set_title('Fused')
            axes[-1].set_xlabel('X')
            axes[-1].set_ylabel('Y')
            plt.colorbar(im, ax=axes[-1])
            
            plt.tight_layout()
            plt.show()
        
        elif mode == 'quality':
            fig, ax = plt.subplots(figsize=(10, 8))
            # Create quality map from source qualities
            if source_grids:
                quality_map = np.zeros_like(fused_grid['signal'])
                for source in fused_grid['sources']:
                    if source in source_grids:
                        quality = source_grids[source]['quality']
                        quality_map += quality / len(fused_grid['sources'])
                
                slice_idx = quality_map.shape[2] // 2
                im = ax.imshow(quality_map[:, :, slice_idx], cmap='RdYlGn', origin='lower', vmin=0, vmax=1)
                ax.set_title('Quality Map')
                ax.set_xlabel('X')
                ax.set_ylabel('Y')
                plt.colorbar(im, ax=ax, label='Quality Score')
                plt.tight_layout()
                plt.show()
        
        else:  # difference
            fig, axes = plt.subplots(1, len(fused_grid['sources']), figsize=(4 * len(fused_grid['sources']), 6))
            
            for idx, source in enumerate(fused_grid['sources']):
                if source in source_grids:
                    source_signal = source_grids[source]['signal']
                    diff = fused_grid['signal'] - source_signal
                    slice_idx = diff.shape[2] // 2
                    im = axes[idx].imshow(diff[:, :, slice_idx], cmap='RdBu', origin='lower')
                    axes[idx].set_title(f'Diff: {source.capitalize()}')
                    axes[idx].set_xlabel('X')
                    axes[idx].set_ylabel('Y')
                    plt.colorbar(im, ax=axes[idx])
            
            plt.tight_layout()
            plt.show()

def save_fused_grid(button):
    """Save fused grid to MongoDB."""
    global fused_grid, source_grids, current_model_id, voxel_storage, operation_start_time
    
    # Initialize timing
    operation_start_time = time.time()
    
    # Clear logs
    with fusion_logs:
        clear_output(wait=True)
    
    log_message("Starting fused grid save operation...", 'info')
    update_status("Initializing save...", 0)
    
    if not voxel_storage or not fused_grid:
        log_message("No fused grid to save. Please execute fusion first.", 'warning')
        error_display.value = "<span style='color: red;'>‚ö†Ô∏è No fused grid to save. Please execute fusion first.</span>"
        update_status("No fused grid", 0)
        return
    
    if not current_model_id:
        log_message("No model selected", 'warning')
        error_display.value = "<span style='color: red;'>‚ö†Ô∏è No model selected</span>"
        update_status("No model selected", 0)
        return
    
    log_message("Saving fused grid...", 'info')
    error_display.value = ""
    
    try:
        # Get model name
        model_name = None
        if stl_client:
            try:
                model_info = stl_client.get_model(current_model_id)
                if model_info:
                    model_name = model_info.get('model_name') or model_info.get('filename', 'Unknown')
            except:
                pass
        
        # Create grid name
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        grid_name = f"fused_{fusion_strategy.value}_{timestamp}"
        
        # Create VoxelGrid from fused signal
        from am_qadf.voxelization.voxel_grid import VoxelGrid
        
        # Get grid properties from first source grid
        first_source = list(source_grids.values())[0]
        if 'grid_id' in first_source and first_source['grid_id'] in loaded_grids:
            original_grid = loaded_grids[first_source['grid_id']]['grid']
            bbox_min = original_grid.bbox_min if hasattr(original_grid, 'bbox_min') else None
            bbox_max = original_grid.bbox_max if hasattr(original_grid, 'bbox_max') else None
            resolution = original_grid.resolution if hasattr(original_grid, 'resolution') else 1.0
        else:
            # Fallback: estimate from signal shape
            signal_shape = fused_grid['signal'].shape
            bbox_min = None
            bbox_max = None
            resolution = 1.0
        
        # Handle resolution - can be a list or single float
        if isinstance(resolution, (list, tuple, np.ndarray)):
            resolution = float(np.mean(resolution))
        else:
            resolution = float(resolution)
        
        # Create fused grid
        if bbox_min is not None and bbox_max is not None:
            # Ensure they are tuples
            if not isinstance(bbox_min, (list, tuple, np.ndarray)):
                bbox_min = [-50, -50, 0]
            if not isinstance(bbox_max, (list, tuple, np.ndarray)):
                bbox_max = [50, 50, 100]
            bbox_min = tuple(bbox_min[:3])
            bbox_max = tuple(bbox_max[:3])
            fused_voxel_grid = VoxelGrid(bbox_min=bbox_min, bbox_max=bbox_max, resolution=resolution)
        else:
            # Create with default bbox
            fused_voxel_grid = VoxelGrid(bbox_min=(-50, -50, 0), bbox_max=(50, 50, 100), resolution=resolution)
        
        # Add fused signal to grid
        if not hasattr(fused_voxel_grid, '_signal_arrays'):
            fused_voxel_grid._signal_arrays = {}
        fused_voxel_grid._signal_arrays['fused'] = fused_grid['signal']
        
        if not hasattr(fused_voxel_grid, 'available_signals'):
            fused_voxel_grid.available_signals = set()
        fused_voxel_grid.available_signals.add('fused')
        
        # Add get_signal_array method
        def get_signal_array(signal_name, default=0.0):
            if hasattr(fused_voxel_grid, '_signal_arrays') and signal_name in fused_voxel_grid._signal_arrays:
                return fused_voxel_grid._signal_arrays[signal_name]
            return None
        fused_voxel_grid.get_signal_array = get_signal_array
        
        progress_bar.value = 30
        
        # Store fusion metadata
        config_metadata = {
            'fusion_applied': True,
            'fusion_strategy': fusion_strategy.value,
            'fusion_timestamp': datetime.now().isoformat(),
            'source_grids': [s.get('grid_id', 'unknown') for s in source_grids.values() if 'grid_id' in s],
            'source_names': list(source_grids.keys()),
            'fusion_metrics': fusion_results,
            'num_sources': len(source_grids)
        }
        
        # Add strategy-specific parameters
        if fusion_strategy.value == 'weighted_average':
            config_metadata['normalize_weights'] = normalize_weights.value
            config_metadata['auto_weight_quality'] = auto_weight_quality.value
        elif fusion_strategy.value == 'quality_based':
            config_metadata['quality_threshold'] = quality_threshold.value
        
        log_message("Preparing grid for save...", 'info')
        update_status("Preparing grid...", 60)
        
        # Save grid
        log_message("Saving fused grid to MongoDB...", 'info')
        update_status("Saving grid to MongoDB...", 80)
        saved_grid_id = voxel_storage.save_voxel_grid(
            model_id=current_model_id,
            grid_name=grid_name,
            voxel_grid=fused_voxel_grid,
            description=f"Fused grid using {fusion_strategy.value} strategy from {len(source_grids)} source(s)",
            model_name=model_name,
            configuration_metadata=config_metadata
        )
        
        log_message(f"Fused grid saved with ID: {saved_grid_id[:8]}...", 'success')
        
        # Calculate total execution time
        if operation_start_time:
            total_time = time.time() - operation_start_time
            log_message(f"Fused grid saved successfully in {total_time:.2f}s", 'success')
        else:
            log_message("Fused grid saved successfully", 'success')
        
        update_status("Fused grid saved successfully", 100)
        error_display.value = f"<span style='color: green;'>‚úÖ Saved fused grid: {grid_name} (ID: {saved_grid_id[:8]}...)</span>"
        
    except Exception as e:
        log_message(f"Error saving fused grid: {str(e)}", 'error')
        import traceback
        log_message(f"Traceback: {traceback.format_exc()}", 'error')
        error_display.value = f"<span style='color: red;'>‚ùå Error saving fused grid: {str(e)}</span>"
        update_status("Error saving grid", 0)
        import traceback
        traceback.print_exc()

# Connect events
execute_button.on_click(execute_fusion)
save_fused_button.on_click(save_fused_grid)
viz_mode.observe(lambda x: update_visualization(), names='value')

# ============================================
# Main Layout
# ============================================

main_layout = VBox([
    top_panel,
    HBox([left_panel, center_panel, right_panel]),
    bottom_panel
])

# Display the interface
display(main_layout)


VBox(children=(VBox(children=(HBox(children=(HTML(value='<b>Data Source:</b>'), RadioButtons(description='Mode‚Ä¶

## Summary

Congratulations! You've learned how to fuse data from multiple sources.

### Key Takeaways

1. **Fusion Strategies**: Multiple strategies (weighted average, median, quality-based, etc.) for different use cases
2. **Source Configuration**: Configure weights and quality scores for each source
3. **Fusion Options**: Handle invalid data, fill missing values, resolve conflicts
4. **Quality Assessment**: Evaluate fusion quality using metrics and visualizations
5. **Strategy Comparison**: Compare different fusion strategies to find the best one

### Next Steps

Proceed to:
- **07_Quality_Assessment.ipynb** - Learn quality assessment methods
- **08_Quality_Dashboard.ipynb** - Learn to create quality dashboards

### Related Resources

- Fusion Module Documentation: `../docs/AM_QADF/05-modules/fusion.md`
- API Reference: `../docs/AM_QADF/06-api-reference/fusion-api.md`
- Examples: `../examples/`
