In [3]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output
import io
import re
from typing import Dict, List, Tuple
import numpy as np
from scipy.signal import find_peaks, peak_widths
import sys
import os
import math

# Keep your existing API imports
sys.path.append(os.path.dirname(os.getcwd()))
from api_calls import get_batch_ids, get_ids_in_batch, get_sample_description, get_all_eqe as get_all_xrd
import batch_selection
import plotting_utils

# --- Authentication & API Setup (keep as is) ---
token = None
url_base = "https://nomad-hzb-se.de"
url = f"{url_base}/nomad-oasis/api/v1"

if 'NOMAD_CLIENT_ACCESS_TOKEN' in os.environ:
    token = os.environ['NOMAD_CLIENT_ACCESS_TOKEN']
else:
    try:
        import importlib.util
        import pathlib
        secrets_path = pathlib.Path(os.getcwd()).parent / 'secrets.py'
        spec = importlib.util.spec_from_file_location('secrets', str(secrets_path))
        secrets = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(secrets)
        token = getattr(secrets, 'TOKEN', None)
    except Exception:
        token = None
    if token is None:
        try:
            import access_token
            token = access_token.get_token(url)
        except Exception:
            token = None
if token is None:
    print("No token found in environment, secrets.py (one level above), or access_token. Please set one of these methods.")

# --- Utility: Get XRD Data from API (keep as is) ---
def get_xrd_data(try_sample_ids, variation):
    all_xrd = get_all_xrd(url, token, try_sample_ids, eqe_type="HySprint_XRD_XY")
    existing_sample_ids = pd.Series(all_xrd.keys())
    if len(existing_sample_ids) == 0:
        return None
    sample_dict = {}
    for sample_id, sample_data in all_xrd.items():
        for xrd_entry in sample_data:
            df = pd.DataFrame(xrd_entry[0]["data"])
            angle = df.iloc[:, 0].to_numpy()
            intensity = df.iloc[:, 1].to_numpy()
            sample_dict[sample_id] = {
                'angle': angle,
                'intensity': intensity,
                'variation': variation.get(sample_id, ''),
                'name': xrd_entry[0].get("name", '')
            }
    return sample_dict if sample_dict else None

# Fixed XRD Visualization Tool
class XRDVisualizationTool:
    def __init__(self, api_data=None):
        self.data_files = {}  # {filename: (x_data, y_data, metadata)}
        self.checkboxes = {}
        self.individual_plots = {}
        self.peak_controls = {}
        self.overlay_plot_output = widgets.Output()
        self.colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
        
        # Stagger control
        self.stagger_slider = widgets.FloatSlider(
            value=0.0, min=0.0, max=1000.0, step=10.0,
            description='Stagger Offset:', 
            style={'description_width': 'initial'}, 
            layout=widgets.Layout(width='400px')
        )
        self.stagger_slider.observe(self.update_overlay_plot, names='value')
        
        self.api_data = api_data
        self.setup_ui()

    def parse_xy_file(self, file_content: str, filename: str):
        """Parse the .xy file format and extract data and metadata"""
        lines = file_content.strip().split('\n')
        metadata = {}
        
        # Parse metadata from the first line
        if lines[0].startswith("'Id:"):
            metadata_line = lines[0].strip("'")
            pattern = r'(\w+):\s*"([^"]*)"'
            matches = re.findall(pattern, metadata_line)
            metadata = dict(matches)
        
        # Parse data points
        x_data, y_data = [], []
        for line in lines[1:]:
            if line.strip():
                try:
                    parts = line.split()
                    if len(parts) >= 2:
                        x = float(parts[0])
                        y = float(parts[1])
                        x_data.append(x)
                        y_data.append(y)
                except ValueError:
                    continue
        
        return x_data, y_data, metadata

    def find_peaks_in_data(self, x_data, y_data, height_threshold=None, prominence=None):
        """Find peaks in XRD data using scipy's find_peaks"""
        y_array = np.array(y_data)
        x_array = np.array(x_data)
        
        if height_threshold is None:
            height_threshold = np.max(y_array) * 0.1
        if prominence is None:
            prominence = np.max(y_array) * 0.05
        
        peaks, properties = find_peaks(y_array, height=height_threshold, prominence=prominence, distance=5)
        peak_positions = x_array[peaks].tolist()
        peak_intensities = y_array[peaks].tolist()
        
        return peaks, peak_positions, peak_intensities

    def create_individual_plot(self, filename, x_data, y_data, metadata):
        """Create individual plot widget"""
        fig = go.Figure()
        
        # Add main data trace
        fig.add_trace(go.Scatter(
            x=x_data,
            y=y_data,
            mode='lines',
            name='Data',
            line=dict(width=2, color='blue')
        ))
        
        # Create title
        if 'sample_id' in metadata:  # API data
            title = f"Sample: {filename}<br>Variation: {metadata.get('variation', '')}<br>Name: {metadata.get('name', '')}"
        else:  # File data
            title_parts = [f"File: {filename}"]
            if 'Id' in metadata:
                title_parts.append(f"ID: {metadata['Id']}")
            if 'Operator' in metadata:
                title_parts.append(f"Operator: {metadata['Operator']}")
            title = '<br>'.join(title_parts)
        
        fig.update_layout(
            title=title,
            xaxis_title='2θ (degrees)',
            yaxis_title='Intensity',
            width=700,
            height=450,
            showlegend=True
        )
        
        # Create FigureWidget
        plot_widget = go.FigureWidget(fig)
        return plot_widget

    def create_peak_controls(self, filename):
        """Create peak detection controls for a file"""
        x_data, y_data, metadata = self.data_files[filename]
        max_intensity = float(np.max(y_data)) if len(y_data) > 0 else 1.0
        
        # Create sliders with appropriate ranges
        default_height = max_intensity * 0.1
        default_prominence = max_intensity * 0.05
        
        height_slider = widgets.FloatSlider(
            value=default_height,
            min=0.0,
            max=max_intensity,
            step=max(1.0, max_intensity/200),
            description='Min Height:',
            style={'description_width': '80px'},
            layout=widgets.Layout(width='300px')
        )
        
        prominence_slider = widgets.FloatSlider(
            value=default_prominence,
            min=0.0,
            max=max_intensity,
            step=max(1.0, max_intensity/200),
            description='Prominence:',
            style={'description_width': '80px'},
            layout=widgets.Layout(width='300px')
        )
        
        show_peaks_checkbox = widgets.Checkbox(
            value=True,
            description='Show Peaks',
            style={'description_width': 'initial'}
        )
        
        peak_info_output = widgets.Output(layout=widgets.Layout(height='100px'))
        
        def update_peaks(change=None):
            self.update_individual_plot_peaks(filename, height_slider.value, prominence_slider.value, show_peaks_checkbox.value, peak_info_output)
        
        height_slider.observe(update_peaks, names='value')
        prominence_slider.observe(update_peaks, names='value')
        show_peaks_checkbox.observe(update_peaks, names='value')
        
        # Initial peak detection
        update_peaks()
        
        controls = widgets.VBox([
            widgets.HTML("<b>Peak Detection Controls:</b>"),
            widgets.HBox([height_slider, prominence_slider]),
            show_peaks_checkbox,
            widgets.HTML("<b>Detected Peaks:</b>"),
            peak_info_output
        ])
        
        return controls

    def update_individual_plot_peaks(self, filename, height_threshold, prominence, show_peaks, peak_info_output):
        """Update individual plot with peak detection"""
        x_data, y_data, metadata = self.data_files[filename]
        plot_widget = self.individual_plots[filename]
        
        # Clear and rebuild the plot
        with plot_widget.batch_update():
            plot_widget.data = []
            
            # Add main data trace
            plot_widget.add_scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name='Data',
                line=dict(width=2, color='blue')
            )
            
            # Find and add peaks if enabled
            peak_info_text = "No peaks detected"
            if show_peaks:
                peaks, peak_positions, peak_intensities = self.find_peaks_in_data(x_data, y_data, height_threshold, prominence)
                
                if len(peak_positions) > 0:
                    plot_widget.add_scatter(
                        x=peak_positions,
                        y=peak_intensities,
                        mode='markers',
                        name='Peaks',
                        marker=dict(color='red', size=8, symbol='triangle-up'),
                        hovertemplate='Peak at 2θ: %{x:.2f}°<br>Intensity: %{y:.1f}<extra></extra>'
                    )
                    
                    peak_info_lines = [f"Found {len(peak_positions)} peaks:"]
                    for i, (pos, intensity) in enumerate(zip(peak_positions, peak_intensities)):
                        peak_info_lines.append(f"Peak {i+1}: 2θ = {pos:.2f}°, I = {intensity:.1f}")
                    peak_info_text = "\n".join(peak_info_lines)
        
        # Update peak info
        with peak_info_output:
            clear_output(wait=True)
            print(peak_info_text)

    def set_stagger_slider_range(self):
        """Set appropriate range for stagger slider based on data"""
        max_intensity = 1.0
        for x_data, y_data, metadata in self.data_files.values():
            if len(y_data) > 0:
                max_intensity = max(max_intensity, float(np.max(y_data)))
        
        if max_intensity > 0:
            slider_max = 10 ** math.ceil(math.log10(max_intensity))
        else:
            slider_max = 1.0
        
        self.stagger_slider.max = slider_max
        self.stagger_slider.value = slider_max * 0.1
        self.stagger_slider.step = max(1.0, slider_max / 100)

    def load_api_data(self, api_data):
        """Load data from API"""
        self.data_files = {}
        self.checkboxes = {}
        self.individual_plots = {}
        self.peak_controls = {}
        
        for sample_id, info in api_data.items():
            x_data = info['angle']
            y_data = info['intensity']
            metadata = {
                'variation': info.get('variation', ''),
                'name': info.get('name', ''),
                'sample_id': sample_id
            }
            
            self.data_files[sample_id] = (x_data, y_data, metadata)
            
            # Create individual plot
            plot_widget = self.create_individual_plot(sample_id, x_data, y_data, metadata)
            self.individual_plots[sample_id] = plot_widget
            
            # Create checkbox
            checkbox = widgets.Checkbox(
                value=False,
                description=f'Include {sample_id}',
                style={'description_width': 'initial'}
            )
            checkbox.observe(self.update_overlay_plot, names='value')
            self.checkboxes[sample_id] = checkbox
            
            # Create peak controls
            peak_controls = self.create_peak_controls(sample_id)
            self.peak_controls[sample_id] = peak_controls
        
        self.set_stagger_slider_range()
        self.update_display()

    def on_file_upload(self, change):
        """Handle file upload"""
        uploaded_files = change['new']
        
        for file_obj in uploaded_files:
            filename = file_obj.name
            if filename.endswith('.xy'):
                file_content = bytes(file_obj.content).decode('utf-8')
                x_data, y_data, metadata = self.parse_xy_file(file_content, filename)
                
                self.data_files[filename] = (x_data, y_data, metadata)
                
                # Create individual plot
                plot_widget = self.create_individual_plot(filename, x_data, y_data, metadata)
                self.individual_plots[filename] = plot_widget
                
                # Create checkbox
                checkbox = widgets.Checkbox(
                    value=False,
                    description=f'Include {filename}',
                    style={'description_width': 'initial'}
                )
                checkbox.observe(self.update_overlay_plot, names='value')
                self.checkboxes[filename] = checkbox
                
                # Create peak controls
                peak_controls = self.create_peak_controls(filename)
                self.peak_controls[filename] = peak_controls
        
        self.set_stagger_slider_range()
        self.update_display()

    def update_overlay_plot(self, change=None):
        """Update the overlay plot based on selected checkboxes"""
        with self.overlay_plot_output:
            clear_output(wait=True)
            
            selected_files = [filename for filename, checkbox in self.checkboxes.items() if checkbox.value]
            
            if not selected_files:
                print("No files selected for overlay plot")
                return
            
            fig = go.Figure()
            stagger_offset = self.stagger_slider.value
            
            for i, filename in enumerate(selected_files):
                x_data, y_data, metadata = self.data_files[filename]
                
                # Apply stagger offset
                staggered_y_data = [y + (i * stagger_offset) for y in y_data]
                
                color = self.colors[i % len(self.colors)]
                display_name = filename
                if stagger_offset > 0:
                    display_name = f"{filename} (+{i * stagger_offset:.0f})"
                
                fig.add_trace(go.Scatter(
                    x=x_data,
                    y=staggered_y_data,
                    mode='lines',
                    name=display_name,
                    line=dict(color=color, width=2)
                ))
            
            title = 'Overlay Plot - Selected Files'
            if stagger_offset > 0:
                title += f' (Staggered by {stagger_offset})'
            
            fig.update_layout(
                title=title,
                xaxis_title='2θ (degrees)',
                yaxis_title='Intensity',
                width=900,
                height=600,
                showlegend=True,
                legend=dict(yanchor="top", y=0.99, xanchor="left", x=1.01)
            )
            
            # Display the plot
            display(go.FigureWidget(fig))

    def update_display(self):
        """Update the entire display with individual plots and checkboxes"""
        with self.main_output:
            clear_output(wait=True)
            
            if not self.data_files:
                print("No files uploaded yet. Please upload .xy files above or load data from API.")
                return
            
            print(f"Loaded {len(self.data_files)} files/samples:")
            print("=" * 50)
            
            for filename in self.data_files.keys():
                print(f"\n{filename}:")
                
                # Display checkbox
                checkbox = self.checkboxes[filename]
                display(checkbox)
                
                # Display individual plot
                plot = self.individual_plots[filename]
                display(plot)
                
                # Display peak controls
                peak_controls = self.peak_controls.get(filename)
                if peak_controls:
                    display(peak_controls)
                
                print("-" * 50)

    def setup_ui(self):
        """Set up the user interface"""
        # File upload widget
        self.file_upload = widgets.FileUpload(
            accept='.xy',
            multiple=True,
            description='Upload .xy files'
        )
        self.file_upload.observe(self.on_file_upload, names='value')
        
        # Main output area
        self.main_output = widgets.Output()
        
        # Instructions
        instructions = widgets.HTML("""
        <h2>XRD Data Visualization Tool</h2>
        <p><strong>Instructions:</strong></p>
        <ol>
            <li>Option 1: Select batches from the API above. The samples will appear below for visualization.</li>
            <li>Option 2: If you have local .xy files, upload them below to visualize.</li>
            <li>Use checkboxes to select samples for overlay, and adjust the stagger offset as needed.</li>
            <li>Each individual plot has peak detection controls that can be adjusted.</li>
        </ol>
        """)
        
        # Display the interface
        display(instructions)
        display(self.file_upload)
        display(self.main_output)
        
        # Overlay plot section
        stagger_section = widgets.HTML("""
        <h3>Overlay Plot Controls:</h3>
        <p>Adjust the stagger offset to vertically separate curves in the overlay plot:</p>
        """)
        
        print("\nOverlay Plot:")
        print("=" * 50)
        display(stagger_section)
        display(self.stagger_slider)
        display(self.overlay_plot_output)
        
        # Initial display
        self.update_display()

def launch_xrd_visualization(api_data=None):
    """Launch the XRD visualization tool"""
    tool = XRDVisualizationTool(api_data=api_data)
    if api_data:
        tool.load_api_data(api_data)
    return tool

# Global variable to hold the main tool instance
main_tool = None

# --- Callback: Load Data from Batch Selection ---
def on_load_data_clicked(batch_ids_selector):
    """Handle batch selection and load data"""
    global data, original_data, main_tool
    
    print("Loading Data...")
    try_sample_ids = get_ids_in_batch(url, token, batch_ids_selector.value)
    identifiers = get_sample_description(url, token, list(try_sample_ids))
    data = get_xrd_data(try_sample_ids, identifiers)
    
    if data is None:
        print("The batches selected don't contain any XRD measurements")
        return
    
    original_data = data.copy()
    print("Data Loaded Successfully!")
    
    # Load data into the existing main tool instance
    if main_tool is not None:
        main_tool.load_api_data(data)
    else:
        print("Error: Main tool not initialized")

# --- Batch Selection Widget with Optional Filtering ---
def create_batch_selection_with_optional_filtering():
    """Create batch selection widget with filtering option"""
    original_batch_widget = batch_selection.create_batch_selection(url, token, on_load_data_clicked)
    
    # Find the batch selector
    batch_selector = None
    for child in original_batch_widget.children:
        if isinstance(child, widgets.SelectMultiple):
            batch_selector = child
            break
    
    total_batches = len(batch_selector.options) if batch_selector else 0
    
    filter_button = widgets.Button(
        description=f"🔍 Filter to show only batches with XRD data",
        button_style='info',
        tooltip=f'Click to filter {total_batches} batches (this may take a few minutes)',
        layout=widgets.Layout(width='400px')
    )
    
    filter_status = widgets.Output()
    
    def start_filtering(b):
        filter_button.disabled = True
        filter_button.description = "🔄 Filtering in progress..."
        
        with filter_status:
            filter_status.clear_output(wait=True)
            print("Finding batches with XRD data...")
            
            batch_ids_list_tmp = list(get_batch_ids(url, token))
            all_batch_ids = []
            for batch in batch_ids_list_tmp:
                if "_".join(batch.split("_")[:-1]) in batch_ids_list_tmp:
                    continue
                all_batch_ids.append(batch)
            
            print(f"Testing {len(all_batch_ids)} batches...")
            valid_batches = []
            
            for i, batch_id in enumerate(all_batch_ids):
                if i % 10 == 0 or i == len(all_batch_ids) - 1:
                    filter_status.clear_output(wait=True)
                    print(f"Progress: {i+1}/{len(all_batch_ids)} - Found {len(valid_batches)} valid batches")
                    print(f"Currently testing: {batch_id}")
                
                try:
                    sample_ids = get_ids_in_batch(url, token, [batch_id])
                    if sample_ids:
                        identifiers = get_sample_description(url, token, list(sample_ids))
                        xrd_data = get_xrd_data(sample_ids, identifiers)
                        if xrd_data is not None:
                            valid_batches.append(batch_id)
                            filter_status.clear_output(wait=True)
                            print(f"✅ Found valid batch: {batch_id} ({len(xrd_data)} samples)")
                            print(f"Total found so far: {len(valid_batches)}")
                except:
                    continue
            
            # Update batch selector
            if batch_selector:
                batch_selector.options = valid_batches
            
            filter_status.clear_output(wait=True)
            print("=" * 60)
            print("FILTERING COMPLETE")
            print("=" * 60)
            print(f"✅ Found {len(valid_batches)} batches with XRD data out of {total_batches} total")
            
            if len(valid_batches) > 0:
                print(f"Valid batches: {valid_batches}")
            else:
                print("⚠️  No batches with XRD data found!")
            
            filter_button.description = f"✅ Filtering complete - {len(valid_batches)} valid batches found"
            filter_button.disabled = True
            
            # Add info to original widget
            info_html = widgets.HTML(
                value=f"<p><b>Showing {len(valid_batches)} of {total_batches} batches with confirmed XRD data</b></p>"
            )
            original_batch_widget.children = (info_html,) + original_batch_widget.children
    
    filter_button.on_click(start_filtering)
    
    complete_widget = widgets.VBox([
        widgets.HTML(f"<p>Select batches from all {total_batches} available batches, or use the filter button below:</p>"),
        filter_button,
        filter_status,
        original_batch_widget
    ])
    
    return complete_widget

# --- MAIN EXECUTION ---
def main():
    """Main function to set up and display the application"""
    global main_tool
    
    # Display manual if available
    try:
        display(plotting_utils.create_manual("eqe_manual.md"))
    except:
        pass
    
    # Create batch selection widget
    batch_widget = create_batch_selection_with_optional_filtering()
    display(batch_widget)
    
    # Launch visualization tool and store as global (can be used for file upload even without API data)
    main_tool = launch_xrd_visualization()
    return main_tool

# Global variables for data storage
data = None
original_data = None
main_tool = None

# Run the main application
if __name__ == "__main__":
    main_tool = main()

VBox(children=(ToggleButton(value=False, description='Manual'), Output()))

VBox(children=(HTML(value='<p>Select batches from all 109 available batches, or use the filter button below:</…

HTML(value='\n        <h2>XRD Data Visualization Tool</h2>\n        <p><strong>Instructions:</strong></p>\n   …

FileUpload(value=(), accept='.xy', description='Upload .xy files', multiple=True)

Output()


Overlay Plot:


HTML(value='\n        <h3>Overlay Plot Controls:</h3>\n        <p>Adjust the stagger offset to vertically sepa…

FloatSlider(value=0.0, description='Stagger Offset:', layout=Layout(width='400px'), max=1000.0, step=10.0, sty…

Output()