In [10]:
%%capture --no-display
# %%capture, hides output of this cell. When cell fails to run remove capture to check error
%pip install -r requirements.txt;


### Welcome
Welcome to the main program. This notebook was developed using Python 3.12.2 in [VSCode](https://code.visualstudio.com/) and uses the packages listed in requirements.txt. These can be installed using the cell above. If that doesn't work, try creating a new virtual environment first.

This notebook has the following cells
- Parameter file and dataset selection
- Ground truth, prediction, and feature selection
- Metric selection
- Calculations and Results
- Saving parameters to file

In the first cell, a dataset can be selected. This must be a .csv file. You can also upload a parameter file in this cell. Such a file, which can be created at the end of this notebook, stores all the selections made during a previous session. When uploading a parameter file, the program will try to load the corresponding .csv file from the designated folder. If it is not found, you will be prompted to upload the dataset manually.

In the second section, you can select the columns that contain predictions and ground truth. The program will give feedback indicating whether the column names have been found. Additionally, you can select protected attributes—each of these features will be evaluated separately by the selected metrics.

In the next section, you can choose which metrics to use. Some metrics require either predictions and/or ground truth. If these requirements are not met, the corresponding metric will be disabled.

As mentioned earlier, the final section allows you to save your selections to a parameter file. The program checks whether the filename is available and valid.

It is recommended to run each cell and fill out the selections before running the next cell. There should be no need to change any code, all of the selection is done by [ipywidgets](https://ipywidgets.readthedocs.io/en/latest/how-to/index.html). The cells of the notebook can be closed by dubble clicking at the left edge of a cell. 


If you posses the model used you may also use SHAP to find the feature importance of the model. In the file, [shap fairness_explainer](https://github.com/bytnater/bias-analyse/blob/main/fairness_metrics/Fairness_explainer/shap_fairness_explainer.py), there is a example on how to use this modul.

We also have created a small start to Hierarchical Bias-Aware Clustering ([HBAC](https://github.com/bytnater/bias-analyse/blob/main/HBAC.ipynb)). HBAC is designed to detect groups in data that may be treated unfairly—particularly in unsupervised settings (i.e., without known labels). It clusters the data based not on feature similarity alone, but on disparities in a chosen bias metric (e.g., fairness, outcome imbalance, etc.).



In [11]:
import utils
import os
from io import BytesIO
from IPython.display import display

import warnings
import torch

import ipywidgets as widgets
from ipywidgets import Layout, Box, VBox, Accordion, Label, Valid, Text, IntText, Combobox, Checkbox, RadioButtons, Button

## part of the project
from fairness_metrics.Predicted_outcomes.Error_rate_metrics import Error_rate_metrics
from fairness_metrics.Predicted_outcomes.Predictive_value_metrics import Predictive_value_metrics
from fairness_metrics.Predicted_outcomes.statistical_parity import statistical_parity
from fairness_metrics.predicted_probs.balance_in_pos_neg import balance_in_pos_neg
from fairness_metrics.predicted_probs.well_callibrated import well_calibration
from fairness_metrics.similarity_based.similarity_based import LipschitzFairness

warnings.filterwarnings('ignore')

In [19]:
# Parameter file and dataset file selection

# Radio buttons to choose between uploading a session file or starting a new one
SESSION = RadioButtons(
    options=[('Upload session file', 1), ('Create new session', 0)],
    value=0,
    style = {'description_width': 'initial'},
    description='Load params file:',)

# Define style and layout for widgets
layout = Layout(width='auto')

# UI box for uploading a .pt file (parameter file)
PARAMS_UPLOAD_BOX = VBox([
    Label(value='Upload your params file, this must be a .pt file'),
    widgets.FileUpload(
        accept='.pt',
        description='Upload file',
        tooltip='No file',
        layout=layout,
    )
])

# UI box for uploading a .csv file (dataset)
DATA_UPLOAD_BOX = VBox([
    Label(value='Please select the dataset you want to analyise, this must be a .csv file'),
    widgets.FileUpload(
        accept='.csv',
        description='Upload file',
        tooltip='No file',
        layout=layout,
    )
])

# Label for feedback messages during upload
UPLOAD_FEEDBACK = Label(Value='')

# Group all upload options and feedback into one vertical layout
FILE_INFO_BOX = VBox([
    SESSION,
    PARAMS_UPLOAD_BOX,
    DATA_UPLOAD_BOX,
    UPLOAD_FEEDBACK,
])

# Add spacing below the radio buttons
SESSION.layout.padding = '0px 0px 30px 0px'

# Hide the parameter upload box by default
PARAMS_UPLOAD_BOX.layout.display = 'None'

output_file_selection = widgets.Output()
display(FILE_INFO_BOX, output_file_selection)

# Function to toggle between session options
def toggle(change):
    if SESSION.value:
        # If user chooses to upload a session file
        UPLOAD_FEEDBACK.value = ''
        PARAMS_UPLOAD_BOX.layout.display = 'block'
        DATA_UPLOAD_BOX.layout.display = 'none'
    else: 
        # If user chooses to create a new session
        UPLOAD_FEEDBACK.value = ''
        DATA_UPLOAD_BOX.layout.display = 'block'
        PARAMS_UPLOAD_BOX.layout.display = 'none'
        PARAMS_UPLOAD_BOX.children[1].value = ()
        PARAMS_UPLOAD_BOX.children[1].description='Upload file'
        PARAMS_UPLOAD_BOX.children[1].tooltip='No file'

        # Save an empty parameter file
        params = {}
        torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

# Function to handle uploaded .pt parameter files
def upload_param_file(change):
    with output_file_selection:
        if change['new']:
            uploaded_file = change['new'][0]
            # Update file upload UI text
            PARAMS_UPLOAD_BOX.children[1].description = f'Uploaded "{uploaded_file.name}"'
            PARAMS_UPLOAD_BOX.children[1].tooltip = f'{uploaded_file.name}'

            # Load parameter file from uploaded content
            params = torch.load(BytesIO(uploaded_file.content), weights_only=True)
            used_dataset = params.get('used_dataset', '')

            if not used_dataset:
                # If no dataset info is in the parameter file, ask user to upload one
                DATA_UPLOAD_BOX.layout.display = 'block'
                UPLOAD_FEEDBACK.value = f'No dataset file found connected to this preset file, please upload a file manually'
            else:
                # Try to load the dataset from disk
                if os.path.exists(utils.SAVED_DATASET_PATH + used_dataset):
                    global dataset
                    dataset = utils.Dataset(utils.SAVED_DATASET_PATH + used_dataset)

                    # Reset the dataset upload UI
                    DATA_UPLOAD_BOX.children[1].value = ()
                    DATA_UPLOAD_BOX.children[1].description='Upload file'
                    DATA_UPLOAD_BOX.children[1].tooltip='No file'
                    DATA_UPLOAD_BOX.layout.display = 'none'

                    UPLOAD_FEEDBACK.value = f'Dataset file found at "{utils.SAVED_DATASET_PATH + used_dataset}"'
                else:
                    # If dataset file is missing on disk, ask for manual upload
                    DATA_UPLOAD_BOX.layout.display = 'block'
                    UPLOAD_FEEDBACK.value = f'file "{used_dataset}" not found at "{utils.SAVED_DATASET_PATH + used_dataset}", please upload a file manually'
            
            # Save the parameter file for later use
            torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

# Function to handle uploaded .csv dataset files
def upload_csv(change):
    with output_file_selection:
        if change['new']:
            uploaded_file = change['new'][0]

            # Update file upload UI text
            DATA_UPLOAD_BOX.children[1].description = f'Uploaded "{uploaded_file.name}"'
            DATA_UPLOAD_BOX.children[1].tooltip = f'{uploaded_file.name}'

            global dataset
            dataset = utils.Dataset(upload_widget=change['new'])

            # Update and save the dataset name in the parameter file
            params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)
            params['used_dataset'] = uploaded_file.name
            torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

# Connect widgets to their corresponding functions      
SESSION.observe(toggle, names='value')
PARAMS_UPLOAD_BOX.children[1].observe(upload_param_file, names='value')
DATA_UPLOAD_BOX.children[1].observe(upload_csv, names='value')

# Reset dataset
dataset=None

# Reset session parameter file
params = {}
torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)


VBox(children=(RadioButtons(description='Load params file:', index=1, layout=Layout(padding='0px 0px 30px 0px'…

Output()

In [20]:
# Ground truth, prediction, and protected atrributes selection

# Make sure a dataset is selected
assert dataset, 'Make sure to select a dataset to analise'

# Load saved parameters
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

# Define style and layout for widgets
style = {'background':'white'}
item_layout = Layout(width='auto')

# Create a checkbox for each feature to select as a protected attribute
protected_items = [Checkbox(layout=item_layout, description=dataset.i2c[i], indent=False, style=style) for i in range(len(dataset.i2c))]

# Layout settings for the box that contains all checkboxes
box_layout = Layout(
    overflow='hidden scroll',
    border='empty',
    width='auto',
    height='300px',
    flex_flow='column',
    display='flex',
    padding='0'
)

# Box that displays all protected attribute checkboxes
PROTECTED_SELECTION = Box(children=protected_items, layout=box_layout)

# Combobox for selecting the ground truth feature
GROUND_SELECTION = Combobox(
    value=None,
    placeholder='Choose a feature',
    options=dataset.i2c+[''],
    description='Ground truth:',
    ensure_option=True,
    style={'description_width':'150px'},
    layout=Layout(width='auto'),
)

# Combobox for selecting the predicted feature
PREDICTED_SELECTION = Combobox(
    value=None,
    placeholder='Choose a feature',
    options=dataset.i2c+[''],
    description='Predicted by model:',
    ensure_option=True,
    style={'description_width':'150px'},
    layout=Layout(width='auto'),
)

# Label to show feedback about the selected features
FEEDBACK_INFO_SLECETION = Label(value='')
FEEDBACK_INFO_SLECETION.layout.padding = '0px 0px 90px 0px'

# Combine all widgets into a single vertical layout
SELECTION_BOX = VBox([
    Label('Please select the ground truth and the predicted probability (if available):'), 
    GROUND_SELECTION,
    PREDICTED_SELECTION,
    FEEDBACK_INFO_SLECETION,
    Label('Please, select the protected attributes you want to analyse:'), 
    PROTECTED_SELECTION,
    ])

# Restore previously saved values for protected attributes and selected columns
preset_protected_values = params.get('protected_values', torch.zeros(len(dataset.i2c), dtype=bool))
for item, value in zip(protected_items, preset_protected_values):
    item.value = bool(value)
GROUND_SELECTION.value=params.get('ground_truth_column', '')
PREDICTED_SELECTION.value=params.get('prediction_column', '')

# Display output
output_selection = widgets.Output()
display(SELECTION_BOX, output_selection)

# Function to update feedback based on current selection
def update_selection_feedback(id):
    with output_selection:
        if GROUND_SELECTION.value == PREDICTED_SELECTION.value != '':
            FEEDBACK_INFO_SLECETION.value=f'Both are now selected as "{GROUND_SELECTION.value}", Please make sure that the ground truth is not the same as the predictions'
        else:
            FEEDBACK_INFO_SLECETION.value = (
                'You have selected ' + 
                (f'ground truth as "{GROUND_SELECTION.value}"' if GROUND_SELECTION.value else 'no ground truth') + 
                ' and ' + 
                (f'predictions as "{PREDICTED_SELECTION.value}"' if PREDICTED_SELECTION.value else 'no predictions')
            )

# Link feedback update function to changes in the boxes
GROUND_SELECTION.on_trait_change(update_selection_feedback)
PREDICTED_SELECTION.on_trait_change(update_selection_feedback)
SELECTION_BOX.on_widget_constructed(update_selection_feedback)

# Button to save the current selection
SAVE_CHANGES = Button(
    description='Save changes',
    button_style='success',
    tooltip='Save your current selection',
)

# Display output for saving
output_save_changes = widgets.Output()
display(SAVE_CHANGES, output_save_changes)

# Function to save the current selection back into the preset file
def save_changes(id):
    with output_save_changes:
        params['protected_values'] = torch.tensor([protected_item.value for protected_item in protected_items])
        params['ground_truth_column'] = GROUND_SELECTION.value
        params['prediction_column'] = PREDICTED_SELECTION.value
        torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

# Link the save function to the button click
SAVE_CHANGES.on_click(save_changes)


VBox(children=(Label(value='Please select the ground truth and the predicted probability (if available):'), Co…

Output()

Button(button_style='success', description='Save changes', style=ButtonStyle(), tooltip='Save your current sel…

Output()

In [15]:
# Metric selection

# Load saved parameters
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

# Makes sure that at least one attribute is being analyzed
assert sum(params.get('protected_values', torch.zeros(len(dataset.i2c), dtype=bool))) != 0, 'You have to select at least one protected attribute'

# Style and layout
item_layout = Layout(width='auto')
slider_style = {'description_width': '200px'}

# Descriptions
fairness_groups_names = [
    'Predicted and actual outcomes',
    'Predicted probibilities and actual outcomes',
    'Similarity based',
]
metric_names = [
    ['Error rate', 'Predictive value', 'Statistical parity'],
    ['Test-fairness', 'Balance in classes'],
    ['Lipschitz']
]

# Valid widgets
VALID_GROUND = Valid(value=True, description='Ground truth')
VALID_PRED = Valid(value=True, description='Prediction')

# Load slider cache: tag -> slider
slider_cache = {
    name:widgets.FloatSlider(
        value=value,
        min=0.1, max=5, step=0.1,
        description=name,
        tooltip=name,
        layout=item_layout,
        indent=False,
        style=slider_style)
    for name, value in params.get('condition', dict()).items()
}

# Tag selector
value = [name for (name, slider) in slider_cache.items()]
tags = widgets.TagsInput(
    value=value,
    allowed_tags=dataset.i2c,
    allow_duplicates=False
)

# Container for the sliders
box_layout = Layout(
    overflow='hidden scroll',
    border='empty',
    width='auto',
    max_height='300px',
    flex_flow='column',
    display='flex',
    padding='0'
)

children = [slider for (name, slider) in slider_cache.items()]
TAGS_SLIDERS = Box(children=children, layout=box_layout)

# TAG_Wrapper for display
TAG_WRAPPER = VBox([
    tags,
    TAGS_SLIDERS,
])

# Dynamic update function
def update_tags_sliders(change):
    if change['name'] == 'value':
        current_tags = change['new']
        new_sliders = []

        for name in current_tags:
            if name in slider_cache:
                slider = slider_cache[name]
            else:
                slider = widgets.FloatSlider(
                    value=1.0,
                    min=0.1, max=5, step=0.1,
                    description=name,
                    tooltip=name,
                    layout=item_layout,
                    indent=False,
                    style=slider_style,
                )
                slider_cache[name] = slider

            new_sliders.append(slider)

        TAGS_SLIDERS.children = new_sliders

# Observe changes to the tags input
tags.observe(update_tags_sliders, names='value')

# Metric widgets
TEST_FAIRNESS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='This metric check if individuals assigned the same score s have equal likelihoods of the positive outcome, independent of the protected attribute'),
    VBox([VALID_GROUND, VALID_PRED]),
])
BALANCE_IN_CLASS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Balance for the positive and negative classes are fairness criteria that focus on ensuring equitable scoring among individuals who share the same actual outcome.'),
    Checkbox(layout=item_layout, description='Show the positive balance', indent=False),
    Checkbox(layout=item_layout, description='Show the negative balance', indent=False),
    VBox([VALID_GROUND, VALID_PRED]),
])
ERROR_RATE_METRICS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    VBox([VALID_GROUND, VALID_PRED]),
])
PREDICTIVE_VALUE_METRICS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    VBox([VALID_GROUND, VALID_PRED]),
])
STATISTICAL_PARITY = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    TAG_WRAPPER,
    VBox([VALID_PRED]),
])
SIMILARITY_BASED = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    IntText(value=1000, description='Sample size:'),
    RadioButtons(options=['Manhattan Distance', 'Euclidean Distance', 'cosine'],
                 value='Manhattan Distance',
                 style={'description_width': 'initial'},
                 description='Choose distance metric:'),
    VBox([VALID_PRED]),
])

# Define metrics
metrics_groups = [
    [ERROR_RATE_METRICS, PREDICTIVE_VALUE_METRICS, STATISTICAL_PARITY],
    [TEST_FAIRNESS, BALANCE_IN_CLASS],
    [SIMILARITY_BASED]
]

# Utility to simulate an accordion group with multi-open toggle buttons
def make_toggle_section(title, content):
    toggle = widgets.ToggleButton(value=False, description=title, layout=Layout(width='200px'))
    box = VBox([content])
    box.layout.display = 'block'

    def toggle_visibility(change):
        box.layout.display = 'none' if change['new'] else 'block'

    toggle.observe(toggle_visibility, names='value')
    return VBox([toggle, box])

# Apply toggle accordion logic to each metrics group
custom_groups = []
for group_metrics, titles in zip(metrics_groups, metric_names):
    group_box = VBox([
        make_toggle_section(title, content)
        for title, content in zip(titles, group_metrics)
    ])
    custom_groups.append(group_box)

# Create outer accordion using standard ipywidgets (since you only want one top-level group open)
outer_accordion = Accordion(children=custom_groups)
for i, title in enumerate(fairness_groups_names[:len(custom_groups)]):
    outer_accordion.set_title(i, title)

# Disables metric that are not usable with the current selection
def validate_valid_valids(metric_widgets):
    valid_valids = all(v.value for v in metric_widgets.children[-1].children)
    metric_widgets.children[0].value = valid_valids
    metric_widgets.children[0].disabled = not(valid_valids)

# Restore state
VALID_GROUND.value = bool(params.get('ground_truth_column', ''))
VALID_PRED.value = bool(params.get('prediction_column', ''))

[[validate_valid_valids(metric) for metric in group] for group in metrics_groups] 

default = [[metric.children[0].value for metric in group] for group in metrics_groups]
metric_selection = params.get('metric_selection', default)
for group, group_selection in zip(metrics_groups, metric_selection):
    for metric, selection in zip(group, group_selection):
        if not metric.children[0].disabled:
            metric.children[0].value = selection

BALANCE_IN_CLASS.children[2].value = params.get('balance_pos', True)
BALANCE_IN_CLASS.children[3].value = params.get('balance_neg', True)
SIMILARITY_BASED.children[2].value = params.get('sample_limit', 1000)
SIMILARITY_BASED.children[3].value = params.get('distance_metric', 'Manhattan Distance')

# Display the result
display(outer_accordion)

# Button to save the current selection
SAVE_CHANGES = Button(
    description='Save changes',
    button_style='success',
    tooltip='Save your current selection',
)

# Display output for saving
output_save_changes = widgets.Output()
display(SAVE_CHANGES, output_save_changes)

# Function to save the current selection back into the preset file
def save_changes(id):
    with output_save_changes:
        params['metric_selection'] = [[metric.children[0].value for metric in group] for group in metrics_groups]
        params['balance_pos'] = BALANCE_IN_CLASS.children[2].value
        params['balance_neg'] = BALANCE_IN_CLASS.children[3].value
        params['sample_limit'] = SIMILARITY_BASED.children[2].value
        params['distance_metric'] = SIMILARITY_BASED.children[3].value
        params['condition'] = {slider.description:slider.value for slider in TAGS_SLIDERS.children}
        torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

# Link the save function to the button click
SAVE_CHANGES.on_click(save_changes)


Accordion(children=(VBox(children=(VBox(children=(ToggleButton(value=False, description='Error rate', layout=L…

Button(button_style='success', description='Save changes', style=ButtonStyle(), tooltip='Save your current sel…

Output()

In [16]:
# Loading and calculating each metric 

# Load saved parameters
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

def flatten(xss):
    return [x for xs in xss for x in xs]

# Defining metrics
metrics = [[Error_rate_metrics, Predictive_value_metrics, statistical_parity],
           [well_calibration, balance_in_pos_neg],
           [LipschitzFairness]]

# Select and calculate metrics
default = [[metric.children[0].value for metric in group] for group in metrics_groups]
metric_selection = params.get('metric_selection', default)
used_metrics = []
for metric, metric_name, selection in zip(flatten(metrics), flatten(metric_names), flatten(metric_selection)):
    if selection:
        used_metrics.append((metric(dataset, params), metric_name))
        print(f'{metric_name} calculated')

Error rate calculated
Predictive value calculated
Statistical parity calculated
Test-fairness calculated
Balance in classes calculated


### Results

Running the codecell below will Produce a graphs per feature for each metric in your selection:

The Error rate metric will produce a graph showing the False Positive Rate (FPR) and False Negative Rate (FNR) for each group in the attribute. This allow you to check fairness accoring to the following metrics:
- Predictive Equality, which requires equal FPRs accross all groups.
- Equal Opportunity, which requires equal FNRs accross all groups.
- Equalised Odds, which requires equal FPRs and FNRs accross all groups.

The Predictive value metric will produce a graph showing the Positive Predictive Value (PPV) and Negative Predictive Value (NPV) for each group in the attribute. This allow you to check fairness accoring to the following metrics:
- Predictive Parity, which requires equal PPVs accross all groups.
- Conditional Use Accuracy Equality, which requires equal PPVs and NPVs accross all groups.

The (conditional) Statistical Parity metric will produce a graph showing the Positive Prediction rate for each group in the attribute. This allow you to check fairness accoring to the following metrics:
- Statistical Parity, which requires equal Positive Prediction rates accross all groups.
- Conditional Statistical Parity, which requires equal Positive Prediction rates of individuals that conform to the condition accross all groups.

The Test-fairness metric will produce a graph showing the actual and predicted probability of an positive outcome for each group in the attribute. This allow you to check fairness accoring to the following metrics:
- Test-fairness, which requires that for any given predicted risk score, the probability of the positive outcome is equal across all groups.
- Well-calibration, which requires that for any given predicted risk score, the probability of the positive outcome is equal across all groups and the predicted risk score accurately reflects the probability of the positive outcome.

The Balance in classes metric will produce a graph showing the predicted risk score of the positve and negative class for each group in the attribute. This allow you to check fairness accoring to the following metrics:
- Balance For The Positive Class, which requires equal predicted risk scores for the positve class accross all groups.
- Balance For The Negative Class, which requires equal predicted risk scores for the negative class accross all groups.

The Lipschitz metric will produce a graph showing the severity of Lipschitz violations for all samples in the attribute. This allow you to check fairness accoring to the following metrics:
- Fairness Through Awareness, which requires that the difference in outputs is no greater than a constant multiple of the difference in inputs.

In [17]:
# show results of each metric
for metric, metric_name in used_metrics:
    print(f'{'#' * (len(metric_name) + 11)}\n# Metric {metric_name} #\n{'#' * (len(metric_name) + 11)}')
    for image in metric.show():
        image.show()

#####################
# Metric Error rate #
#####################


###########################
# Metric Predictive value #
###########################


#############################
# Metric Statistical parity #
#############################


########################
# Metric Test-fairness #
########################


#############################
# Metric Balance in classes #
#############################


In [18]:
# Saving parameters to file

# Info Widgets 
OUTPUT_FILE_NAME = Text(
    placeholder='Type output file name',
    description='File name:',
    tooltip='do not forget the file extension'
)
FEEDBACK_INFO = Label(
    value='',
)
SUBMIT_BUTTON = Button(
    description='Save preset',
    disabled=True,
    button_style='danger',
    tooltip='No file name',
)
info_box = VBox([
    Label(value=f'Select a name for your preset file here, This will be saved in the folder "{utils.SAVED_PRESET_PATH}"'),
    OUTPUT_FILE_NAME,
    SUBMIT_BUTTON,
    FEEDBACK_INFO,
])

# Display widgets
output_savefile = widgets.Output()
display(info_box, output_savefile)

def update_file_name_info(id):
    with output_savefile:
        if type(id) == Button:
            params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

            torch.save(params, utils.SAVED_PRESET_PATH + OUTPUT_FILE_NAME.value)
            FEEDBACK_INFO.value=f'Preset saved to "{OUTPUT_FILE_NAME.value}"'
        
        if type(id) == Text:
            FEEDBACK_INFO.value=''
        
        file_exits = os.path.exists(utils.SAVED_PRESET_PATH + OUTPUT_FILE_NAME.value)
        if OUTPUT_FILE_NAME.value == '':
            SUBMIT_BUTTON.button_style='danger'
            SUBMIT_BUTTON.tooltip='No file name'
            SUBMIT_BUTTON.disabled=True
        elif OUTPUT_FILE_NAME.value == 'session_save.pt':
            SUBMIT_BUTTON.button_style='danger'
            SUBMIT_BUTTON.tooltip='This file name is reserved'
            SUBMIT_BUTTON.disabled=True
        elif not OUTPUT_FILE_NAME.value.endswith('.pt'):
            SUBMIT_BUTTON.button_style='warning'
            SUBMIT_BUTTON.tooltip='Please make it a .pt file'
            SUBMIT_BUTTON.disabled=False
        elif file_exits:
            SUBMIT_BUTTON.button_style='info'
            SUBMIT_BUTTON.tooltip='A file already exists with this name, overwrite is possible'
            SUBMIT_BUTTON.disabled=False
        else:
            SUBMIT_BUTTON.button_style='success'
            SUBMIT_BUTTON.tooltip='Save file'
            SUBMIT_BUTTON.disabled=False

# Link feedback update function to changes in the boxes
OUTPUT_FILE_NAME.on_trait_change(update_file_name_info)
SUBMIT_BUTTON.on_click(update_file_name_info)


VBox(children=(Label(value='Select a name for your preset file here, This will be saved in the folder "presets…

Output()