In [1]:
# Photoluminescence Analysis App - Main Notebook
# This is the main notebook file to run with Voila

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from scipy.signal import find_peaks
from lmfit import Model, Parameters
from lmfit.models import GaussianModel, VoigtModel, LorentzianModel, LinearModel, PolynomialModel, ExponentialModel
import ipywidgets as widgets
from IPython.display import display, HTML
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from pl_data_loader import PLDataLoader
from pl_fitting_models import PLFittingModels
from pl_visualization import PLVisualization
from pl_peak_detection import PLPeakDetection
from pl_export_utils import PLExportUtils

def get_var():
    %store -r h5_path

class PLAnalysisApp:
    def __init__(self):
        self.data_loader = PLDataLoader()
        self.fitting_models = PLFittingModels()
        self.visualization = PLVisualization()
        self.peak_detection = PLPeakDetection()
        self.export_utils = PLExportUtils()
        
        # Data storage
        self.data_matrix = None
        self.wavelengths = None
        self.timestamps = None
        self.current_time_idx = 0
        self.current_spectrum = None
        self.heatmap_fig = None
        self.use_figure_widget = True  # Set to True to use FigureWidget
        self.heatmap_widget = None
        self.h5_path = None
        
        # Check if data analyzer is loaded from ISA Voila h5
        self.check_for_h5_path()
        
        # UI components
        self.setup_ui()
        
        # Fitting results storage
        self.fitting_results = {}
        
        if self.data_stored_in_h5:
            self.on_file_upload_new(change={'new': [None]})  # Trigger load from h5
        
    def check_for_h5_path(self):
        get_var()
        try:
            h5_path
            self.h5_path = h5_path
            self.data_stored_in_h5 = True
        except NameError:
            self.data_stored_in_h5 = False
        
    def setup_ui(self):
        """Initialize all UI components"""
        if self.data_stored_in_h5:
            self.mode_dropdown = widgets.Dropdown(options=[('PL raw', 'pl_raw'),
                                                           ('PL binned & bgs', 'pl_binned'),
                                                           ('GIWAXS', 'giwaxs')],
                                                  value='giwaxs',
                                                  description='Measurement:')

            self.mode_dropdown.observe(self.on_file_upload_new, names='value')
        else:
            # File upload
            self.file_upload = widgets.FileUpload(
                accept='.txt,.csv,.dat',
                multiple=False,
                description='Upload PL Data'
            )
            self.file_upload.observe(self.on_file_upload_new, names='value')
        
        # Time control with both slider and input
        self.time_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=100,
            step=1,
            description='',
            disabled=True,
            continuous_update=False,
            layout=widgets.Layout(width='300px')
        )
        self.time_slider.observe(self.on_time_change_slider, names='value')
        
        self.time_input = widgets.BoundedIntText(
            value=0,
            min=0,
            max=100,
            step=1,
            description='Time Index:',
            disabled=True,
            layout=widgets.Layout(width='150px')
        )
        self.time_input.observe(self.on_time_change_input, names='value')
        
        # Time display
        self.time_display = widgets.Label(value="Time: Not loaded")
        
        # Setup components
        self.setup_action_buttons()
        self.setup_fitting_controls()
        
        # Output areas
        self.heatmap_output = widgets.Output()
        self.spectrum_output = widgets.Output()
        self.status_output = widgets.Output()
        self.time_series_output = widgets.Output()  # New output for time series plots
        self.export_output = widgets.Output()  # New output for export results
        
    def on_file_upload_new(self, change):
        """Handle file upload"""
        if not change['new']:
            return

        try:
            with self.status_output:
                self.status_output.clear_output()
                print("Loading file...")
            
            if self.data_stored_in_h5:
                mode = self.mode_dropdown.value
                h5_path =self.h5_path
                self.data_matrix, self.wavelengths, self.timestamps = self.data_loader.load_h5_data(mode=mode, h5_path=h5_path)
            else:
                # Get the first uploaded file
                uploaded_files = change['new']
                uploaded_file = uploaded_files[0]
                file_content = uploaded_file['content']
                
                # Load data
                self.data_matrix, self.wavelengths, self.timestamps = self.data_loader.load_data(file_content)
    
            # Update UI controls
            max_time_idx = len(self.timestamps) - 1
            self.time_slider.max = max_time_idx
            self.time_slider.disabled = False
            self.time_slider.value = 0
            self.time_input.max = max_time_idx
            self.time_input.disabled = False
            self.time_input.value = 0
            
            # Enable buttons and controls
            self.detect_peaks_btn.disabled = False
            self.fit_current_btn.disabled = False
            self.fit_all_btn.disabled = False
            
            # Enable background subtraction controls
            self.background_subtract_checkbox.disabled = False
            self.background_time_start.disabled = False
            self.background_time_start.max = len(self.timestamps) - 1
            self.background_num_curves.disabled = False
            self.background_num_curves.max = min(50, len(self.timestamps))
            self.subtract_background_btn.disabled = False
            
            # Enable fit range controls
            self.fit_start_idx.disabled = False
            self.fit_start_idx.max = len(self.timestamps) - 1
            self.fit_end_idx.disabled = False
            self.fit_end_idx.max = len(self.timestamps)
            self.fit_end_idx.value = len(self.timestamps)
            self.fit_all_range_btn.disabled = False
    
            # Initial visualization
            self.update_visualizations()
    
            with self.status_output:
                self.status_output.clear_output()
                print(f"✅ Successfully loaded data!")
                print(f"📊 {self.data_matrix.shape[0]} time points, {self.data_matrix.shape[1]} wavelengths")
                print(f"⏱️ Time range: {self.timestamps.min():.3f} - {self.timestamps.max():.3f} s")
                print(f"🌈 Wavelength range: {self.wavelengths.min():.1f} - {self.wavelengths.max():.1f} nm")
                
        except Exception as e:
            import traceback
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error loading file: {str(e)}")
                print("\\n🔍 Full traceback:")
                print(traceback.format_exc())
        
    def setup_fitting_controls(self):
        """Setup fitting control widgets"""
        # Background subtraction - collapsible
        self.background_accordion = widgets.Accordion(children=[], titles=[])
        self.setup_background_controls()
        
        # Peak detection and fitting - collapsible
        self.peak_accordion = widgets.Accordion(children=[], titles=[])
        self.setup_peak_controls()
        
        # Fit range controls
        self.setup_fit_range_controls()
        
    def setup_peak_controls(self):
        """Setup peak detection and fitting controls"""
        # Peak detection button
        self.detect_peaks_btn = widgets.Button(
            description='Auto Detect Peaks',
            button_style='info',
            disabled=True
        )
        self.detect_peaks_btn.on_click(self.on_detect_peaks)
        
        # Model container for dynamic peak models
        self.models_container = widgets.VBox([])
        self.peak_models = []
        
        # Add initial peak model
        self.add_peak_model()
        
        # Peak controls container
        peak_controls = widgets.VBox([
            self.detect_peaks_btn,
            widgets.HTML("<h4>Peak Models</h4>"),
            self.models_container
        ])
        
        # Put in accordion (collapsible)
        self.peak_accordion.children = [peak_controls]
        self.peak_accordion.set_title(0, 'Peak Detection & Fitting')
        self.peak_accordion.selected_index = 0  # Start open
        
    def setup_background_controls(self):
        """Setup background subtraction controls"""
        self.background_subtract_checkbox = widgets.Checkbox(
            value=True,
            description='Enable Background Subtraction',
            disabled=False
        )
        
        self.background_time_start = widgets.BoundedIntText(
            value=0,
            min=0,
            max=100,
            step=1,
            description='Start Index:',
            disabled=True,
            layout=widgets.Layout(width='180px')
        )
        
        self.background_num_curves = widgets.BoundedIntText(
            value=10,
            min=1,
            max=50,
            step=1,
            description='# Curves:',
            disabled=True,
            layout=widgets.Layout(width='180px')
        )
        
        self.subtract_background_btn = widgets.Button(
            description='Apply Background Subtraction',
            button_style='warning',
            disabled=True,
            layout=widgets.Layout(width='200px')
        )
        self.subtract_background_btn.on_click(self.on_subtract_background)
        
        # Background controls container
        background_controls = widgets.VBox([
            self.background_subtract_checkbox,
            widgets.HBox([self.background_time_start, self.background_num_curves]),
            self.subtract_background_btn
        ])
        
        # Put in accordion (collapsible)
        self.background_accordion.children = [background_controls]
        self.background_accordion.set_title(0, 'Background Subtraction')
        self.background_accordion.selected_index = 0  # Start closed
        
    def setup_fit_range_controls(self):
        """Setup fitting range controls"""
        self.fit_start_idx = widgets.BoundedIntText(
            value=0,
            min=0,
            max=100,
            step=1,
            description='Start:',
            disabled=True,
            layout=widgets.Layout(width='180px')
        )
        
        self.fit_end_idx = widgets.BoundedIntText(
            value=100,
            min=0,
            max=100,
            step=1,
            description='End:',
            disabled=True,
            layout=widgets.Layout(width='180px')
        )
        
        self.fit_all_range_btn = widgets.Button(
            description='Fit Range',
            button_style='warning',
            disabled=True,
            layout=widgets.Layout(width='100px')
        )
        self.fit_all_range_btn.on_click(self.on_fit_range)
        
    def setup_action_buttons(self):
        """Setup action buttons"""
        self.fit_current_btn = widgets.Button(
            description='Fit Current Spectrum',
            button_style='success',
            disabled=True
        )
        self.fit_current_btn.on_click(self.on_fit_current)
        
        self.update_params_btn = widgets.Button(
            description='Update Parameters',
            button_style='info',
            disabled=True
        )
        self.update_params_btn.on_click(self.on_update_parameters)
        
        self.fit_all_btn = widgets.Button(
            description='Fit All Spectra',
            button_style='warning',
            disabled=True
        )
        self.fit_all_btn.on_click(self.on_fit_all)
        
        self.export_btn = widgets.Button(
            description='Export Results',
            button_style='info',
            disabled=True
        )
        self.export_btn.on_click(self.on_export_results)
        
        # R-squared display
        self.r_squared_display = widgets.Label(value="R²: Not fitted")
    
    def on_time_change_slider(self, change):
        """Handle time slider change"""
        # Update input to match slider
        self.time_input.value = change['new']
        self.update_time_display_and_spectrum(change['new'])
        
    def on_time_change_input(self, change):
        """Handle time input change"""
        # Update slider to match input
        self.time_slider.value = change['new']
        self.update_time_display_and_spectrum(change['new'])
        
    def update_time_display_and_spectrum(self, new_idx):
        """Update time display and spectrum"""
        self.current_time_idx = new_idx
        self.time_display.value = f"Time: {self.timestamps[self.current_time_idx]:.3f}s"
        self.current_spectrum = self.data_matrix[self.current_time_idx, :]
        
        # Update spectrum plot (this is fast)
        self.update_spectrum_plot()
        
        # For now, recreate heatmap to ensure it's visible
        # (We can optimize this later once we confirm it works)
        self.update_heatmap()
        
    def on_subtract_background(self, button):
        """Apply background subtraction - MODIFIED to force heatmap recreation"""
        if self.data_matrix is None:
            return
            
        try:
            start_idx = self.background_time_start.value
            num_curves = self.background_num_curves.value
            
            # Ensure we don't go beyond the data
            end_idx = min(start_idx + num_curves, len(self.timestamps))
            
            if start_idx >= len(self.timestamps):
                with self.status_output:
                    self.status_output.clear_output()
                    print("❌ Start time index is beyond the data range")
                return
            
            # Calculate background (average of selected time range)
            background_spectra = self.data_matrix[start_idx:end_idx, :]
            background_average = np.mean(background_spectra, axis=0)
            
            # Store original data if not already stored
            if not hasattr(self, 'original_data_matrix'):
                self.original_data_matrix = self.data_matrix.copy()
            
            # Apply background subtraction
            if self.background_subtract_checkbox.value:
                self.data_matrix = self.original_data_matrix - background_average
            else:
                self.data_matrix = self.original_data_matrix.copy()
            
            # Update current spectrum
            self.current_spectrum = self.data_matrix[self.current_time_idx, :]
            
            # Force heatmap recreation since data changed
            self.heatmap_fig = None
            
            # Update visualizations
            self.update_visualizations()
            
            with self.status_output:
                self.status_output.clear_output()
                if self.background_subtract_checkbox.value:
                    print(f"✅ Background subtracted using time range {self.timestamps[start_idx]:.3f}s - {self.timestamps[end_idx-1]:.3f}s")
                else:
                    print("✅ Background subtraction removed - original data restored")
                    
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error in background subtraction: {str(e)}")
        
    def add_peak_model(self):
        """Add a new peak model to the interface"""
        model_idx = len(self.peak_models)
        
        # Model type selection
        model_type = widgets.Dropdown(
            options=['Gaussian', 'Voigt', 'Lorentzian'],
            value='Gaussian',
            description=f'Model {model_idx + 1}:',
            layout=widgets.Layout(width='150px')
        )
        
        # Parameter inputs
        center_input = widgets.FloatText(
            value=500.0,
            description='Center (nm):',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='150px')
        )
        
        height_input = widgets.FloatText(
            value=1000.0,
            description='Height:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='150px')
        )
        
        sigma_input = widgets.FloatText(
            value=10.0,
            description='Sigma:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='150px')
        )
        
        # Control buttons
        remove_btn = widgets.Button(
            description='Remove',
            button_style='danger',
            layout=widgets.Layout(width='80px')
        )
        
        add_btn = widgets.Button(
            description='Add Below',
            button_style='info',
            layout=widgets.Layout(width='100px')
        )
        
        # Create model widget group
        model_widget = widgets.VBox([
            widgets.HBox([model_type, remove_btn, add_btn]),
            widgets.HBox([center_input, height_input, sigma_input])
        ])
        
        # Store model data
        model_data = {
            'widget': model_widget,
            'type': model_type,
            'center': center_input,
            'height': height_input,
            'sigma': sigma_input,
            'remove_btn': remove_btn,
            'add_btn': add_btn,
            'index': model_idx  # Store the index
        }
        
        # Set up button callbacks with proper index handling
        def remove_callback(b, idx=model_idx):
            self.remove_peak_model(idx)
            
        def add_callback(b, idx=model_idx):
            self.add_peak_model_below(idx)
        
        remove_btn.on_click(remove_callback)
        add_btn.on_click(add_callback)
        
        self.peak_models.append(model_data)
        self.update_models_container()
        
    def remove_peak_model(self, idx):
        """Remove a peak model"""
        if len(self.peak_models) > 1:  # Keep at least one model
            # Find model with matching index
            model_to_remove = None
            for i, model in enumerate(self.peak_models):
                if model.get('index', i) == idx:
                    model_to_remove = i
                    break
            
            if model_to_remove is not None:
                self.peak_models.pop(model_to_remove)
                self.update_models_container()
                self.reindex_models()
            
    def add_peak_model_below(self, idx):
        """Add a peak model below the current one"""
        # Find position in list
        insert_pos = 0
        for i, model in enumerate(self.peak_models):
            if model.get('index', i) == idx:
                insert_pos = i + 1
                break
        
        # Create new model
        self.add_peak_model()
        
        # Move it to correct position if not at the end
        if insert_pos < len(self.peak_models) - 1:
            new_model = self.peak_models.pop()
            self.peak_models.insert(insert_pos, new_model)
            self.update_models_container()
            
        self.reindex_models()
        
    def reindex_models(self):
        """Update model indices after add/remove operations"""
        for i, model in enumerate(self.peak_models):
            model['type'].description = f'Model {i + 1}:'
            model['index'] = i  # Update stored index
            
    def update_models_container(self):
        """Update the models container widget"""
        self.models_container.children = [model['widget'] for model in self.peak_models]
        
    def on_detect_peaks(self, button):
        """Auto-detect peaks in current spectrum"""
        if self.current_spectrum is None:
            return
            
        try:
            # Use more sensitive peak detection
            peaks_info = self.peak_detection.detect_peaks_advanced(
                self.wavelengths, 
                self.current_spectrum,
                min_height=np.max(self.current_spectrum) * 0.05,  # Lower threshold
                min_prominence=np.std(self.current_spectrum) * 2,  # Adaptive prominence
                min_distance=10,  # Minimum distance between peaks in indices
                adaptive_threshold=True
            )
            
            # Clear existing models
            self.peak_models.clear()
            
            # Add models for detected peaks
            for peak in peaks_info:
                self.add_peak_model()
                model = self.peak_models[-1]
                model['center'].value = peak['center']
                model['height'].value = peak['height']
                model['sigma'].value = max(peak['sigma'], 2.0)  # Minimum sigma of 2nm
                
            with self.status_output:
                self.status_output.clear_output()
                print(f"🔍 Detected {len(peaks_info)} peaks")
                for i, peak in enumerate(peaks_info):
                    print(f"  Peak {i+1}: {peak['center']:.1f} nm, height: {peak['height']:.0f}")
                    
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error detecting peaks: {str(e)}")

    def update_spectrum_plot(self, fit_result=None):
        """Update spectrum plot"""
        with self.spectrum_output:
            self.spectrum_output.clear_output()
            fig = self.visualization.create_spectrum_plot(
                self.wavelengths,
                self.current_spectrum,
                fit_result=fit_result
            )
            fig.show(renderer="jupyterlab")
                
    def on_fit_current(self, button):
        """Fit current spectrum"""
        if self.current_spectrum is None:
            return
            
        try:
            # Prepare fitting parameters
            fit_params = self.prepare_fit_parameters()
            
            # Perform fitting
            result = self.fitting_models.fit_spectrum(
                self.wavelengths,
                self.current_spectrum,
                fit_params
            )
            
            # Store the result for potential parameter update
            self.last_fit_result = result
            
            # Update R-squared display
            self.r_squared_display.value = f"R²: {result.rsquared:.4f}"
            
            # Enable update parameters button
            self.update_params_btn.disabled = False
            
            # Update spectrum plot with fit
            self.update_spectrum_plot(fit_result=result)
            
            with self.status_output:
                self.status_output.clear_output()
                print(f"✅ Fit completed. R² = {result.rsquared:.4f}")
                print("🔄 Click 'Update Parameters' to use fitted values")
                
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error fitting spectrum: {str(e)}")
                
    def on_update_parameters(self, button):
        """Update model parameters with fitted values"""
        try:
            if not hasattr(self, 'last_fit_result') or self.last_fit_result is None:
                with self.status_output:
                    self.status_output.clear_output()
                    print("❌ No fit result available. Please fit current spectrum first.")
                return
                
            result = self.last_fit_result
            
            # Extract fitted parameters
            fitted_params = result.params
            
            with self.status_output:
                self.status_output.clear_output()
                print("🔄 Updating peak model parameters with fitted values...")
                
            # Update peak model parameters in the UI fields
            updated_count = 0
            for param_name, param in fitted_params.items():
                if param_name.startswith('p') and '_' in param_name:
                    parts = param_name.split('_', 1)
                    try:
                        peak_num = int(parts[0][1:])  # Extract number from p0, p1, etc.
                        param_type = parts[1]
                        
                        if peak_num < len(self.peak_models):
                            model = self.peak_models[peak_num]
                            
                            if param_type == 'center':
                                model['center'].value = round(param.value, 2)
                                print(f"  Peak {peak_num + 1} center: {param.value:.2f} nm")
                                updated_count += 1
                                
                            elif param_type == 'amplitude':
                                # Convert amplitude to height estimate
                                sigma_param = f'p{peak_num}_sigma'
                                if sigma_param in fitted_params:
                                    sigma = fitted_params[sigma_param].value
                                    if sigma > 0:
                                        height = param.value / (sigma * np.sqrt(2 * np.pi))
                                        model['height'].value = round(height, 1)
                                        print(f"  Peak {peak_num + 1} height: {height:.1f}")
                                        updated_count += 1
                                else:
                                    # Fallback: use amplitude directly
                                    model['height'].value = round(param.value, 1)
                                    updated_count += 1
                                    
                            elif param_type == 'sigma':
                                model['sigma'].value = round(param.value, 2)
                                print(f"  Peak {peak_num + 1} sigma: {param.value:.2f} nm")
                                updated_count += 1
                                
                    except (ValueError, IndexError) as e:
                        print(f"  Warning: Could not parse parameter {param_name}: {e}")
                        continue
            
            with self.status_output:
                print(f"✅ Updated {updated_count} parameters in the Peak Models fields!")
                print("You can now refine the fit or run batch fitting with these optimized values.")
                
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error updating parameters: {str(e)}")
                import traceback
                print(traceback.format_exc())
                
    def on_export_results(self, button):
        """Export fitting results with visualizations as PNG files in a zip archive"""
        if not self.fitting_results:
            return
            
        try:
            import zipfile
            import tempfile
            import shutil
            import base64
            import io
            
            with self.export_output:
                self.export_output.clear_output()
                print("🔄 Exporting results...")
                
            # Create zip file in memory
            zip_buffer = io.BytesIO()
            
            with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file:
                # Export Excel file
                excel_filename = f"pl_fitting_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
                excel_buffer = io.BytesIO()
                
                # Create temporary file for Excel export
                temp_dir = tempfile.mkdtemp()
                temp_excel_path = os.path.join(temp_dir, excel_filename)
                
                try:
                    self.export_utils.export_to_excel(
                        self.fitting_results,
                        self.timestamps,
                        filename=temp_excel_path
                    )
                    
                    # Read Excel file and add to zip
                    with open(temp_excel_path, 'rb') as f:
                        zip_file.writestr(excel_filename, f.read())
                    
                    # Clean up temp directory
                    shutil.rmtree(temp_dir)
                    
                except Exception as e:
                    print(f"Error creating Excel file: {e}")
                    if os.path.exists(temp_dir):
                        shutil.rmtree(temp_dir)
                
                # Export PNG images of current plots
                png_files = []
                
                # Check if kaleido is available for PNG export
                try:
                    import plotly.io as pio
                    kaleido_available = True
                except ImportError:
                    kaleido_available = False
                
                # 1. Export heatmap if it exists
                if hasattr(self, 'data_matrix') and self.data_matrix is not None:
                    try:
                        heatmap_fig = self.visualization.create_heatmap(
                            self.data_matrix,
                            self.wavelengths,
                            self.timestamps,
                            current_time_idx=self.current_time_idx
                        )
                        
                        # Save as HTML (always works)
                        html_str = heatmap_fig.to_html(include_plotlyjs='cdn')
                        zip_file.writestr("heatmap.html", html_str)
                        
                        # Try to save as PNG if kaleido is available
                        if kaleido_available:
                            try:
                                img_bytes = pio.to_image(heatmap_fig, format='png', width=800, height=500)
                                zip_file.writestr("heatmap.png", img_bytes)
                                png_files.append("heatmap.png")
                            except Exception as img_err:
                                print(f"Note: Could not save heatmap as PNG: {img_err}")
                                
                    except Exception as e:
                        print(f"Error saving heatmap: {e}")
                        
                # 2. Export current spectrum if it exists
                if hasattr(self, 'current_spectrum') and self.current_spectrum is not None:
                    try:
                        fit_result = getattr(self, 'last_fit_result', None)
                        spectrum_fig = self.visualization.create_spectrum_plot(
                            self.wavelengths,
                            self.current_spectrum,
                            fit_result=fit_result
                        )
                        
                        # Save as HTML
                        html_str = spectrum_fig.to_html(include_plotlyjs='cdn')
                        zip_file.writestr("current_spectrum.html", html_str)
                        
                        # Try to save as PNG if kaleido is available
                        if kaleido_available:
                            try:
                                img_bytes = pio.to_image(spectrum_fig, format='png', width=800, height=600)
                                zip_file.writestr("current_spectrum.png", img_bytes)
                                png_files.append("current_spectrum.png")
                            except Exception as img_err:
                                print(f"Note: Could not save spectrum as PNG: {img_err}")
                                
                    except Exception as e:
                        print(f"Error saving current spectrum: {e}")
                        
                # 3. Export time series plots if they exist
                if hasattr(self, 'time_series_fig'):
                    try:
                        # Save as HTML
                        html_str = self.time_series_fig.to_html(include_plotlyjs='cdn')
                        zip_file.writestr("time_series_plots.html", html_str)
                        
                        # Try to save as PNG if kaleido is available
                        if kaleido_available:
                            try:
                                img_bytes = pio.to_image(self.time_series_fig, format='png', width=1000, height=600)
                                zip_file.writestr("time_series_plots.png", img_bytes)
                                png_files.append("time_series_plots.png")
                            except Exception as img_err:
                                print(f"Note: Could not save time series as PNG: {img_err}")
                                
                    except Exception as e:
                        print(f"Error saving time series plots: {e}")
                        
            # Create download link
            zip_buffer.seek(0)
            b64 = base64.b64encode(zip_buffer.getvalue()).decode()
            
            # Create filename
            zip_filename = f"pl_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
            
            # JavaScript to trigger download
            js_code = f"""
            var link = document.createElement('a');
            link.href = 'data:application/zip;base64,{b64}';
            link.download = '{zip_filename}';
            document.body.appendChild(link);
            link.click();
            document.body.removeChild(link);
            """
            
            with self.export_output:
                self.export_output.clear_output()
                print(f"✅ Results exported successfully!")
                print(f"📦 Zip file ready: {zip_filename}")
                print(f"📊 Contains: Excel file + HTML plots" + (f" + {len(png_files)} PNG images" if png_files else ""))
                if not kaleido_available:
                    print("📝 Note: PNG export not available. Install kaleido for PNG support: pip install kaleido")
                print()
                
                # Display download button
                download_button_html = f"""
                <button onclick="{js_code}" style="
                    background-color: #28a745;
                    color: white;
                    padding: 10px 20px;
                    border: none;
                    border-radius: 5px;
                    cursor: pointer;
                    font-size: 14px;
                    font-weight: bold;
                ">
                    📥 Download {zip_filename}
                </button>
                """
                
                display(HTML(download_button_html))
                print("Download initiated. If the download doesn't start automatically, click the button above.")
                
        except Exception as e:
            with self.export_output:
                self.export_output.clear_output()
                print(f"❌ Error exporting results: {str(e)}")
                import traceback
                print(traceback.format_exc())
                
    def on_fit_range(self, button):
        """Fit spectra in selected range"""
        if self.data_matrix is None:
            return
            
        try:
            start_idx = self.fit_start_idx.value
            end_idx = self.fit_end_idx.value
            
            if start_idx >= end_idx:
                with self.status_output:
                    self.status_output.clear_output()
                    print("❌ Start index must be less than end index")
                return
            
            if end_idx > len(self.timestamps):
                end_idx = len(self.timestamps)
                self.fit_end_idx.value = end_idx
            
            with self.status_output:
                self.status_output.clear_output()
                print(f"Starting range fitting from {start_idx} to {end_idx}...")
                
            # Prepare fitting parameters
            fit_params = self.prepare_fit_parameters()
            
            # Extract data for the range
            range_data = self.data_matrix[start_idx:end_idx, :]
            range_timestamps = self.timestamps[start_idx:end_idx]
            
            # Perform batch fitting on range
            results = self.fitting_models.fit_all_spectra(
                self.wavelengths,
                range_data,
                range_timestamps,
                fit_params
            )
            
            # Check what's in the results - sample a few
            success_count = 0
            error_count = 0
            sample_results = []
            
            for i, result in list(results.items())[:3]:  # Check first 3 results
                if result and result.get('success', False):
                    success_count += 1
                    sample_results.append(f"Success {i}: R²={result.get('r_squared', 'N/A')}, params={len(result.get('parameters', {}))}")
                elif result:
                    error_count += 1
                    sample_results.append(f"Failed {i}: {result.get('error', 'Unknown error')}")
                else:
                    error_count += 1
                    sample_results.append(f"None result for index {i}")
            
            # Count all successful results
            total_success = sum(1 for r in results.values() if r and r.get('success', False))
            
            print(f"Sample results:")
            for sample in sample_results:
                print(f"  {sample}")
            print(f"Total successful fits: {total_success} out of {len(results)}")
            
            # Adjust indices in results to match original data indices
            adjusted_results = {}
            for i, result in results.items():
                if result is not None:
                    result['index'] = start_idx + result['index']  # Adjust to original indexing
                    adjusted_results[start_idx + i] = result
                else:
                    adjusted_results[start_idx + i] = None
            
            self.fitting_results = adjusted_results
            
            # Enable export button
            self.export_btn.disabled = False
            
            # Create time series visualizations
            self.create_time_series_plots()
            
            with self.status_output:
                print(f"✅ Range fitting completed!")
                print(f"📊 Fitted {total_success} successful spectra from index {start_idx} to {end_idx-1}")
                
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error in range fitting: {str(e)}")
                import traceback
                print(traceback.format_exc())
                
    def on_fit_all(self, button):
        """Fit all spectra using parallel processing"""
        if self.data_matrix is None:
            return
            
        try:
            with self.status_output:
                self.status_output.clear_output()
                print("Starting batch fitting... This may take a while.")
                
            # Prepare fitting parameters
            fit_params = self.prepare_fit_parameters()
            
            # Perform batch fitting
            self.fitting_results = self.fitting_models.fit_all_spectra(
                self.wavelengths,
                self.data_matrix,
                self.timestamps,
                fit_params
            )
            
            # Enable export button
            self.export_btn.disabled = False
            
            # Create time series visualizations
            self.create_time_series_plots()
            
            with self.status_output:
                self.status_output.clear_output()
                print(f"✅ Batch fitting completed!")
                print(f"📊 Fitted {len(self.fitting_results)} spectra.")
                
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"❌ Error in batch fitting: {str(e)}")
                
    def prepare_fit_parameters(self):
        """Prepare fitting parameters from UI inputs"""
        params = {
            'background_model': 'None',  # No background models for now
            'poly_degree': 2,
            'peak_models': []
        }
        
        for model in self.peak_models:
            peak_params = {
                'type': model['type'].value,
                'center': model['center'].value,
                'height': model['height'].value,
                'sigma': model['sigma'].value
            }
            params['peak_models'].append(peak_params)
            
        return params
    
    def create_time_series_plots(self):
        """Create time series plots for center, FWHM, height, and area"""
        if not self.fitting_results:
            return
            
        try:
            # Extract data for plotting
            peak_params_df, _ = self.export_utils._create_peak_parameters_dataframe(self.fitting_results)
            
            if peak_params_df.empty:
                with self.time_series_output:
                    self.time_series_output.clear_output()
                    print("No successful fits to plot")
                return
            
            # Sort by time to ensure continuous lines
            peak_params_df = peak_params_df.sort_values('Time').reset_index(drop=True)
            
            # Identify peak columns
            peak_ids = []
            for col in peak_params_df.columns:
                if '_center' in col or '_amplitude' in col:
                    peak_id = col.split('_')[0]
                    if peak_id not in peak_ids:
                        peak_ids.append(peak_id)
            
            # Sort peak IDs for consistent ordering
            peak_ids = sorted(peak_ids, key=lambda x: int(x[1:]) if x[1:].isdigit() else 0)
            
            if not peak_ids:
                with self.time_series_output:
                    self.time_series_output.clear_output()
                    print("No peak parameters found to plot")
                return
            
            # Create subplots
            fig = make_subplots(
                rows=2, cols=2,
                subplot_titles=('Peak Centers vs Time', 'Peak Heights vs Time', 
                               'Peak FWHM vs Time', 'Peak Areas vs Time'),
                vertical_spacing=0.1,
                horizontal_spacing=0.1
            )
            
            colors = px.colors.qualitative.Plotly
            
            for i, peak_id in enumerate(peak_ids):
                color = colors[i % len(colors)]
                
                # Peak Centers
                center_col = f'{peak_id}_center'
                if center_col in peak_params_df.columns:
                    valid_data = peak_params_df.dropna(subset=[center_col]).sort_values('Time')
                    if not valid_data.empty:
                        fig.add_trace(go.Scatter(
                            x=valid_data['Time'],
                            y=valid_data[center_col],
                            mode='markers+lines',
                            name=f'{peak_id} Center',
                            line=dict(color=color, width=2),
                            marker=dict(size=4),
                            legendgroup=peak_id,
                            connectgaps=False
                        ), row=1, col=1)
                
                # Peak Heights
                height_col = f'{peak_id}_height'
                if height_col in peak_params_df.columns:
                    valid_data = peak_params_df.dropna(subset=[height_col]).sort_values('Time')
                    if not valid_data.empty:
                        fig.add_trace(go.Scatter(
                            x=valid_data['Time'],
                            y=valid_data[height_col],
                            mode='markers+lines',
                            name=f'{peak_id} Height',
                            line=dict(color=color, width=2),
                            marker=dict(size=4),
                            legendgroup=peak_id,
                            showlegend=False,
                            connectgaps=False
                        ), row=1, col=2)
                
                # Peak FWHM
                fwhm_col = f'{peak_id}_FWHM'
                if fwhm_col in peak_params_df.columns:
                    valid_data = peak_params_df.dropna(subset=[fwhm_col]).sort_values('Time')
                    if not valid_data.empty:
                        fig.add_trace(go.Scatter(
                            x=valid_data['Time'],
                            y=valid_data[fwhm_col],
                            mode='markers+lines',
                            name=f'{peak_id} FWHM',
                            line=dict(color=color, width=2),
                            marker=dict(size=4),
                            legendgroup=peak_id,
                            showlegend=False,
                            connectgaps=False
                        ), row=2, col=1)
                
                # Peak Areas
                area_col = f'{peak_id}_area'
                if area_col in peak_params_df.columns:
                    valid_data = peak_params_df.dropna(subset=[area_col]).sort_values('Time')
                    if not valid_data.empty:
                        fig.add_trace(go.Scatter(
                            x=valid_data['Time'],
                            y=valid_data[area_col],
                            mode='markers+lines',
                            name=f'{peak_id} Area',
                            line=dict(color=color, width=2),
                            marker=dict(size=4),
                            legendgroup=peak_id,
                            showlegend=False,
                            connectgaps=False
                        ), row=2, col=2)
            
            # Update layout
            fig.update_layout(
                title="Peak Parameters Evolution Over Time",
                height=600,
                template='plotly_white',
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                )
            )
            
            # Update axes labels
            fig.update_xaxes(title_text="Time (s)")
            fig.update_yaxes(title_text="Center (nm)", row=1, col=1)
            fig.update_yaxes(title_text="Height", row=1, col=2)
            fig.update_yaxes(title_text="FWHM (nm)", row=2, col=1)
            fig.update_yaxes(title_text="Area", row=2, col=2)
            
            # Store the figure for export
            self.time_series_fig = fig
            
            # Display the plot
            with self.time_series_output:
                self.time_series_output.clear_output()
                fig.show()
                
        except Exception as e:
            with self.time_series_output:
                self.time_series_output.clear_output()
                print(f"Error creating time series plots: {str(e)}")
                import traceback
                print(traceback.format_exc())
        """Export fitting results"""
        if not self.fitting_results:
            return
            
        try:
            # Export to Excel
            excel_file = self.export_utils.export_to_excel(
                self.fitting_results,
                self.timestamps,
                filename=f"pl_fitting_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
            )
            
            # Create summary plots
            self.export_utils.create_summary_plots(
                self.fitting_results,
                self.timestamps,
                output_dir=f"pl_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            )
            
            with self.status_output:
                self.status_output.clear_output()
                print(f"Results exported successfully!")
                print(f"Excel file: {excel_file}")
                
        except Exception as e:
            with self.status_output:
                self.status_output.clear_output()
                print(f"Error exporting results: {str(e)}")
                
    def update_visualizations(self):
        """Update both heatmap and spectrum visualizations"""

        self.current_spectrum = self.data_matrix[self.current_time_idx, :]
        self.time_display.value = f"Time: {self.timestamps[self.current_time_idx]:.3f}s"

        # Always update heatmap (for now, until line updates work reliably)
        self.update_heatmap()
        # Always update spectrum plot
        self.update_spectrum_plot()
        
    def update_heatmap(self):
        """Update heatmap plot - create and display"""
        with self.heatmap_output:
            self.heatmap_output.clear_output()

            # Always create/recreate the heatmap
            self.heatmap_fig = self.visualization.create_heatmap(
                self.data_matrix,
                self.wavelengths,
                self.timestamps,
                current_time_idx=self.current_time_idx
            )
            #return
            # Display the heatmap
            self.heatmap_fig.show(renderer="svg")
    
    def update_heatmap_line(self):
        """Update ONLY the position line on heatmap - with robust error handling"""
        if self.heatmap_fig is None:
            # If heatmap doesn't exist, create it
            self.update_heatmap()
            return
        
        try:
            # Try to update the line position
            if hasattr(self.visualization, 'update_heatmap_line_position'):
                # Update the figure
                updated_fig = self.visualization.update_heatmap_line_position(
                    self.heatmap_fig, 
                    self.timestamps, 
                    self.current_time_idx
                )
                
                # Only update if we got a valid figure back
                if updated_fig is not None:
                    self.heatmap_fig = updated_fig
                    
                    # Re-display the updated figure
                    with self.heatmap_output:
                        self.heatmap_output.clear_output()
                        self.heatmap_fig.show()
                else:
                    # If update failed, recreate
                    self.update_heatmap()
            else:
                # Method doesn't exist, recreate heatmap
                self.update_heatmap()
                
        except Exception as e:
            print(f"Line update failed: {e}")
            # Fallback: recreate the heatmap
            self.update_heatmap()
              
    def display_app(self):
        """Display the complete application interface"""
        # Create the main layout
        header = widgets.HTML("<h1>🔬 Photoluminescence Analysis App</h1>")

        # File upload section
        if self.data_stored_in_h5:
            file_section = widgets.VBox([
                widgets.HTML("<h3>📁 Data Loading</h3>"),
                self.mode_dropdown,
                self.status_output
            ])
        else:
            file_section = widgets.VBox([
                widgets.HTML("<h3>📁 Data Loading</h3>"),
                self.file_upload,
                self.status_output
            ])
        
        # Time control section
        time_section = widgets.VBox([
            widgets.HTML("<h3>⏱️ Time Control</h3>"),
            widgets.HBox([self.time_input, self.time_display]),
            self.time_slider
        ])
        
        # Fitting controls section - reorganized
        fitting_section = widgets.VBox([
            # Background subtraction (collapsible)
            self.background_accordion,
            
            # Peak detection and fitting (collapsible)
            self.peak_accordion,
            
            # Fitting actions
            widgets.HTML("<h3>🔧 Fitting Actions</h3>"),
            widgets.HBox([self.fit_current_btn, self.update_params_btn]),
            self.r_squared_display,
            
            # Batch fitting with range ABOVE actions
            widgets.HTML("<h3>📊 Batch Fitting</h3>"),
            widgets.HTML("<b>Range Selection:</b>"),
            widgets.HBox([self.fit_start_idx, self.fit_end_idx]),
            widgets.HTML("<b>Batch Actions:</b>"),
            widgets.HBox([self.fit_all_btn, self.fit_all_range_btn]),
            self.export_btn,
            self.export_output
        ], layout=widgets.Layout(width='420px'))
        
        # Visualization section - make it much wider
        viz_section = widgets.VBox([
            widgets.HTML("<h3>📊 Visualizations</h3>"),
            widgets.HTML("<h4>Heatmap</h4>"),
            self.heatmap_output,
            widgets.HTML("<h4>Current Spectrum</h4>"),
            self.spectrum_output,
            widgets.HTML("<h4>Time Series Analysis</h4>"),
            self.time_series_output
        ], layout=widgets.Layout(flex='2', width='auto'))
        
        # Control panel - fixed width
        control_panel = widgets.VBox([
            file_section,
            time_section,
            fitting_section
        ], layout=widgets.Layout(
            width='420px',
            padding='10px',
            border='1px solid #ddd',
            margin='0px 10px 0px 0px'
        ))
        
        # Main layout - use flex layout for responsive design
        main_layout = widgets.HBox([
            control_panel,
            viz_section
        ], layout=widgets.Layout(
            width='100%',
            height='auto'
        ))
        
        # Display everything
        display(header)
        display(main_layout)

# Create and display the app
app = PLAnalysisApp()
app.display_app()

no stored variable or alias h5_path


HTML(value='<h1>🔬 Photoluminescence Analysis App</h1>')

HBox(children=(VBox(children=(VBox(children=(HTML(value='<h3>📁 Data Loading</h3>'), FileUpload(value=(), accep…