# 1 Select Batches
only batches of type "hysprint_batch" are considered

In [8]:
%matplotlib ipympl
%load_ext autoreload
%autoreload 2
import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import display, Markdown, HTML
import os
import pandas as pd
from api_calls import get_batch_ids, get_ids_in_batch, get_sample_description, get_all_eqe
import numpy as np
import json
import batch_selection
import plotting_utils
url_base ="https://nomad-hzb-se.de"
url = f"{url_base}/nomad-oasis/api/v1"
token = os.environ['NOMAD_CLIENT_ACCESS_TOKEN']

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
warning_sign = "\u26A0"

out = widgets.Output()
out2 = widgets.Output()
read = widgets.Output()
dynamic_content = widgets.Output()  # For dynamically updated content
results_content = widgets.Output(layout={
    # 'border': '1px solid black',  # Optional: adds a border to the widget
    'max-height': '1000px',  # Set the height
    'overflow': 'scroll',  # Adds a scrollbar if content overflows
    })
cell_edit = widgets.VBox() 

default_variables = widgets.Dropdown(
    options=['subbatches', 'batches',"sample description", 'custom'],
    index=0,
    description='preset:',
    disabled=False,
)

#widget group for selecting cells in a sample
class cellSelector(widgets.widget_box.VBox):
    def __init__(self, sample_id, default, cell_box):
        self.sample_id = sample_id
        self.sample_id_text = widgets.Label(value=sample_id, layout={'width': '200px'})
        self.count_text = widgets.Label(layout={'width': '100px'})
        
        item_split = sample_id.split("&")
        batch, variable = "", sample_id
        if len(item_split) >=2:
            batch, variable = item_split[0], "&".join(item_split[1:])
        if default == "batches":
            default_value = batch if batch else "_".join(sample_id.split("_")[:-1])
        elif default == "subbatches":
            default_value = variable
        elif default == "sample description":
            default_value = data["properties"].loc[sample_id, "description"]
        else:
            default_value = ""
        self.text_input = widgets.Text(value=default_value, placeholder='Name in Plot', layout={'width': '300px'})
        
        self.display_all_button = widgets.Button(description="show all cells", layout={'width': '100px'})
        self.display_none_button = widgets.Button(description="show none", layout={'width': '100px'})
        self.edit_curves_button = widgets.Button(description="expand options", layout={'width': '100px'})
        self.display_all_button.on_click(self.select_all)
        self.display_none_button.on_click(self.disselect_all)
        self.edit_curves_button.on_click(self.expand_options)
        
        super().__init__([widgets.HBox([self.sample_id_text, self.count_text]), 
                          self.text_input, 
                          widgets.HBox([self.display_all_button, self.display_none_button, self.edit_curves_button])])
        
        self.select_individual_cells = []
        self.name_defaults= []
        self.name_individual_cells = []
        for i in data["params"].loc[sample_id].index:
            current_select_box=widgets.Checkbox(description = data["entries"].loc[(sample_id, i[0]), "entry_names"] + " " + str(i[1]), value=True)
            current_select_box.observe(self.update_count,"value")
            self.select_individual_cells.append(current_select_box)
            self.name_individual_cells.append(widgets.Text(placeholder="Name"))
            self.name_defaults.append(data["entries"].loc[(sample_id, i[0]), "entry_names"].removeprefix(sample_id) +" "+ str(i[1]))
        
        self.individual_widget_list = [widgets.HBox([self.select_individual_cells[i],self.name_individual_cells[i]]) for i in range(len(self.name_individual_cells))]
                                              
        #box for containing the widgets for editing individual curve names and visibility
        self.edit_box = cell_box
        
        #initialize value for the counter text
        self.update_count(None)
    def get_name(self):
        if not self.text_input.value.strip():
            return self.sample_id
        else:
            return self.text_input.value
    def get_cell_selection(self):
        return [cell.value for cell in self.select_individual_cells]#, index=data["params"].loc[self.sample_id].index)
    def get_curve_names(self):
        name_list = []
        for i, text_field in enumerate(self.name_individual_cells):
            if text_field.value.strip():
                name_list.append(text_field.value)
            else:
                name_list.append(self.name_defaults[i])
        return name_list
    
    def select_all(self,b):
        for button in self.select_individual_cells:
            button.value=True
    def disselect_all(self,b):
        for button in self.select_individual_cells:
            button.value=False
    def expand_options(self,b):
        self.edit_box.children = self.individual_widget_list
    def update_count(self,change):
        self.count_text.value=f"{self.get_cell_selection().count(True)}/{len(self.select_individual_cells)} shown"

def create_widgets_table(elements_list):
    rows=[]
    selectors_dict={}
    for sample_id in elements_list:
        select = cellSelector(sample_id, default_variables.value, cell_edit)
        rows.append(select)
        selectors_dict[sample_id]=select
    return widgets.VBox(rows), selectors_dict

#this function takes sample ids and returns the eqe curves and parameters as Dataframes
def get_eqe_data(try_sample_ids):
    #parameters of single eqe measurement
    eqe_params_names = ["light_bias","bandgap_eqe","integrated_jsc","integrated_j0rad","voc_rad","urbach_energy","urbach_energy_fit_std_dev"]
    #make api call, result has everything in json format
    all_eqe = get_all_eqe(url, token, try_sample_ids)
    with open('data.json', 'w', encoding='utf-8') as f:
        json.dump(all_eqe,f)

    existing_sample_ids = pd.Series(all_eqe.keys())    
    eqe_curves_list = []
    eqe_params_list = []
    description_list = []
    for sample_data in all_eqe:
        entry_names_list = []
        entry_description_list = []
        sample_curves_list = []
        sample_params_list = []
        for eqe_entry in all_eqe.get(sample_data):
            current_entry_eqe_curves = []
            for measurement in eqe_entry[0]["eqe_data"]:
                current_entry_eqe_curves.append(pd.DataFrame(measurement, columns = ["photon_energy_array", "wavelength_array", "eqe_array"]))
            sample_curves_list.append(pd.concat(current_entry_eqe_curves,keys=np.arange(len(current_entry_eqe_curves))))
            sample_params_list.append(pd.DataFrame(eqe_entry[0]["eqe_data"], columns=eqe_params_names))
            
            entry_names_list.append(eqe_entry[0]["name"])
            entry_description_list.append(eqe_entry[0]["description"])
        eqe_curves_list.append(pd.concat(sample_curves_list, keys=np.arange(len(sample_curves_list))))
        eqe_params_list.append(pd.concat(sample_params_list, keys=np.arange(len(sample_curves_list))))
        description_list.append(pd.DataFrame({"entry_names":entry_names_list, "entry_description":entry_description_list}))
            
        #unify all measurements of a single sample into a dataframe, put these frames of different samples in a list
        #resulting frame for curves has nested index structure sample/eqe_entry/curve/datapoint, string/string/int/int, params lack last layer
    return pd.concat(eqe_curves_list, keys=existing_sample_ids), pd.concat(eqe_params_list, keys=existing_sample_ids), existing_sample_ids, pd.concat(description_list, keys=existing_sample_ids)

def on_load_data_clicked(batch_ids_selector):
    #global dictionary to hold data
    global data
    data = {}
    dynamic_content.clear_output()
    with out:
        out.clear_output()
        print("Loading Data")

        try_sample_ids = get_ids_in_batch(url, token, batch_ids_selector.value)
        
        #extract EQE here
        data["curves"], data["params"], data["sample_ids"], data["entries"] = get_eqe_data(try_sample_ids)
        
        identifiers = get_sample_description(url, token, list(data["sample_ids"]))
        data["params"].loc[:,"plot"]=False
        data["params"].loc[:,"name"]=""
        
        #subbatch_col = list(data["sample_ids"].copy().apply(lambda x: x.split('/')[-1].split('.')[0].split('_')[-2]))
        #batch_col = list(data["sample_ids"].copy().apply(lambda x: x.split('/')[-1].split('.')[0].split('_')[-3]))
        #identifier_col = list(data["sample_ids"].copy().apply(lambda x: x.split('/')[-1].split(".")[0]))
        #"subbatch":subbatch_col,"batch":batch_col,"identifier":identifier_col,
        data["properties"]=pd.DataFrame({"description":pd.Series(identifiers),"name":pd.Series()})
        
        out.clear_output()
        print("Data Loaded")
    make_variables_menu(data["sample_ids"])
    return

def on_confirm_clicked(selectors_dict):
    name_dict = {}
    read.clear_output()
    for item, selector_widget in selectors_dict.items():
        #print(item, text_widget.value)
        name_dict[item] = selector_widget.get_name()
        #print(selector_widget.get_cell_selection())
        data["params"].loc[item, "plot"] = selector_widget.get_cell_selection()
        data["params"].loc[item, "name"] = selector_widget.get_curve_names()
    data["properties"]["name"] = pd.Series(name_dict)
    
    data["curves"].to_csv("eqe_curve.csv")
    data["params"].to_csv("eqe_params.csv")
    data["properties"].to_csv("eqe_properties.csv")
    data["entries"].to_csv("eqe_entries.csv")
    
    with read:
        print("Variables loaded")

def make_variables_menu(sample_ids):
    variables_markdown = f"""
# 1a Add variable names
{len(sample_ids)} samples have been found.
Enter the name of the samples that should be used in the plot.
Curves with the same name will be grouped together
""" 
    #results_markdown = brief_data_summary(data['jvc'])
    with dynamic_content:
        display(results_content)
        display(Markdown(variables_markdown))
        display(default_variables)
        widgets_table, selectors_dict = create_widgets_table(sample_ids)
        retrieve_button = widgets.Button(description="Confirm variables", button_style='primary')
        retrieve_button.on_click(lambda b: on_confirm_clicked(selectors_dict))
        display(widgets.HBox([widgets_table, cell_edit]))
        button_group = widgets.HBox([retrieve_button, read])
        display(button_group)
    
    #results_html = widgets.HTML(value=f"<div>{results_markdown}</div>")
    create_overview_table(results_content)
        #display(Markdown(results_markdown))
    with read:
        read.clear_output()
        print(f"{warning_sign} Variables not loaded")
    return

def on_change_default_variables(b):
    dynamic_content.clear_output()
    make_variables_menu(data["sample_ids"])

def create_overview_table(output_widget):
    columns = pd.MultiIndex.from_product([["bandgap_eqe","integrated_jsc","integrated_j0rad","voc_rad","urbach_energy","light_bias"],["min", "mean","mean std", "max"]])
    overview = pd.DataFrame(columns=columns)
    for index in columns:
        for sid in data["sample_ids"]:
            if index[1]=="min":
                overview.loc[sid, index]=data["params"].loc[sid, index[0]].min()
            elif index[1]=="mean":
                overview.loc[sid, index]=data["params"].loc[sid, index[0]].mean()
            elif index[1]=="max":
                overview.loc[sid, index]=data["params"].loc[sid, index[0]].max()
            elif index[1]=="mean std":
                overview.loc[sid, index]=data["params"].loc[sid, index[0]].std()
        #add statisitcs for entire table
        if index[1]=="min":
            overview.loc["All Data", index]=data["params"].loc[:, index[0]].min()
        elif index[1]=="mean":
            overview.loc["All Data", index]=data["params"].loc[:, index[0]].mean()
        elif index[1]=="max":
            overview.loc["All Data", index]=data["params"].loc[:, index[0]].max()
        elif index[1]=="mean std":
            overview.loc["All Data", index]=data["params"].loc[:, index[0]].std()

    with output_widget, pd.option_context('display.float_format', '{:,.2e}'.format):
        output_widget.clear_output() 
        display(HTML(overview.to_html()))
        display(HTML(data["params"].to_html(columns=["bandgap_eqe","integrated_jsc","integrated_j0rad","voc_rad","urbach_energy","light_bias"], justify="left", border=1)))
    return

default_variables.observe(on_change_default_variables,names=['value'])

# Bind the 'Load Data' button click event
button.on_click(on_load_data_clicked)

# Bind the Search function to changes in the search field
search_field.observe(on_search_enter, "value")

# Display the initial UI components
load_group = widgets.HBox([button, out])

display(batch_selection.create_batch_selection(url, token, on_load_data_clicked))
display(out)
display(dynamic_content)  # This will be updated dynamically with the variables menu

VBox(children=(Text(value='', description='Search Batch'), SelectMultiple(description='Batches', layout=Layout…

Output()

Output()

In [10]:
#set styling template
import plotly.io as pio
from bisect import bisect
import itertools
import scipy

template = pio.templates["plotly_white"]
template.data.scatter = [go.Scatter(line_color=color) for color in px.colors.qualitative.Vivid]

#create looping iterator for setting colors
color_iterator = itertools.cycle(px.colors.qualitative.Vivid)

# EQE curve plots

In [11]:
curve_out = widgets.Output(overflow="scroll")
intervals = widgets.Checkbox(description="group curves with same name", indent=False, value=True)

def update_curve_plot(b):
    with curve_out:
        axis_title, column_name = unit_selector.value
        layout = go.Layout(
            width=curve_options.width.value,
            height=curve_options.height.value,
            xaxis={"title":{"text":axis_title}},
            yaxis={"title":{"text":"external quantum efficiency"}},
            template=template
            )
        curve_out.clear_output()
        figure = go.Figure(layout=layout)
        
        if intervals.value:
            #Dictionary with every unique given name as index, contains list of all curves that have this given name
            data_organized_by_given_name = {}
            for sample_id in data["sample_ids"]:
                sample_name = data["properties"].loc[sample_id,"name"]
                samples_filtered =data["params"].loc[sample_id].loc[data["params"].loc[sample_id]["plot"]]  
                for i in samples_filtered.index:
                    name = curve_options.name.value(sample_name, samples_filtered.loc[i,"name"])
                    if name not in data_organized_by_given_name.keys():
                        data_organized_by_given_name[name]=[]
                    data_organized_by_given_name[name].append(data["curves"].loc[(sample_id, *i),:])
            
            #print(data_organized_by_given_name)
            for name, curve_list in data_organized_by_given_name.items():
                #get minimum and maximum energies/wavelengths
                if unit_selector.value[1] == "wavelength_array":
                    max_x=max([curve.loc[curve.index[0],"wavelength_array"] for curve in curve_list])
                    min_x=min([curve.loc[curve.index[-1],"wavelength_array"] for curve in curve_list])
                else:
                    max_x=max([curve.loc[curve.index[-1],"photon_energy_array"] for curve in curve_list])
                    min_x=min([curve.loc[curve.index[0],"photon_energy_array"] for curve in curve_list])
                
                xcoords = np.linspace(min_x, max_x, 500)
                
                if unit_selector.value[1] == "wavelength_array":
                    #order of datapoints is flipped to get wavelengths in ascending order
                    interpolated_curves = pd.DataFrame(map(lambda curve : np.interp(xcoords, 
                                                                                    curve.loc[::-1,"wavelength_array"], 
                                                                                    curve.loc[::-1,"eqe_array"], 
                                                                                    left=np.nan, right=np.nan), 
                                                           curve_list))
                else:
                    interpolated_curves = pd.DataFrame(map(lambda curve : np.interp(xcoords, 
                                                                                    curve.loc[:,"photon_energy_array"], 
                                                                                    curve.loc[:,"eqe_array"], 
                                                                                    left=np.nan, right=np.nan), 
                                                           curve_list))
                
                eqe_curve_stats = pd.DataFrame([interpolated_curves.mean(), 
                                                interpolated_curves.std(), 
                                                interpolated_curves.median(), 
                                                interpolated_curves.quantile(q=0.25,interpolation='linear'),
                                                interpolated_curves.quantile(q=0.75,interpolation='linear')
                                               ],
                                               index = ["mean","std","median","lower_quartile","upper_quartile"]
                                              )
                
                #Plot the results
                color = next(color_iterator)
                if standart_deviation_area.value:
                    figure.add_scatter(x=np.concatenate([xcoords,xcoords[::-1]]), #[indices, reversed indices]
                                       y= pd.concat([eqe_curve_stats.loc["mean",:] + eqe_curve_stats.loc["std",:], 
                                                     eqe_curve_stats.loc["mean",::-1] - eqe_curve_stats.loc["std",::-1]]), #[mean+std, mean-std in reversed order]
                                       line_color='rgba(255,255,255,0)', #make outline of area invisible
                                       fillcolor=f"rgba({color[4:-1]},0.2)", #manipulate color string to add transparency
                                       fill="toself",
                                       legendgroup=name,
                                       showlegend=False,
                                       name=name)
                    figure.add_scatter(x=xcoords, 
                                       y=eqe_curve_stats.loc["mean",:], 
                                       name=name, 
                                       line_color=color,
                                       legendgroup=name)
                else:
                    figure.add_scatter(x=np.concatenate([xcoords,xcoords[::-1]]), #[indices, reversed indices]
                                       y= pd.concat([eqe_curve_stats.loc["lower_quartile",:], eqe_curve_stats.loc["upper_quartile",::-1]]), #[lower qurtile, upper quartile in reversed order]
                                       line_color='rgba(255,255,255,0)', #make outline of area invisible
                                       fillcolor=f"rgba({color[4:-1]},0.2)", #manipulate color string to add transparency
                                       fill="toself",
                                       legendgroup=name,
                                       showlegend=False,
                                       name=name)
                    figure.add_scatter(x=xcoords, 
                                       y=eqe_curve_stats.loc["median",:], 
                                       name=name, 
                                       line_color=color,
                                       legendgroup=name)
                
        else:
            #iterate over every sample and cell with parameter plot set to true
            for sample_id in data["sample_ids"]:
                sample_name = data["properties"].loc[sample_id,"name"]
                samples_filtered =data["params"].loc[sample_id].loc[data["params"].loc[sample_id]["plot"]]  
                for i in samples_filtered.index:
                    figure.add_scatter(x=data["curves"].loc[(sample_id, *i),column_name],
                                       y=data["curves"].loc[(sample_id, *i),"eqe_array"],
                                       name=curve_options.name.value(sample_name,samples_filtered.loc[i,"name"]))
        figure.show()

#options contains list of tupels with contents (description, value), in this case value is (axis title, column name) 
unit_selector = widgets.ToggleButtons(options=[("photon energy",("photon energy / eV","photon_energy_array")),
                                               ("wavelength",("wavelength / nm", "wavelength_array"))], index=0)
standart_deviation_area = widgets.ToggleButtons(description="group type",options=[("median, quartiles",False),("average, std",True)], index=0)
curve_button = widgets.Button(description="refresh plot", button_style='primary')
curve_button.on_click(update_curve_plot)

curve_options = plotting_utils.plot_options(default_name=0)

display(intervals, standart_deviation_area, unit_selector, curve_options, curve_button, curve_out)

Checkbox(value=True, description='group curves with same name', indent=False)

ToggleButtons(description='group type', options=(('median, quartiles', False), ('average, std', True)), value=…

ToggleButtons(options=(('photon energy', ('photon energy / eV', 'photon_energy_array')), ('wavelength', ('wave…

plot_options(children=(ToggleButtons(options=(('sample + curve name', <function sample_and_curve_name at 0x7f8…

Button(button_style='primary', description='refresh plot', style=ButtonStyle())

Output()

# Boxplot

In [5]:
options = [
    ("bandgap",("bandgap_eqe","bandgap / eV")),
    ("jsc",("integrated_jsc","jsc / A/cm²")),
    ("j0rad",("integrated_j0rad","j0rad / A/cm²")),
    ("voc rad",("voc_rad","Voc rad / V")),
    ("urbach energy",("urbach_energy","urbach energy / eV")),
]
box_out = widgets.Output()

def update_box_plot(b):
    with box_out:
        column_name, axis_title = parameter_selector.value
        layout = go.Layout(
            width=box_options.width.value,
            height=box_options.height.value,
            yaxis={"title":{"text":axis_title}},
            template=template
        )
        box_out.clear_output()
        figure = go.Figure(layout=layout)
        x = []
        y = []
        for sample_id in data["sample_ids"]:
            sample_name = data["properties"].loc[sample_id,"name"]
            samples_filtered =data["params"].loc[sample_id].loc[data["params"].loc[sample_id]["plot"]]  
            for i in samples_filtered.index:
                x.append(box_options.name.value(sample_name, samples_filtered.loc[i,"name"]))
                y.append(data["params"].loc[(sample_id, *i),column_name])
        figure.add_box(x=x, y=y)
        figure.show()
        
parameter_selector = widgets.ToggleButtons(options=options, index=1)
box_button = widgets.Button(description="refresh plot", button_style='primary')
box_button.on_click(update_box_plot)

box_options = plotting_utils.plot_options(default_name=1)

display(parameter_selector, box_options, box_button, box_out)

ToggleButtons(index=1, options=(('bandgap', ('bandgap_eqe', 'bandgap / eV')), ('jsc', ('integrated_jsc', 'jsc …

plot_options(children=(ToggleButtons(index=1, options=(('sample + curve name', <function sample_and_curve_name…

Button(button_style='primary', description='refresh plot', style=ButtonStyle())

Output()