In [2]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output
import io
import re
from typing import Dict, List, Tuple
import numpy as np
from scipy.signal import find_peaks, peak_widths

class XRDVisualizationTool:
    def __init__(self, url=None, token=None):
        # Try to get URL and token from global scope if not provided
        if url is None:
            try:
                self.url = globals().get('url', None)
            except:
                self.url = None
        else:
            self.url = url
            
        if token is None:
            try:
                self.token = globals().get('token', None)
            except:
                self.token = None
        else:
            self.token = token
        self.data_files = {}  # Store parsed data: {filename: (x_data, y_data, metadata)}
        self.checkboxes = {}  # Store checkboxes for each file
        self.individual_plots = {}  # Store individual plot widgets
        self.overlay_plot_output = widgets.Output()
        
        # Color cycle for overlay plot (all solid lines)
        self.colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
        
        # Stagger control
        self.stagger_slider = widgets.FloatSlider(
            value=0.0,
            min=0.0,
            max=1000.0,
            step=10.0,
            description='Stagger Offset:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        self.stagger_slider.observe(self.update_overlay_plot, names='value')
        
        self.setup_ui()
    
    def get_xrd_data_from_api(self, sample_ids: List[str]) -> Dict:
        """
        Fetch XRD data from API for given sample IDs
        Uses the same API pattern as your other tools
        """
        if not self.url or not self.token:
            print("Warning: API URL and token not available. Check your global variables.")
            return {}
        
        try:
            # Use the same pattern as get_all_abspl but for XRD
            # You'll need to replace this with your actual XRD API function
            # For now, using a placeholder that follows your pattern
            xrd_data = get_all_xrd(self.url, self.token, sample_ids, measurement_type="XRD_Measurement")
            return xrd_data
        except NameError:
            print("❌ get_all_xrd function not found. Please make sure you have the XRD API function imported.")
            return {}
        except Exception as e:
            print(f"❌ Error fetching XRD data from API: {e}")
            return {}
    
    def parse_api_xrd_data(self, api_data: Dict) -> Dict:
        """
        Parse XRD data from API response and convert to the same format as file data
        """
        parsed_data = {}
        
        for sample_id, measurements in api_data.items():
            for i, measurement in enumerate(measurements):
                try:
                    # Extract XRD data - modify based on your API structure
                    xrd_results = measurement[0]["results"]
                    
                    # Assuming API returns 2theta and intensity data
                    if "two_theta" in xrd_results and "intensity" in xrd_results:
                        x_data = xrd_results["two_theta"]
                        y_data = xrd_results["intensity"]
                        
                        # Create metadata
                        metadata = {
                            'Id': sample_id,
                            'Measurement': f"{i+1}",
                            'Source': 'API'
                        }
                        
                        # Add additional metadata if available
                        if "name" in measurement[0]:
                            metadata['Name'] = measurement[0]["name"]
                        
                        # Create unique filename for this measurement
                        filename = f"{sample_id}_XRD_{i+1}.xy"
                        parsed_data[filename] = (x_data, y_data, metadata)
                        
                except Exception as e:
                    print(f"Error parsing XRD data for sample {sample_id}: {e}")
                    continue
        
        return parsed_data
    
    def load_data_from_batches(self, batch_ids: List[str]):
        """
        Load XRD data from selected batches via API
        """
        if not self.url or not self.token:
            print("❌ API credentials not available. Cannot load data from batches.")
            return
        
        print(f"Loading XRD data from {len(batch_ids)} batches...")
        
        try:
            # Get sample IDs from batches (same as your other tools)
            all_sample_ids = []
            for batch_id in batch_ids:
                try:
                    sample_ids = get_ids_in_batch(self.url, self.token, [batch_id])
                    all_sample_ids.extend(sample_ids)
                except Exception as e:
                    print(f"Warning: Could not get samples from batch {batch_id}: {e}")
                    continue
            
            if not all_sample_ids:
                print("❌ No sample IDs found in selected batches.")
                return
            
            print(f"Found {len(all_sample_ids)} samples. Fetching XRD data...")
            
            # Fetch XRD data from API
            api_data = self.get_xrd_data_from_api(all_sample_ids)
            
            if not api_data:
                print("❌ No XRD data found for these samples.")
                print("Note: Make sure you have a 'get_all_xrd' function or modify the API call in get_xrd_data_from_api()")
                return
            
            # Parse and add the data
            parsed_data = self.parse_api_xrd_data(api_data)
            
            if not parsed_data:
                print("❌ Could not parse XRD data from API response.")
                return
            
            # Add parsed data to existing data
            for filename, (x_data, y_data, metadata) in parsed_data.items():
                self.add_data_to_tool(filename, x_data, y_data, metadata)
            
            print(f"✅ Successfully loaded {len(parsed_data)} XRD measurements from API.")
            self.update_display()
            
        except Exception as e:
            print(f"❌ Error loading data from batches: {e}")
    
    def add_data_to_tool(self, filename: str, x_data: List[float], y_data: List[float], metadata: Dict):
        """
        Add data to the tool (whether from file upload or API)
        """
        # Store the data
        self.data_files[filename] = (x_data, y_data, metadata)
        
        # Create initial plot widget
        fig = go.Figure()
        
        # Create title with metadata info
        title_parts = [f"File: {filename}"]
        if 'Id' in metadata:
            title_parts.append(f"ID: {metadata['Id']}")
        if 'Operator' in metadata:
            title_parts.append(f"Operator: {metadata['Operator']}")
        if 'Source' in metadata:
            title_parts.append(f"Source: {metadata['Source']}")
        
        fig.update_layout(
            title='<br>'.join(title_parts),
            xaxis_title='2θ (degrees)',
            yaxis_title='Intensity',
            width=700,
            height=450,
            showlegend=True
        )
        
        # Store the plot widget
        plot_widget = go.FigureWidget(fig)
        self.individual_plots[filename] = plot_widget
        
        # Create checkbox
        checkbox = widgets.Checkbox(
            value=False,
            description=f'Include {filename}',
            style={'description_width': 'initial'}
        )
        checkbox.observe(self.update_overlay_plot, names='value')
        self.checkboxes[filename] = checkbox
        
        # Create peak detection controls (this will also create the initial plot)
        peak_controls = self.create_peak_controls(filename)
        self.peak_controls = getattr(self, 'peak_controls', {})
        self.peak_controls[filename] = peak_controls
    
    def parse_xy_file(self, file_content: str, filename: str) -> Tuple[List[float], List[float], Dict]:
        """Parse the .xy file format and extract data and metadata"""
        lines = file_content.strip().split('\n')
        
        # Parse metadata from the first line
        metadata = {'Source': 'File Upload'}
        if lines and lines[0].startswith("'Id:"):
            # Remove quotes and parse key-value pairs
            metadata_line = lines[0].strip("'")
            # Use regex to find key-value pairs
            pattern = r'(\w+):\s*"([^"]*)"'
            matches = re.findall(pattern, metadata_line)
            parsed_metadata = dict(matches)
            metadata.update(parsed_metadata)
        
        # Parse data points (skip metadata line if present)
        x_data = []
        y_data = []
        
        start_line = 1 if lines and lines[0].startswith("'Id:") else 0
        
        for line in lines[start_line:]:
            if line.strip():  # Skip empty lines
                try:
                    parts = line.split()
                    if len(parts) >= 2:
                        x = float(parts[0])
                        y = float(parts[1])
                        x_data.append(x)
                        y_data.append(y)
                except ValueError:
                    continue  # Skip lines that can't be parsed
        
        return x_data, y_data, metadata
    
    def find_peaks_in_data(self, x_data: List[float], y_data: List[float], 
                          height_threshold: float = None, prominence: float = None) -> Tuple[List[int], List[float], List[float]]:
        """Find peaks in XRD data using scipy's find_peaks"""
        y_array = np.array(y_data)
        x_array = np.array(x_data)
        
        # Set default parameters if not provided
        if height_threshold is None:
            height_threshold = np.max(y_array) * 0.1  # 10% of max intensity
        if prominence is None:
            prominence = np.max(y_array) * 0.05  # 5% of max intensity
        
        # Find peaks
        peaks, properties = find_peaks(y_array, 
                                     height=height_threshold, 
                                     prominence=prominence,
                                     distance=5)  # Minimum distance between peaks
        
        # Get peak positions and intensities
        peak_positions = x_array[peaks].tolist()
        peak_intensities = y_array[peaks].tolist()
        
        return peaks, peak_positions, peak_intensities
    
    def create_peak_controls(self, filename: str) -> widgets.VBox:
        """Create peak detection controls for a file"""
        x_data, y_data, metadata = self.data_files[filename]
        max_intensity = max(y_data) if y_data else 1000
        
        # Peak detection parameters with dynamic ranges
        height_slider = widgets.FloatSlider(
            value=max_intensity * 0.1,
            min=0.1,
            max=max_intensity,
            step=max_intensity * 0.01,
            description='Min Height:',
            style={'description_width': '80px'},
            layout=widgets.Layout(width='300px')
        )
        
        prominence_slider = widgets.FloatSlider(
            value=max_intensity * 0.05,
            min=0.1,
            max=max_intensity * 0.5,
            step=max_intensity * 0.005,
            description='Prominence:',
            style={'description_width': '80px'},
            layout=widgets.Layout(width='300px')
        )
        
        show_peaks_checkbox = widgets.Checkbox(
            value=True,
            description='Show Peaks',
            style={'description_width': 'initial'}
        )
        
        peak_info_output = widgets.Output(layout=widgets.Layout(height='100px'))
        
        # Function to update peaks when sliders change
        def update_peaks(change=None):
            self.update_individual_plot_peaks(filename, height_slider.value, 
                                            prominence_slider.value, show_peaks_checkbox.value, 
                                            peak_info_output)
        
        # Observe slider changes
        height_slider.observe(update_peaks, names='value')
        prominence_slider.observe(update_peaks, names='value')
        show_peaks_checkbox.observe(update_peaks, names='value')
        
        # Initial peak detection
        update_peaks()
        
        # Create controls layout
        controls = widgets.VBox([
            widgets.HTML(f"<b>Peak Detection Controls for {filename}:</b>"),
            widgets.HBox([height_slider, prominence_slider]),
            show_peaks_checkbox,
            widgets.HTML("<b>Detected Peaks:</b>"),
            peak_info_output
        ])
        
        return controls
    
    def update_individual_plot_peaks(self, filename: str, height_threshold: float, 
                                    prominence: float, show_peaks: bool, peak_info_output: widgets.Output):
        """Update individual plot with peak detection"""
        x_data, y_data, metadata = self.data_files[filename]
        
        # Get the existing plot widget
        plot_widget = self.individual_plots[filename]
        
        # Clear existing traces
        with plot_widget.batch_update():
            plot_widget.data = []
            
            # Add main data trace
            plot_widget.add_scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name='Data',
                line=dict(width=2, color='blue')
            )
            
            # Find and add peaks if enabled
            peak_info_text = "No peaks detected"
            if show_peaks:
                peaks, peak_positions, peak_intensities = self.find_peaks_in_data(
                    x_data, y_data, height_threshold, prominence
                )
                
                if len(peak_positions) > 0:
                    # Add peak markers
                    plot_widget.add_scatter(
                        x=peak_positions,
                        y=peak_intensities,
                        mode='markers',
                        name='Peaks',
                        marker=dict(
                            color='red',
                            size=8,
                            symbol='triangle-up'
                        ),
                        hovertemplate='Peak at 2θ: %{x:.2f}°<br>Intensity: %{y:.1f}<extra></extra>'
                    )
                    
                    # Create peak info text
                    peak_info_lines = [f"Found {len(peak_positions)} peaks:"]
                    for i, (pos, intensity) in enumerate(zip(peak_positions, peak_intensities)):
                        peak_info_lines.append(f"Peak {i+1}: 2θ = {pos:.2f}°, I = {intensity:.1f}")
                    peak_info_text = "\n".join(peak_info_lines)
        
        # Update peak info output
        with peak_info_output:
            clear_output(wait=True)
            print(peak_info_text)
    
    def on_file_upload(self, change):
        """Handle file upload and create individual plots"""
        uploaded_files = change['new']
        
        # Handle the uploaded files - they come as a tuple of file objects
        for file_obj in uploaded_files:
            filename = file_obj.name
            if filename.endswith('.xy'):
                try:
                    # Parse the file - convert memoryview to bytes then to string
                    file_content = bytes(file_obj.content).decode('utf-8')
                    x_data, y_data, metadata = self.parse_xy_file(file_content, filename)
                    
                    # Add data to tool
                    self.add_data_to_tool(filename, x_data, y_data, metadata)
                    
                except Exception as e:
                    print(f"Error processing file {filename}: {e}")
        
        self.update_display()
    
    def create_batch_selector(self):
        """Create batch selection interface for API data loading"""
        if not self.url or not self.token:
            return widgets.HTML("<p>❌ API credentials not available. Make sure 'url' and 'token' are defined in your notebook.</p>")
        
        try:
            # Use the same functions as your other tools
            batch_ids_list_tmp = list(get_batch_ids(self.url, self.token))
            batch_ids_list = []
            for b in batch_ids_list_tmp:
                if "_".join(b.split("_")[:-1]) in batch_ids_list_tmp:
                    continue
                batch_ids_list.append(b)
            
            if not batch_ids_list:
                return widgets.HTML("<p>❌ No batches found in the database.</p>")
            
            # Create batch selector
            batch_selector = widgets.SelectMultiple(
                options=batch_ids_list,
                description='Select Batches:',
                layout=widgets.Layout(width='400px', height='200px')
            )
            
            # Create load button
            load_button = widgets.Button(
                description='Load XRD Data from Batches',
                button_style='primary'
            )
            
            # Status output
            status_output = widgets.Output()
            
            # Load function
            def on_load_click(b):
                if batch_selector.value:
                    with status_output:
                        status_output.clear_output(wait=True)
                        self.load_data_from_batches(list(batch_selector.value))
                else:
                    with status_output:
                        status_output.clear_output(wait=True)
                        print("Please select at least one batch.")
            
            load_button.on_click(on_load_click)
            
            return widgets.VBox([
                widgets.HTML("<h3>Load XRD Data from API:</h3>"),
                widgets.HTML(f"<p>Found {len(batch_ids_list)} batches in database</p>"),
                batch_selector,
                load_button,
                status_output
            ])
            
        except NameError as e:
            return widgets.HTML(f"<p>❌ Missing API function: {e}. Make sure you have imported get_batch_ids and related functions.</p>")
        except Exception as e:
            return widgets.HTML(f"<p>❌ Error creating batch selector: {e}</p>")
    
    def update_overlay_plot(self, change=None):
        """Update the overlay plot based on selected checkboxes"""
        with self.overlay_plot_output:
            clear_output(wait=True)
            
            # Get selected files
            selected_files = [filename for filename, checkbox in self.checkboxes.items() if checkbox.value]
            
            if not selected_files:
                print("No files selected for overlay plot")
                return
            
            # Create overlay plot
            fig = go.Figure()
            
            # Get stagger offset
            stagger_offset = self.stagger_slider.value
            
            for i, filename in enumerate(selected_files):
                x_data, y_data, metadata = self.data_files[filename]
                
                # Apply stagger offset - each subsequent curve is offset upward
                staggered_y_data = [y + (i * stagger_offset) for y in y_data]
                
                # Use only colors - all lines are solid
                color = self.colors[i % len(self.colors)]
                
                # Create display name with offset info if staggered
                display_name = filename
                if stagger_offset > 0:
                    display_name = f"{filename} (+{i * stagger_offset:.0f})"
                
                fig.add_trace(go.Scatter(
                    x=x_data,
                    y=staggered_y_data,
                    mode='lines',
                    name=display_name,
                    line=dict(
                        color=color,
                        width=2
                    )
                ))
            
            # Update layout
            title = 'Overlay Plot - Selected Files'
            if stagger_offset > 0:
                title += f' (Staggered by {stagger_offset})'
            
            fig.update_layout(
                title=title,
                xaxis_title='2θ (degrees)',
                yaxis_title='Intensity',
                width=900,
                height=600,
                showlegend=True,
                legend=dict(
                    yanchor="top",
                    y=0.99,
                    xanchor="left",
                    x=1.01
                )
            )
            
            display(go.FigureWidget(fig))
    
    def update_display(self):
        """Update the entire display with individual plots and checkboxes"""
        with self.main_output:
            clear_output(wait=True)
            
            if not self.data_files:
                print("No data loaded yet. Please upload .xy files or load data from API.")
                return
            
            print(f"Loaded {len(self.data_files)} XRD measurements:")
            print("=" * 50)
            
            # Display individual plots with checkboxes
            for filename in self.data_files.keys():
                print(f"\n{filename}:")
                
                # Get components
                checkbox = self.checkboxes[filename]
                plot = self.individual_plots[filename]
                peak_controls = getattr(self, 'peak_controls', {}).get(filename)
                
                # Display checkbox
                display(checkbox)
                
                # Display plot
                display(plot)
                
                # Display peak controls if available
                if peak_controls:
                    display(peak_controls)
                
                print("-" * 50)
    
    def setup_ui(self):
        """Set up the user interface"""
        # Instructions
        instructions = widgets.HTML("""
        <h2>XRD Data Visualization Tool</h2>
        <p><strong>Instructions:</strong></p>
        <ol>
            <li><strong>Upload Files:</strong> Upload .xy files using the upload button below</li>
            <li><strong>Load from API:</strong> Select batches and load XRD data from the database</li>
            <li><strong>Individual Plots:</strong> Each measurement will be displayed with peak detection controls</li>
            <li><strong>Overlay Plot:</strong> Check boxes next to plots you want to overlay</li>
            <li><strong>Stagger Control:</strong> Use the stagger offset slider to vertically separate curves</li>
        </ol>
        """)
        
        # File upload widget
        self.file_upload = widgets.FileUpload(
            accept='.xy',
            multiple=True,
            description='Upload .xy files'
        )
        self.file_upload.observe(self.on_file_upload, names='value')
        
        # Main output area for individual plots
        self.main_output = widgets.Output()
        
        # Display the interface
        display(instructions)
        
        # Create tabs for different data sources
        file_upload_tab = widgets.VBox([
            widgets.HTML("<h3>Upload XRD Files:</h3>"),
            self.file_upload
        ])
        
        api_tab = self.create_batch_selector()
        
        # Create tabs
        data_source_tabs = widgets.Tab([file_upload_tab, api_tab])
        data_source_tabs.set_title(0, 'File Upload')
        data_source_tabs.set_title(1, 'API Data')
        
        display(data_source_tabs)
        display(self.main_output)
        
        # Stagger control section
        stagger_section = widgets.HTML("""
        <h3>Overlay Plot Controls:</h3>
        <p>Adjust the stagger offset to vertically separate curves in the overlay plot:</p>
        """)
        
        print("\nOverlay Plot:")
        print("=" * 50)
        display(stagger_section)
        display(self.stagger_slider)
        display(self.overlay_plot_output)
        
        # Initial display
        self.update_display()

# Create and run the tool
# Usage: tool = XRDVisualizationTool(url=your_api_url, token=your_api_token)
# Or without API: tool = XRDVisualizationTool()
tool = XRDVisualizationTool()

HTML(value='\n        <h2>XRD Data Visualization Tool</h2>\n        <p><strong>Instructions:</strong></p>\n   …

Tab(children=(VBox(children=(HTML(value='<h3>Upload XRD Files:</h3>'), FileUpload(value=(), accept='.xy', desc…

Output()


Overlay Plot:


HTML(value='\n        <h3>Overlay Plot Controls:</h3>\n        <p>Adjust the stagger offset to vertically sepa…

FloatSlider(value=0.0, description='Stagger Offset:', layout=Layout(width='400px'), max=1000.0, step=10.0, sty…

Output()