In [2]:
%%capture
%matplotlib widget
#!pip install requests_cache

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import time
import requests
import pandas as pd
import numpy as np
from datetime import datetime
import math
from scipy.optimize import curve_fit

import ipywidgets as widgets
from IPython.display import display, clear_output

import sys
sys.path.insert(1, '../python-scripts-c6fxKDJrSsWp1xCxON1Y7g')
sys.path.insert(1, '../../python-scripts-c6fxKDJrSsWp1xCxON1Y7g')
from api_calls import *
from periodic_table import *

url = 'https://nomad-hzb-ce.de/nomad-oasis/api/v1'

import os
token = os.environ['NOMAD_CLIENT_ACCESS_TOKEN']

In [3]:
def get_transmission_data(url, token):   
    query = {
        'required': {
            'data': '*',
        },
        'owner': 'visible',
        'query': {
            'entry_type:any': ['Bessy2_KMC3_XASTransmission', 'Bessy2_KMC2_XASTransmission']
        },
        'pagination': {
            'page_size': 10000
        }
    }
    response = requests.post(f'{url}/entries/archive/query',
                             headers={'Authorization': f'Bearer {token}'}, json=query)
    linked_data = response.json()['data']
    res = []
    for ldata in linked_data:
        data_dict = ldata.get('archive').get('data')
        data_dict['entry_id'] = ldata.get('entry_id')
        data_dict['upload_id'] = ldata.get('upload_id')
        res.append(data_dict)
    return res

def get_xas_entryids(url, token):   
    query = {
        'required': {
            'metadata': '*',
        },
        'owner': 'visible',
        'query': {
            'entry_type:any': [
                'Bessy2_KMC3_XASTransmission',
                'Bessy2_KMC3_XASFluorescence',
                'Bessy2_KMC2_XASTransmission',
                'Bessy2_KMC2_XASFluorescence',
            ]
        },
        'pagination': {
            'page_size': 10000
        }
    }
    response = requests.post(f'{url}/entries/archive/query',
                             headers={'Authorization': f'Bearer {token}'}, json=query)
    linked_data = response.json()['data']
    res = []
    for ldata in linked_data:
        res.append({
            'entry_name': ldata['archive']['metadata'].get('entry_name'),
            'entry_id': ldata.get('entry_id'),
            'upload_name': ldata['archive']['metadata'].get('upload_name'),
            'upload_id': ldata.get('upload_id'),
            'create_time': ldata['archive']['metadata'].get('upload_create_time'),
        })
    res_sorted = sorted(res, key=lambda x: x['create_time'], reverse=True)
    return res_sorted 

def link_xas_energy_shift(url, token, entry_id, energy_shift, ocv_link):   
    query = {
      'changes': [
          {
              'path': 'data/manual_energy_shift',
              'new_value': energy_shift,
              'action': 'upsert'
          },
          {
              'path': 'data/connected_measurements',
              'new_value': ocv_link,
              'action': 'upsert'
          }
      ]
    }
    response = requests.post(f'{url}/entries/{entry_id}/edit',
                             headers={'Authorization': f'Bearer {token}'}, json=query)
    res = response.json()
    return res

def gauss(x, A, x0, sigma):
    return A * np.exp(-(x - x0)**2 / (2 * sigma**2))
    
def get_first_derivative(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.shape != y.shape:
        raise ValueError('Can only compute derivative for axis with same length.')
    return np.gradient(y, x)

def get_gaussian_fit(df, peak_df, initial_peak_idx_guess, x_col_name, y_col_name):
    start = peak_df[x_col_name].iloc[0]
    end = peak_df[x_col_name].iloc[-1]
    sigma = end - start

    # TODO maybe find better min/max values here
    max_peak_height = np.max(peak_df[y_col_name])
    min_peak_height = np.min(peak_df[y_col_name])
    #max_peak_height = 230
    #min_peak_height = 0
    if df.loc[initial_peak_idx_guess, y_col_name] < 1:
        max_peak_height = 1
    
    p0 = [df.loc[initial_peak_idx_guess, y_col_name], df.loc[initial_peak_idx_guess, x_col_name], sigma] # peak height, peak position, peak width (A, x0, sigma)
    bounds = ([min_peak_height, start, 0],
              [max_peak_height, end, 1])  # min and max for p0        
    popt, pcov = curve_fit(gauss, peak_df[x_col_name], peak_df[y_col_name], p0=p0, bounds=bounds, maxfev=5000)
    return popt, pcov

def get_energy_gauss_df(df, peak_df, initial_peak_guess, x_col_name, y_col_name):
    popt, pcov = get_gaussian_fit(df, peak_df, initial_peak_guess, x_col_name, y_col_name)
    df.loc[:, f'{y_col_name}_edge_area'] = peak_df[y_col_name]
    df.loc[:, f'{y_col_name}_edge'] = popt[1]
    df.loc[:, f'{y_col_name}_gauss'] = gauss(peak_df[x_col_name], *popt)
    df.loc[:, f'{y_col_name}_gauss_complete'] = gauss(df[x_col_name], *popt)    
    return df

def get_edge_fit(energy, derivative_ref, estimated_peak_energy):
    df = pd.DataFrame({'energy': energy, 'd_ref': derivative_ref})
    df_edge_area = df.loc[(df['energy'] >= estimated_peak_energy - 0.005) & (df['energy'] <= estimated_peak_energy + 0.0025), ['energy', 'd_ref']]
    df_edge_area = df_edge_area.copy()
    if df_edge_area.empty:
        raise ValueError('Could not find a peak close to the selected E0.')
    peak_idx_guess = (df['energy'] >= estimated_peak_energy).idxmax()
    # fit gaussian to derivatives
    df = get_energy_gauss_df(df, df_edge_area, peak_idx_guess, 'energy', 'd_ref')
    return df['d_ref_edge'].iloc[0], df['d_ref_gauss_complete']

def get_energy_shift(energy, derivative_ref, e0, estimated_peak_energy):
    edge_energy, gauss_fit = get_edge_fit(energy, derivative_ref, estimated_peak_energy)
    return e0 - edge_energy, gauss_fit

def get_energy_plot(energy, derivative_ref, gauss_fit, energy_shift, E0):
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=energy,
        y=derivative_ref,
        mode='lines+markers',
        name='dµ/dE',
        line=dict(width=2),
        marker=dict(size=6)
    ))

    fig.add_trace(go.Scatter(
        x=energy,
        y=gauss_fit,
        mode='lines',
        name='gauss fit',
        line=dict(width=2, color='orange'),
        marker=dict(size=6)
    ))

    fig.add_trace(
        go.Scatter(
            x=[E0-energy_shift, E0-energy_shift],
            y=[min(derivative_ref), max(derivative_ref)],
            mode='lines',
            line=dict(color='orange', dash='dot'),
            name='fitted gauss peak',
        )
    )

    fig.add_trace(
        go.Scatter(
            x=[E0, E0],
            y=[min(derivative_ref), max(derivative_ref)],
            mode='lines',
            line=dict(color='red', dash='dot'),
            name=f'E0 = {E0} keV'
        )
    )
    
    # Clean and minimal layout
    fig.update_layout(
        title='First derivative of absorbance of the reference (dµ/dE)',
        xaxis_title='Energy (keV)',
        yaxis_title='dµ/dE (a.u.)',
        template='simple_white',  # Clean white background
        #showlegend=False,         # Hide legend if only one trace
        legend_title=dict(text=f'manual shift = {energy_shift*1000:.6f} eV'),
        margin=dict(l=40, r=40, t=50, b=40),
        height=400
    )
    
    # Optional: Remove grid lines for a minimalist look
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    return fig

def parse_datetime(dt_str):
    return datetime.fromisoformat(dt_str)

def format_datetime_string(dt_obj):
    return dt_obj.strftime('%d.%m.%Y %H:%M')

def get_ocv_link(upload_id, entry_id):
    return [f'../uploads/{upload_id}/archive/{entry_id}#data']

In [4]:
# widgets
ocp_selection_output = widgets.Output()

default_xas = {
    'link': None,
    'energy_shift': None,
}

selected_xas = default_xas.copy()   # without copy we change the default dict...

# Calibration of XAS with manual energy shift

This script helps you to set the `energy_shift` in multiple NOMAD entries depending on an XAS Transmission Measurement that is uploaded as a `Bessy2_KMC3_XASTransmission` or `Bessy2_KMC2_XASTransmission`. 

The `manual_energy_shift` is then automatically annotated within the selected NOMAD entries.

### 1) Select E0

In [5]:
default_E0 = 6.539
E0_widget = widgets.FloatText(
    value=default_E0,
    description=r'$E_{0}$ (keV)',
    step=0.001,
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

element = None
def on_click_periodic_table(symbol):
    global selected_xas
    k = k_edge_data.get(symbol)
    if k:
        E0_widget.value = k
    selected_xas = default_xas.copy()
    with ocp_selection_output:
        ocp_selection_output.clear_output()

gridbox = create_periodic_table(on_click_periodic_table)

display(gridbox)

GridBox(children=(Button(description='H', layout=Layout(width='40px'), style=ButtonStyle(button_color='#eee'))…

**If the absorption edge is not defined in the table above: Please select the $E_{0}$ you want to use.**

In [9]:
estimated_peak_energy_widget = widgets.FloatText(
    value=None,
    description='Estimated Peak Energy (keV) (setting this is optional)',
    step=0.001,
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

display(E0_widget)
display(estimated_peak_energy_widget)

FloatText(value=6.539, description='$E_{0}$ (keV)', layout=Layout(width='200px'), step=0.001, style=Descriptio…

FloatText(value=0.0, description='Estimated Peak Energy (keV) (setting this is optional)', layout=Layout(width…

### 2) Select XAS Transmission Measurement for Calibration

In [6]:
xas_list = get_transmission_data(url, token)
xas_list.sort(key=lambda entry: parse_datetime(entry['datetime']), reverse=True)

In [53]:
dropdown_options = [
    (f'{format_datetime_string(parse_datetime(entry['datetime']))} - {entry['data_file']}', i)
    for i, entry in enumerate(xas_list)
]

dropdown = widgets.Dropdown(
    options=dropdown_options,
    value=None,
    description='Select an XAS entry for calibration:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='60%')
)

def on_dropdown_change(change):
    global selected_xas
    if change['type'] == 'change' and change['name'] == 'value':
        idx = change['new']
        selected_entry = xas_list[idx]
        with ocp_selection_output:
            ocp_selection_output.clear_output()
            print('Calculate energy shift. This might take a moment...')
            #for key, value in selected_entry.items():
            #    print(f'{key}: {value}')
            if 'Bessy2_KMC3_XASTransmission' in selected_entry.get('m_def'):
                mu = selected_entry.get('absorbance_of_the_sample')
            elif 'Bessy2_KMC2_XASTransmission' in selected_entry.get('m_def'):
                mu = selected_entry.get('absorbance_of_the_reference')
            peak_guess = E0_widget.value if estimated_peak_energy_widget.value is None else estimated_peak_energy_widget.value
            try:
                selected_entry['derivative'] = get_first_derivative(selected_entry.get('energy'), mu)
                if np.isnan(selected_entry['derivative']).any():
                    print('Derivative contains nan values that will be removed. Please check selected µ and energy.')
                    energy_array = np.array(selected_entry.get('energy'), dtype=float)
                    mu_array = np.array(mu, dtype=float)
                    mask = np.isfinite(energy_array) & np.isfinite(mu_array)
                    energy_clean = energy_array[mask]
                    mu_clean = mu_array[mask]
                    selected_entry['derivative'] = get_first_derivative(energy_clean, mu_clean)
                    energy_shift, gauss_fit = get_energy_shift(energy_clean, selected_entry.get('derivative'), E0_widget.value, peak_guess)
                else:
                    energy_shift, gauss_fit = get_energy_shift(selected_entry.get('energy'), selected_entry.get('derivative'), E0_widget.value, peak_guess)
                selected_xas['energy_shift'] = energy_shift
            except ValueError as e:
                print(e)
                print('Could not calculate the energy shift from the given E0 and selected XAS entry.')
                selected_xas = default_xas.copy()
                return
            fig = get_energy_plot(selected_entry.get('energy'), selected_entry.get('derivative'), gauss_fit, energy_shift, E0_widget.value)
            selected_xas['link'] = get_ocv_link(selected_entry.get('upload_id'), selected_entry.get('entry_id'))
            ocp_selection_output.clear_output()
            print('Selected entry:')
            if selected_entry.get('quality_annotation') == 'ICR out of bounds':
                print('Please check the ICR values of the selected entry. Some ICR values are not in the recommended bounds [0; 250000].')
            fig.show()
            print(f'Calculated energy shift: {energy_shift} keV')
            print(f'The energy shift is calculated as the difference of the selected E0 and the first peak of the derivative of µ.')

dropdown.observe(on_dropdown_change)

display(dropdown, ocp_selection_output)

Dropdown(description='Select an XAS entry for calibration:', layout=Layout(width='60%'), options=(('19.07.2025…

Output(outputs=({'name': 'stdout', 'text': 'Selected entry:\n', 'output_type': 'stream'}, {'output_type': 'dis…

### 3) In which entries would you like to use this calibration?

Hold Shift and click to select multiple entries.

In [55]:
xas_entry_list = get_xas_entryids(url, token)

options = [
    (f'{item.get('upload_name', '--no name given--')} - {item['entry_name']} - {item['entry_id']}', item['entry_id'])
    for item in xas_entry_list
]

multi_select_entries = widgets.SelectMultiple(
    options=options,
    description='Entries:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='75%', height='150px')
)

# Ausgabe Widget für Anzeige der Auswahl
entry_selection_output = widgets.Output()

def on_selection_change(change):
    with entry_selection_output:
        entry_selection_output.clear_output()
        selected_ids = change['new']
        if selected_ids:
            print('Selected Entry IDs:')
            for sid in selected_ids:
                print(sid)
        else:
            print('No Uploads selected.')

multi_select_entries.observe(on_selection_change, names='value')

display(multi_select_entries, entry_selection_output)

SelectMultiple(description='Entries:', layout=Layout(height='150px', width='75%'), options=(('XAS Workflow Exa…

Output()

### 4) Apply energy shift

The button below will associate all selected NOMAD entries with the calculated manual energy shift from the XAS Transmission entry.  

**Please note that this calibration process is not easily reversible.** If you have 'connected_experiments' in your NOMAD entries these will be overwritten.

In [56]:
calibration_output = widgets.Output()

button = widgets.Button(
    description='Link XAS Transmission entry and calculated shift for calibration',
    button_style='info',
    layout=widgets.Layout(width='auto'),
)

def on_button_click(b):
    with calibration_output:
        calibration_output.clear_output()
        if selected_xas.get('energy_shift') is None:
            print('Please calculate an energy shift before using this button!')
            return
        print('Please wait for the "All entries updated. DONE." at the bottom')
        for entry_id in multi_select_entries.value:
            link_xas_energy_shift(url, token, entry_id, selected_xas.get('energy_shift'), selected_xas.get('link'))
            print(f'Use calibration of {selected_xas.get('energy_shift')} keV in NOMAD entry {entry_id}')
        print('All entries updated. DONE.')

button.on_click(on_button_click)

# Anzeigen
display(button, calibration_output)

Button(button_style='info', description='Link XAS Transmission entry and calculated shift for calibration', la…

Output()