In [1]:
import numpy as np
import pandas as pd
from scipy.stats import qmc
import ipywidgets as widgets
from IPython.display import display, clear_output

from datetime import datetime

import os
import sys
sys.path.insert(1, '../../python-scripts-c6fxKDJrSsWp1xCxON1Y7g')
from api_calls import *
url = "https://nomad-hzb-ce.de/nomad-oasis/api/v1"
token = os.environ['NOMAD_CLIENT_ACCESS_TOKEN']

# Latin Hypercube Sampling (LHS)

In [2]:
# ----------------------------------------------------------
# Module 1: LHS generator function
# ----------------------------------------------------------
def generate_lhs(params, n_samples):
    '''
    Generate a Latin Hypercube Sampling based on parameter definitions.
    
    params: list of dicts with keys ('name', 'min', 'max')
    n_samples: number of desired samples
    
    Returns: pandas.DataFrame
    '''
    sampler = qmc.LatinHypercube(d=len(params))
    sample = sampler.random(n=n_samples)

    # Scale to min/max of each parameter
    l_bounds = [p['min'] for p in params]
    u_bounds = [p['max'] for p in params]
    scaled = qmc.scale(sample, l_bounds, u_bounds)
    
    df = pd.DataFrame(scaled, columns=[p['name'] for p in params])
    return df

# ----------------------------------------------------------
# Module 2: Interactive parameter input widget
# ----------------------------------------------------------
class LHSWidget:
    def __init__(self):
        self.param_boxes = []
        self.param_box_container = widgets.VBox()
        
        # Table header
        header = widgets.GridBox(
            children=[
                widgets.HTML('<b>Name</b>'),
                widgets.HTML('<b>Min</b>'),
                widgets.HTML('<b>Max</b>'),
                widgets.HTML('<b>Num of decimals</b>'),
                widgets.HTML('<b>Remove</b>')
            ],
            layout=widgets.Layout(
                grid_template_columns='150px 100px 100px 120px 50px',
                grid_gap='5px 10px'
            )
        )
        
        # Control widgets
        self.add_button = widgets.Button(description='➕ Add parameter', button_style='success')
        self.add_button.on_click(self.add_param_row)
        
        self.sample_count = widgets.IntText(value=10, description='Samples:')
        self.generate_button = widgets.Button(description='Generate LHS', button_style='primary')
        self.generate_button.on_click(self.generate_lhs_table)
        
        self.output = widgets.Output()
        self.output_save = widgets.Output()
        
        # Save options
        self.download_button = widgets.Button(description='Download CSV', button_style='info', layout=widgets.Layout(width='505px'))
        self.download_button.on_click(self.download_csv)

        uploads = get_all_uploads(url, token, number_of_uploads=200)
        self.upload_dropdown = widgets.Dropdown(
            options = [(u.get("upload_name","--no-name--"), u) for u in uploads],
            layout=widgets.Layout(width='300px')
        )
        self.save_button = widgets.Button(description='Save to NOMAD upload', button_style='info', layout=widgets.Layout(width='200px'))
        self.save_button.on_click(self.save_to_upload)
        
        self.save_box = widgets.VBox([
            widgets.HBox([self.upload_dropdown, self.save_button]),
            self.download_button,
        ])
        
        # Display layout
        display(widgets.VBox([
            header,
            self.param_box_container,
            self.add_button,
            self.sample_count,
            self.generate_button,
            #self.save_box,
            self.output
        ]))
        
    def add_param_row(self, b=None):
        name = widgets.Text(layout=widgets.Layout(width='150px'))
        min_val = widgets.FloatText(value=1.0, layout=widgets.Layout(width='100px'))
        max_val = widgets.FloatText(value=10.0, layout=widgets.Layout(width='100px'))
        precision = widgets.IntText(value=0, layout=widgets.Layout(width='120px'))
        remove_btn = widgets.Button(description='❌', layout=widgets.Layout(width='40px'))
        
        row_grid = widgets.GridBox(
            children=[name, min_val, max_val, precision, remove_btn],
            layout=widgets.Layout(
                grid_template_columns='150px 100px 100px 120px 50px',
                grid_gap='5px 10px'
            )
        )
        
        self.param_boxes.append((name, min_val, max_val, precision, row_grid))
        self.param_box_container.children = [p[4] for p in self.param_boxes]
        
        def remove_row(_):
            self.param_boxes.remove((name, min_val, max_val, precision, row_grid))
            self.param_box_container.children = [p[4] for p in self.param_boxes]
        
        remove_btn.on_click(remove_row)

    
    def generate_lhs_table(self, b=None):
        '''Generate and display the LHS table.'''
        params = []
        precisions = []
        for name, min_val, max_val, precision, _ in self.param_boxes:
            if not name.value:
                continue
            params.append({
                'name': name.value,
                'min': min_val.value,
                'max': max_val.value
            })
            precisions.append(precision.value)
        
        if not params:
            with self.output:
                clear_output()
                print('⚠️ Please add at least one parameter.')
            return

        for p, prec in zip(params, precisions):
            step = 10 ** -prec
            n_unique = int((p['max'] - p['min']) / step) + 1
            if n_unique < self.sample_count.value:
                with self.output:
                    clear_output()
                    print(f'⚠️ Warning: Parameter "{p["name"]}" has only {n_unique} unique values with precision {prec}, '
                          f'but {self.sample_count.value} samples requested.')
                return
        
        self.df = generate_lhs(params, self.sample_count.value)
        # Apply numeric precision rounding per column
        for col, p in zip(self.df.columns, precisions):
            self.df[col] = self.df[col].round(p)
        
        with self.output:
            clear_output()
            display(self.df)
    
    def download_csv_old(self, b=None):
        '''Download the generated table as CSV.'''
        if not hasattr(self, 'df'):
            with self.output_save:
                clear_output()
                print('⚠️ No sampling has been created yet.')
            return
        
        from IPython.display import FileLink
        filename = 'lhs.csv'
        self.df.to_csv(filename, index=False)
        with self.output_save:
            clear_output()
            print('✅ CSV saved:')
            display(FileLink(filename))
            display(self.df.head())

    def download_csv(self, b=None):
        file_name = 'lhs_sampling.csv'
        create_zip_download_link([self.df], [file_name], 'lhs')
    
    def save_to_upload(self, b=None):
        '''Simulate saving the sampling into an upload target.'''
        if not hasattr(self, 'df'):
            with self.output_save:
                clear_output()
                print('⚠️ No sampling has been created yet.')
            return
        
        selected_folder = self.upload_dropdown.value
        folder_name = selected_folder.get('upload_name', '').lower()
        folder_name = folder_name.replace(' ', '-')
        if folder_name != '':
            folder_name = f'{folder_name}-'
        folder_name = f'../../{folder_name}{selected_folder.get("upload_id")}'
        time = datetime.now().strftime("%Y-%m-%d")
        filename = f'{folder_name}/lhs-{time}.csv'
        self.df.to_csv(filename, index=False)
        with self.output_save:
            clear_output()
            print(f'✅ Sampling saved as "{filename}".')
            display(self.df.head())


# ----------------------------------------------------------
# Start the interactive app
# ----------------------------------------------------------
my_lhs_app = LHSWidget()


VBox(children=(GridBox(children=(HTML(value='<b>Name</b>'), HTML(value='<b>Min</b>'), HTML(value='<b>Max</b>')…

### Save LHS table

In [3]:
display(widgets.VBox([
    my_lhs_app.save_box,
    my_lhs_app.output_save,
]))

VBox(children=(VBox(children=(HBox(children=(Dropdown(layout=Layout(width='300px'), options=(('--no-name--', {…