# NMR (Nuclear Magnetic Resonance) Analysis

## How to use this notebook:
1. Select batches to analyze (only batches of type "hysprint_batch" are considered)
2. The data will be loaded into a pandas DataFrame
3. Use the plotting tools to visualize your NMR spectra:
   - Create joint plots showing all NMR spectra overlaid
   - Chemical shift (ppm) on x-axis, intensity on y-axis
4. Access advanced features for data table viewing and statistics

In [None]:
%matplotlib ipympl
%load_ext autoreload
%autoreload 2
import os
import base64
import io
import time
import sys
import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import display, Markdown, HTML
import pandas as pd
import numpy as np
import json

sys.path.append(os.path.dirname(os.getcwd()))
from api_calls import get_ids_in_batch, get_sample_description, get_batch_ids,  get_all_eqe as get_all_nmr, get_all_batches_wth_data
import batch_selection
import access_token

url_base ="https://nomad-hzb-se.de"
url = f"{url_base}/nomad-oasis/api/v1"
token = access_token.get_token(url)
access_token.log_notebook_usage()

In [None]:
def create_nmr_plotting_interface():
    """
    üß™ UNIFIED NMR PLOTTING INTERFACE üé®
    
    This replaces the old separate buttons with a single, integrated interface.
    Colors update automatically when you change any color picker - no buttons needed!
    """
    global data
    
    if data is None:
        print("‚ùå No data available. Please load data first using the batch selection above.")
        return
    
    # Get unique samples
    unique_samples = data['sample_id'].unique()
    
    if len(unique_samples) == 0:
        print("‚ùå No samples found in data.")
        return
    
    # Create color pickers for each sample
    color_pickers = {}
    
    # Default colors (nice bright colors)
    default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
                      '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    
    # Create offset control widget
    offset_widget = widgets.FloatText(
        value=0.0,
        description='Y-axis Offset:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px'),
        tooltip='Vertical offset between spectra (0 = overlapping)'
    )
    
    print(f"üî¨ Setting up interface for {len(unique_samples)} NMR spectra...")
    
    for i, sample_id in enumerate(unique_samples):
        sample_data = data[data['sample_id'] == sample_id]
        sample_name = sample_data['variation'].iloc[0] if sample_data['variation'].iloc[0] else sample_id
        
        default_color = default_colors[i % len(default_colors)]
        color_pickers[sample_id] = widgets.ColorPicker(
            concise=False,
            description=f'{sample_name}:',
            value=default_color,
            disabled=False,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='320px', margin='2px')
        )
    
    # Create single output area for the plot
    plot_output = widgets.Output()
    
    # Function to create and update plot
    def update_plot():
        with plot_output:
            plot_output.clear_output()
            
            # Create plotly figure
            fig = go.Figure()
            
            # Plot each spectrum with current colors
            for i, sample_id in enumerate(unique_samples):
                
                sample_data = data[data['sample_id'] == sample_id]
                sample_name = sample_data['variation'].iloc[0] if sample_data['variation'].iloc[0] else sample_id
                
                # Extract the lists from the first row
                chemical_shift_list = np.array(sample_data['chemical_shift'].iloc[0])
                intensity_list = np.array(sample_data['intensity'].iloc[0])
                
                # Apply offset: first spectrum at y=0, second at offset, third at 2*offset, etc.
                offset_value = i * offset_widget.value
                intensity_with_offset = intensity_list + offset_value
                intensity_list_filter = intensity_list >= np.median(intensity_list)*20
                # Get current color
                current_color = color_pickers[sample_id].value
                
                fig.add_trace(go.Scatter(
                    x=chemical_shift_list[intensity_list_filter],
                    y=intensity_with_offset[intensity_list_filter],
                    mode='lines',
                    name=sample_name,
                    line=dict(width=2, color=current_color),
                    hovertemplate='<b>%{fullData.name}</b><br>' +
                                 'Chemical Shift: %{x:.2f} ppm<br>' +
                                 'Intensity: %{y:.2f}<br>' +
                                 f'Offset: {offset_value:.2f}<br>' +
                                 '<extra></extra>'
                ))
            
            # Update layout
            fig.update_layout(
                title='üß™ NMR Spectra with Custom Colors',
                xaxis_title='Chemical Shift (ppm)',
                yaxis_title='Intensity',
                xaxis=dict(autorange='reversed'),  # Typical for NMR
                hovermode='closest',
                legend=dict(
                    orientation="v",
                    yanchor="top",
                    y=1,
                    xanchor="left",
                    x=1.02
                ),
                width=1500,
                height=600
            )
            
            fig.show()
            print(f"‚úÖ Updated plot with {len(unique_samples)} spectra")
    
    # Connect color pickers and offset widget to auto-update plot
    for sample_id, color_picker in color_pickers.items():
        color_picker.observe(lambda change: update_plot(), names='value')
    
    offset_widget.observe(lambda change: update_plot(), names='value')
    
    # Layout widgets
    title_widget = widgets.HTML("""
        <h3>üß™ Unified NMR Plotting Interface</h3>
        <h4>üé® Select Colors and Offset (plot updates automatically when you change any setting):</h4>
    """)
    
    offset_controls = widgets.VBox([
        widgets.HTML("<h4>üìä Spectrum Positioning:</h4>"),
        offset_widget
    ], layout=widgets.Layout(
        border='1px solid #ddd',
        padding='10px',
        margin='5px'
    ))
    
    color_widgets = widgets.VBox(
        [widgets.HTML("<h4>üé® Spectrum Colors:</h4>")] + list(color_pickers.values()),
        layout=widgets.Layout(
            border='1px solid #ddd',
            padding='10px',
            margin='5px'
        )
    )
    
    # Display everything
    display(title_widget)
    display(offset_controls)
    display(color_widgets)
    display(plot_output)
    
    # Create initial plot
    update_plot()


In [None]:
from scipy.signal import find_peaks
import numpy as np

def create_single_spectrum_analyzer():
    """
    üîç SINGLE SPECTRUM ANALYZER WITH PEAK DETECTION üèîÔ∏è
    
    Select individual spectra from a dropdown and analyze peaks using scipy peak finder.
    Peaks are highlighted with red crosses on the plot.
    """
    global data
    
    if data is None:
        print("‚ùå No data available. Please load data first using the batch selection above.")
        return
    
    # Get unique samples
    unique_samples = data['sample_id'].unique()
    
    if len(unique_samples) == 0:
        print("‚ùå No samples found in data.")
        return
    
    # Create sample selection dropdown
    sample_options = []
    for sample_id in unique_samples:
        sample_data = data[data['sample_id'] == sample_id]
        sample_name = sample_data['variation'].iloc[0] if sample_data['variation'].iloc[0] else sample_id
        
        # Create label with both sample name and sample_id
        if sample_name and sample_name != sample_id:
            label = f"{sample_name} ({sample_id})"
        else:
            label = sample_id
        
        sample_options.append((label, sample_id))
    
    # Dropdown for sample selection
    sample_dropdown = widgets.Dropdown(
        options=sample_options,
        description='Select Sample:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='400px')
    )
    
    # Peak detection parameters
    height_slider = widgets.FloatSlider(
        value=0.1,
        min=0.01,
        max=1.0,
        step=0.01,
        description='Min Peak Height:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px')
    )
    
    
    # Color picker for spectrum
    color_picker = widgets.ColorPicker(
        concise=False,
        description='Spectrum Color:',
        value='#1f77b4',
        disabled=False,
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='300px')
    )
    
    # Integration range controls
    range_start = widgets.FloatText(
        value=0.0,
        description='Range Start (ppm):',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='200px')
    )
    
    range_end = widgets.FloatText(
        value=1.0,
        description='Range End (ppm):',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='200px')
    )
    
    integrate_button = widgets.Button(
        description='üßÆ Integrate Range',
        button_style='info',
        tooltip='Calculate integral over the specified chemical shift range',
        layout=widgets.Layout(width='200px')
    )
    
    # Output area for the plot
    plot_output = widgets.Output()
    
    # Status output for peak information
    status_output = widgets.Output()
    
    # Integration results output
    integration_output = widgets.Output()
    
    def calculate_integration():
        """Calculate the integral over the specified range"""
        with integration_output:
            integration_output.clear_output()

            selected_sample_id = sample_dropdown.value
            sample_data = data[data['sample_id'] == selected_sample_id]
            
            # Extract spectrum data
            chemical_shift_list = sample_data['chemical_shift'].iloc[0]
            intensity_list = sample_data['intensity'].iloc[0]
            
            # Convert to numpy arrays
            x_data = np.array(chemical_shift_list)
            y_data = np.array(intensity_list)
            
            # Define integration range
            start_ppm = min(range_start.value, range_end.value)
            end_ppm = max(range_start.value, range_end.value)
            
            # Find points within the integration range
            mask = (x_data >= start_ppm) & (x_data <= end_ppm)
            
            if np.any(mask):
                x_range = x_data[mask]
                y_range = y_data[mask]
                
                # Calculate integral using trapezoidal rule
                integral_value = -1*np.trapz(y_range, x_range)
                
                # Create results DataFrame
                integration_df = pd.DataFrame({
                    'Parameter': ["Start (ppm)", "End (ppm)", "Integral", "Avg Intensity"],
                    'Value': [f"{start_ppm:.2f}", 
                             f"{end_ppm:.2f}",
                             f"{integral_value:.4f}",
                             f"{np.mean(y_range):.4f}"]
                })
                
                print(f"üßÆ Integration Results:")
                display(integration_df)
                print(f"üìä Peak integral from {start_ppm:.2f} to {end_ppm:.2f} ppm = {integral_value:.4f}")
            else:
                print(f"‚ùå No data points found in range {start_ppm:.2f} - {end_ppm:.2f} ppm")
    
    # Connect integration button
    integrate_button.on_click(lambda b: calculate_integration())
    
    def update_single_plot():
        """Update the single spectrum plot with peak detection"""
        with plot_output:
            plot_output.clear_output()
            
            selected_sample_id = sample_dropdown.value
            sample_data = data[data['sample_id'] == selected_sample_id]
            sample_name = sample_data['variation'].iloc[0] if sample_data['variation'].iloc[0] else selected_sample_id
            
            # Create label
            if sample_name and sample_name != selected_sample_id:
                label = f"{sample_name} ({selected_sample_id})"
            else:
                label = selected_sample_id
            
            # Extract spectrum data
            chemical_shift_list = sample_data['chemical_shift'].iloc[0]
            intensity_list = sample_data['intensity'].iloc[0]
            
            # Convert to numpy arrays for peak detection
            x_data = np.array(chemical_shift_list)
            y_data = np.array(intensity_list)
            
            # Normalize intensity for peak detection
            y_normalized = (y_data - np.min(y_data)) / (np.max(y_data) - np.min(y_data))
            
            # Find peaks using scipy
            peaks, properties = find_peaks(
                y_normalized, 
                height=height_slider.value
            )
            
            # Create plotly figure
            fig = go.Figure()
            
            # Add spectrum trace
            fig.add_trace(go.Scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name=label,
                line=dict(width=2, color=color_picker.value),
                hovertemplate='<b>%{fullData.name}</b><br>' +
                             'Chemical Shift: %{x:.2f} ppm<br>' +
                             'Intensity: %{y:.2f}<br>' +
                             '<extra></extra>'
            ))
            
            # Add peak markers
            if len(peaks) > 0:
                peak_x = x_data[peaks]
                peak_y = y_data[peaks]
                
                fig.add_trace(go.Scatter(
                    x=peak_x,
                    y=peak_y,
                    mode='markers',
                    name=f'Peaks ({len(peaks)} found)',
                    marker=dict(
                        symbol='x',
                        size=8,
                        color='red'
                    ),
                    hovertemplate='<b>Peak</b><br>' +
                                 'Chemical Shift: %{x:.2f} ppm<br>' +
                                 'Intensity: %{y:.2f}<br>' +
                                 '<extra></extra>'
                ))
            
            # Add integration range visualization
            if range_start.value != range_end.value:
                # Find points within the integration range
                mask = (x_data >= min(range_start.value, range_end.value)) & (x_data <= max(range_start.value, range_end.value))
                if np.any(mask):
                    x_range = x_data[mask]
                    y_range = y_data[mask]
                    
                    # Add shaded area for integration range
                    fig.add_trace(go.Scatter(
                        x=np.concatenate([x_range, x_range[::-1]]),
                        y=np.concatenate([y_range, np.zeros_like(y_range)]),
                        fill='toself',
                        fillcolor='rgba(255, 255, 0, 0.3)',
                        line=dict(color='rgba(255, 255, 0, 0)'),
                        name='Integration Range',
                        hoverinfo='skip'
                    ))
            
            # Update layout
            fig.update_layout(
                title=f'üîç Single Spectrum Analysis: {label}',
                xaxis_title='Chemical Shift (ppm)',
                yaxis_title='Intensity',
                xaxis=dict(autorange='reversed'),  # Typical for NMR
                hovermode='closest',
                legend=dict(
                    orientation="v",
                    yanchor="top",
                    y=1,
                    xanchor="left",
                    x=1.02
                ),
                width=1500,
                height=600
            )
            
            fig.show()
        
        # Update status with peak information
        with status_output:
            status_output.clear_output()
            if len(peaks) > 0:
                peak_shifts = x_data[peaks]
                peak_intensities = y_data[peaks]
                
                # Create a pandas DataFrame for the peak table
                peak_df = pd.DataFrame({
                    'Peak #': range(1, len(peaks) + 1),
                    'Chemical Shift (ppm)': [f"{shift:.2f}" for shift in peak_shifts],
                    'Intensity': [f"{intensity:.2f}" for intensity in peak_intensities]
                })
                
                print(f"üèîÔ∏è Found {len(peaks)} peaks:")
                display(peak_df)
                
                # Also show a simple text summary
                print(f"\nüìã Summary: {len(peaks)} peaks detected between {peak_shifts.min():.2f} - {peak_shifts.max():.2f} ppm")
            else:
                print("‚ùå No peaks found with current parameters. Try adjusting the sliders.")
    
    # Connect widgets to update function
    sample_dropdown.observe(lambda change: update_single_plot(), names='value')
    height_slider.observe(lambda change: update_single_plot(), names='value')
    color_picker.observe(lambda change: update_single_plot(), names='value')
    range_start.observe(lambda change: update_single_plot(), names='value')
    range_end.observe(lambda change: update_single_plot(), names='value')
    
    # Layout widgets
    title_widget = widgets.HTML("""
        <h3>üîç Single Spectrum Analyzer with Peak Detection</h3>
        <p><em>Select a spectrum and adjust peak detection parameters. Peaks are marked with red crosses.</em></p>
    """)
    
    controls = widgets.VBox([
        widgets.HTML("<h4>üìã Sample Selection:</h4>"),
        sample_dropdown,
        widgets.HTML("<h4>üèîÔ∏è Peak Detection Parameters:</h4>"),
        widgets.HBox([height_slider]),
        widgets.HTML("<h4>üßÆ Integration Range:</h4>"),
        widgets.HBox([range_start, range_end, integrate_button]),
        widgets.HTML("<h4>üé® Appearance:</h4>"),
        color_picker
    ], layout=widgets.Layout(
        border='1px solid #ddd',
        padding='10px',
        margin='5px'
    ))
    
    # Display everything
    display(title_widget)
    display(controls)
    display(plot_output)
    display(widgets.HTML("<h4>üìä Peak Detection Results:</h4>"))
    display(status_output)
    display(widgets.HTML("<h4>üßÆ Integration Results:</h4>"))
    display(integration_output)
    
    # Create initial plot
    update_single_plot()



In [None]:
warning_sign = "\u26A0"

out = widgets.Output()
out2 = widgets.Output()
read = widgets.Output()
dynamic_content = widgets.Output()  # For dynamically updated content
results_content = widgets.Output(layout={
    # 'border': '1px solid black',  # Optional: adds a border to the widget
    'max_height': '1000px',  # Set the height
    'overflow': 'scroll',  # Adds a scrollbar if content overflows
    })
cell_edit = widgets.VBox() 

default_variables = widgets.Dropdown(
    options=['sample name', 'batch',"sample description", 'custom'],
    index=0,
    description='name preset:',
    disabled=False,
    tooltip="Presets for how the samples will be named in the plot"
)
data = None
original_data = None  # To store original data for filter reset


#this function takes sample ids and returns the eqe curves and parameters as Dataframes
def get_nmr_data(try_sample_ids, variation):
    #parameters of single eqe measurement
    #make api call, result has everything in json format
    all_nmr = get_all_nmr(url, token, try_sample_ids, eqe_type="HySprint_Simple_NMR")

    existing_sample_ids = pd.Series(all_nmr.keys())

    # Check if there's any EQE data
    if len(existing_sample_ids) == 0:
        return None  # Return None value to indicate no data

    sample_params_list = []
    for sample_id, sample_data in all_nmr.items():
        for nmr_entry in sample_data:
            row = [sample_id, variation.get(sample_id, ''), nmr_entry[0].get("name", ''), np.array(nmr_entry[0]["data"]["chemical_shift"]), np.array(nmr_entry[0]["data"]["intensity"])]
            df = pd.DataFrame([row], columns=["sample_id", "variation", "name", "chemical_shift", "intensity"])
            sample_params_list.append(df)

      
    # Only try to concatenate if there's data
    if sample_params_list:
        return pd.concat(sample_params_list)
    return None


def on_load_data_clicked(batch_ids_selector):
    #global dictionary to hold data
    global data, original_data
    dynamic_content.clear_output()
    with out:
        out.clear_output()
        print("Loading Data")

        try_sample_ids = get_ids_in_batch(url, token, batch_ids_selector.value)

        #extract NMR here
        identifiers = get_sample_description(url, token, list(try_sample_ids))
        data = get_nmr_data(try_sample_ids, identifiers)

        # Check if NMR data was found
        if data is None:
            out.clear_output()
            print("The batches selected don't contain any NMR measurements")
            return

        # Store original data for filter reset functionality
        original_data = data.copy()
        
        out.clear_output()
        print("Data Loaded")
        
        # Create and display plotting widgets once data is loaded
        with dynamic_content:
            dynamic_content.clear_output()
            create_nmr_plotting_interface()
            create_single_spectrum_analyzer()

# BATCH SELECTION WITH OPTIONAL FILTERING
# --- Batch Selection Widget with Optional Filtering ---
def create_batch_selection_with_optional_filtering():
    """Create batch selection widget with filtering option"""
    original_batch_widget = batch_selection.create_batch_selection(url, token, on_load_data_clicked)
    
    # Find the batch selector
    batch_selector = None
    for child in original_batch_widget.children:
        if isinstance(child, widgets.SelectMultiple):
            batch_selector = child
            break
    
    total_batches = len(batch_selector.options) if batch_selector else 0
    
    filter_button = widgets.Button(
        description=f"üîç Filter to show only batches with NMR data",
        button_style='info',
        tooltip=f'Click to filter {total_batches} batches (this may take a few minutes)',
        layout=widgets.Layout(width='400px')
    )
    
    filter_status = widgets.Output()
    
    def start_filtering(b):
        filter_button.disabled = True
        filter_button.description = "üîÑ Filtering in progress..."
        
        with filter_status:
            filter_status.clear_output(wait=True)
            print("Finding batches with NMR data...")
            
            batch_ids_list_tmp = list(get_batch_ids(url, token))
            all_batch_ids = []
            for batch in batch_ids_list_tmp:
                if "_".join(batch.split("_")[:-1]) in batch_ids_list_tmp:
                    continue
                all_batch_ids.append(batch)
            
            print(f"Testing {len(all_batch_ids)} batches...")
            valid_batches = get_all_batches_wth_data(url, token, "HySprint_Simple_NMR")

            
            # Update batch selector
            if batch_selector:
                batch_selector.options = valid_batches
            
            filter_status.clear_output(wait=True)
            print("=" * 60)
            print("FILTERING COMPLETE")
            print("=" * 60)
            print(f"‚úÖ Found {len(valid_batches)} batches with NMR data out of {total_batches} total")
            
            if len(valid_batches) > 0:
                print(f"Valid batches: {valid_batches}")
            else:
                print("‚ö†Ô∏è  No batches with NMR data found!")
            
            filter_button.description = f"‚úÖ Filtering complete - {len(valid_batches)} valid batches found"
            filter_button.disabled = True
            
            # Add info to original widget
            info_html = widgets.HTML(
                value=f"<p><b>Showing {len(valid_batches)} of {total_batches} batches with confirmed NMR data</b></p>"
            )
            original_batch_widget.children = (info_html,) + original_batch_widget.children
    
    filter_button.on_click(start_filtering)
    
    complete_widget = widgets.VBox([
        widgets.HTML(f"<p>Select batches from all {total_batches} available batches, or use the filter button below:</p>"),
        filter_button,
        filter_status,
        original_batch_widget
    ])
    
    return complete_widget


# Create and display the batch selection widget with optional filtering
batch_widget = create_batch_selection_with_optional_filtering()
display(batch_widget)

display(out)
display(dynamic_content)  # This will be updated dynamically with the variables menu