In [None]:
%%capture --no-display
# %%capture, hides output of this cell. When cell fails to run remove capture to check error
%pip install -r requirements.txt;


### Welcome
Welcome to the main program. We worked with python 3.12.2 and used the packages specified in the requirements.txt. You can install them using the cell above. If this doesn't work try making a new virtual environment first.

This notebook has the following cells
- Parameter file and dataset selection
- Dataset and parameter file loading
- Ground truth, prediction, and feature selection
- Metric selection
- Calculations and Results
- Saving parameters to file

It is recommended to run each cell and fill out the selections before running the next cell. There should be no need to change any code, all of the selection is done by [ipywidgets](https://ipywidgets.readthedocs.io/en/latest/how-to/index.html). The cells of the notebook can be closed by dubble clicking at the left edge of a cell. 


If you posses the model used you may also use SHAP to find the feature importance of the model. In the file, [shap fairness_explainer](https://github.com/bytnater/bias-analyse/blob/main/fairness_metrics/Fairness_explainer/shap_fairness_explainer.py), there is a example on how to use this modul.


In [None]:
import utils
import os
from IPython.display import display

import warnings
import torch

import ipywidgets as widgets
from ipywidgets import Layout, Box, VBox, Accordion, Label, Valid, Text, IntText, Combobox, Checkbox, RadioButtons, Button

## part of the project
from fairness_metrics.Predicted_outcomes.Error_rate_metrics import Error_rate_metrics
from fairness_metrics.Predicted_outcomes.Predictive_value_metrics import Predictive_value_metrics
from fairness_metrics.Predicted_outcomes.statistical_parity import statistical_parity
from fairness_metrics.predicted_probs.balance_in_pos_neg import balance_in_pos_neg
from fairness_metrics.predicted_probs.well_callibrated import well_calibration
from fairness_metrics.similarity_based.similarity_based import LipschitzFairness

# warnings.filterwarnings('ignore')

In [None]:
# Input file selection
USE_PRESET= RadioButtons(
    options=[('Load file', 1), ('Create new', 0)],
    value=0,
    style = {'description_width': 'initial'},
    description='Load params file:',
)
INPUT_FILE_NAME = Text(
    placeholder='Type file name for parameters',
    description='File name:',
    tooltip='do not forget the file extension'
)
PRESET_FILENAME = INPUT_FILE_NAME.value if INPUT_FILE_NAME.value else 'No file'
file_exits = os.path.exists(utils.SAVED_PRESET_PATH + PRESET_FILENAME)
FILE_FOUND = Valid(
    value=file_exits,
    style = {'description_width': 'initial'},
    description=f'Preset file "{PRESET_FILENAME}"',
    tooltip='Indicates if the chosen file has been found'
)
FILE_INFO_BOX = VBox([
    Label(value=f'Select your params file here, They must be in the folder "{utils.SAVED_PRESET_PATH}"'),
    INPUT_FILE_NAME,
    FILE_FOUND,
])
FILE_INFO_BOX.layout.display = 'none'

INPUT_DATASET_NAME = Text(
    placeholder='Type file name for dataset',
    description='File name:',
    tooltip='do not forget the file extension'
)
PRESET_DATASET_FILENAME = INPUT_DATASET_NAME.value if INPUT_DATASET_NAME.value else 'No file'
file_exits = os.path.exists(utils.SAVED_PRESET_PATH + PRESET_DATASET_FILENAME)
DATASET_FOUND = Valid(
    value=file_exits,
    style = {'description_width': 'initial'},
    description=f'Preset file "{PRESET_DATASET_FILENAME}"',
    tooltip='Indicates if the chosen file has been found'
)


info_box = VBox([
    USE_PRESET,
    FILE_INFO_BOX,
    Label(value=f'Select your dataset file here, They must be in the folder "{utils.SAVED_DATASET_PATH}"'),
    INPUT_DATASET_NAME,
    DATASET_FOUND,
])

output_file_selection = widgets.Output()
display(info_box, output_file_selection)

def update_file_name_info(_):
    with output_file_selection:
        PRESET_FILENAME = INPUT_FILE_NAME.value if INPUT_FILE_NAME.value else 'No file'
        file_exits = os.path.exists(os.path.join(utils.SAVED_PRESET_PATH + PRESET_FILENAME))
        FILE_FOUND.description = f'Preset file "{PRESET_FILENAME}"'
        FILE_FOUND.value = file_exits
        if file_exits and INPUT_FILE_NAME.value.endswith('.pt'):
            params = torch.load(utils.SAVED_PRESET_PATH + INPUT_FILE_NAME.value, weights_only=True)
            INPUT_DATASET_NAME.value = params.get('used_dataset', '')

def update_dataset_info(_):
    with output_file_selection:
        PRESET_FILENAME = INPUT_DATASET_NAME.value if INPUT_DATASET_NAME.value else 'No file'
        file_exits = os.path.exists(os.path.join(utils.SAVED_DATASET_PATH + PRESET_FILENAME))
        DATASET_FOUND.description = f'Preset file "{PRESET_FILENAME}"'
        DATASET_FOUND.value = file_exits

def toggle_visibility(_):
    with output_file_selection:
        FILE_INFO_BOX.layout.display = 'block' if USE_PRESET.value else 'none'    


USE_PRESET.on_trait_change(toggle_visibility)
INPUT_FILE_NAME.on_trait_change(update_file_name_info)
INPUT_DATASET_NAME.on_trait_change(update_dataset_info)

In [None]:
# Data and file loading
dataset = utils.Dataset(utils.SAVED_DATASET_PATH + INPUT_DATASET_NAME.value)
display(Label(value=f'dataset at "{utils.SAVED_DATASET_PATH + INPUT_DATASET_NAME.value}" has been loaded'))

params = {}
if USE_PRESET.value:
    params = torch.load(utils.SAVED_PRESET_PATH + INPUT_FILE_NAME.value, weights_only=True)

    display(Label(value=f'"{INPUT_FILE_NAME.value}" has been loaded'))

params['used_dataset'] = INPUT_DATASET_NAME.value
torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)


In [None]:
# Ground truth, prediction, and protected atrributes selection
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

style = {'background':'white'}
item_layout = Layout(width='auto')

protected_items = [Checkbox(layout=item_layout, description=dataset.i2c[i], indent=False, style=style) for i in range(len(dataset.i2c))]

box_layout = Layout(overflow='hidden scroll',
                    border='empty',
                    width='auto',
                    height='300px',
                    flex_flow='column',
                    display='flex',
                    padding='0')

PROTECTED_SELECTION = Box(children=protected_items, layout=box_layout)
GROUND_SELECTION = Combobox(
    value=None,
    placeholder='Choose a feature',
    options=dataset.i2c+[''],
    description='Ground truth:',
    ensure_option=True,
    style={'description_width':'150px'},
    layout=Layout(width='auto'),
)
PREDICTED_SELECTION = Combobox(
    value=None,
    placeholder='Choose a feature',
    options=dataset.i2c+[''],
    description='Predicted by model:',
    ensure_option=True,
    style={'description_width':'150px'},
    layout=Layout(width='auto'),
)
FEEDBACK_INFO_SLECETION = Label(
    value='',
)

SELECTION_BOX = VBox([
    Label('Please select the ground truth and the predicted probability (if available):'), 
    GROUND_SELECTION,
    PREDICTED_SELECTION,
    FEEDBACK_INFO_SLECETION,
    Label('Please, select the protected attributes you want to analyse:'), 
    PROTECTED_SELECTION,
    ])

## Restore data
preset_protected_values = params.get('protected_values', torch.zeros(len(dataset.i2c), dtype=bool))
for item, value in zip(protected_items, preset_protected_values):
    item.value = bool(value)
GROUND_SELECTION.value=params.get('ground_truth_column', '')
PREDICTED_SELECTION.value=params.get('prediction_column', '')

output_selection = widgets.Output()
display(SELECTION_BOX, output_selection)

def update_selection_feedback(id):
    with output_selection:
        if GROUND_SELECTION.value == PREDICTED_SELECTION.value != '':
            FEEDBACK_INFO_SLECETION.value=f'Both are now selected as "{GROUND_SELECTION.value}", Please make sure that the ground truth is not the same as the predictions'
        else:
            FEEDBACK_INFO_SLECETION.value='You have selected ' + (f'ground truth as "{GROUND_SELECTION.value}"' if GROUND_SELECTION.value else 'no ground truth') + ' and ' + (f'predictions as "{PREDICTED_SELECTION.value}"' if PREDICTED_SELECTION.value else 'no predictions')

GROUND_SELECTION.on_trait_change(update_selection_feedback)
PREDICTED_SELECTION.on_trait_change(update_selection_feedback)
SELECTION_BOX.on_widget_constructed(update_selection_feedback)

SAVE_CHANGES = Button(
    description='Save changes',
    button_style='success',
    tooltip='Save your current selection',
)
output_save_changes = widgets.Output()

def save_changes(id):
    with output_save_changes:
        params['protected_values'] = torch.tensor([protected_item.value for protected_item in protected_items])
        params['ground_truth_column'] = GROUND_SELECTION.value
        params['prediction_column'] = PREDICTED_SELECTION.value
        torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

SAVE_CHANGES.on_click(save_changes)

display(SAVE_CHANGES, output_save_changes)

In [None]:
# Metric selection
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

assert sum(params.get('protected_values', torch.zeros(len(dataset.i2c), dtype=bool))) != 0, 'You have to select at least one protected attribute'


# Style and layout
item_layout = Layout(width='auto')
slider_style = {'description_width': '200px'}

# Descriptions
fairness_groups_names = [
    'Predicted and actual outcomes',
    'Predicted probibilities and actual outcomes',
    'Causal Discrimination',
    'Causal Reasoning'
]

metric_names = [
    ['Error rate', 'Predictive value', 'Statistical parity'],
    ['Test-fairness', 'Balance in classes'],
    ['Lipschitz']
]

# Valid widgets
VALID_GROUND = Valid(value=True, description='Ground truth')
VALID_PRED = Valid(value=True, description='Prediction')

# Load slider cache: tag -> slider
slider_cache = {
    name:widgets.FloatSlider(
        value=value,
        min=0.1, max=5, step=0.1,
        description=name,
        tooltip=name,
        layout=item_layout,
        indent=False,
        style=slider_style)
    for name, value in params.get('condition', dict()).items()
}

# Tag selector
value = [name for (name, slider) in slider_cache.items()]
tags = widgets.TagsInput(
    value=value,
    allowed_tags=dataset.i2c,
    allow_duplicates=False
)

# Container for the sliders
box_layout = Layout(
    overflow='hidden scroll',
    border='empty',
    width='auto',
    max_height='300px',
    flex_flow='column',
    display='flex',
    padding='0'
)

children = [slider for (name, slider) in slider_cache.items()]
TAGS_SLIDERS = Box(children=children, layout=box_layout)

# TAG_Wrapper for display
TAG_WRAPPER = VBox([
    tags,
    TAGS_SLIDERS,
])

# Dynamic update function
def update_tags_sliders(change):
    if change['name'] == 'value':
        current_tags = change['new']
        new_sliders = []

        for name in current_tags:
            if name in slider_cache:
                slider = slider_cache[name]
            else:
                slider = widgets.FloatSlider(
                    value=1.0,
                    min=0.1, max=5, step=0.1,
                    description=name,
                    tooltip=name,
                    layout=item_layout,
                    indent=False,
                    style=slider_style,
                )
                slider_cache[name] = slider

            new_sliders.append(slider)

        TAGS_SLIDERS.children = new_sliders

# Observe changes to the tags input
tags.observe(update_tags_sliders, names='value')


# Metric widgets
TEST_FAIRNESS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='This metric check if individuals assigned the same score s have equal likelihoods of the positive outcome, independent of the protected attribute'),
    VBox([VALID_GROUND, VALID_PRED]),
])

BALANCE_IN_CLASS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Balance for the positive and negative classes are fairness criteria that focus on ensuring equitable scoring among individuals who share the same actual outcome.'),
    Checkbox(layout=item_layout, description='Show the positive balance', indent=False),
    Checkbox(layout=item_layout, description='Show the negative balance', indent=False),
    VBox([VALID_GROUND, VALID_PRED]),
])

ERROR_RATE_METRICS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    VBox([VALID_GROUND, VALID_PRED]),
])

PREDICTIVE_VALUE_METRICS = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    VBox([VALID_GROUND, VALID_PRED]),
])

STATISTICAL_PARITY = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    TAG_WRAPPER,
    VBox([VALID_PRED]),
])

SIMILARITY_BASED = VBox([
    Checkbox(layout=item_layout, value=True, description='Use metric?', indent=False),
    Label(value='Please, make a selections'),
    IntText(value=1000, description='Sample size:'),
    RadioButtons(options=['Manhattan Distance', 'Euclidean Distance', 'cosine'],
                 value='Manhattan Distance',
                 style={'description_width': 'initial'},
                 description='Choose distance metric:'),
    VBox([VALID_PRED]),
])

# Define metrics
metrics_groups = [
    [ERROR_RATE_METRICS, PREDICTIVE_VALUE_METRICS, STATISTICAL_PARITY],
    [TEST_FAIRNESS, BALANCE_IN_CLASS],
    [SIMILARITY_BASED]
]

# Utility to simulate an accordion group with multi-open toggle buttons
def make_toggle_section(title, content):
    toggle = widgets.ToggleButton(value=False, description=title, layout=Layout(width='200px'))
    box = VBox([content])
    box.layout.display = 'block'

    def toggle_visibility(change):
        box.layout.display = 'none' if change['new'] else 'block'

    toggle.observe(toggle_visibility, names='value')
    return VBox([toggle, box])

# Apply toggle accordion logic to each metrics group
custom_groups = []
for group_metrics, titles in zip(metrics_groups, metric_names):
    group_box = VBox([
        make_toggle_section(title, content)
        for title, content in zip(titles, group_metrics)
    ])
    custom_groups.append(group_box)

# Create outer accordion using standard ipywidgets (since you only want one top-level group open)
outer_accordion = Accordion(children=custom_groups)
for i, title in enumerate(fairness_groups_names[:len(custom_groups)]):
    outer_accordion.set_title(i, title)

# Disables metric that are not usable with the current selection
def validate_valid_valids(metric_widgets):
    valid_valids = all(v.value for v in metric_widgets.children[-1].children)
    metric_widgets.children[0].value = valid_valids
    metric_widgets.children[0].disabled = not(valid_valids)

# Restore state
VALID_GROUND.value = bool(params.get('ground_truth_column', ''))
VALID_PRED.value = bool(params.get('prediction_column', ''))

[[validate_valid_valids(metric) for metric in group] for group in metrics_groups] 

default = [[metric.children[0].value for metric in group] for group in metrics_groups]
metric_selection = params.get('metric_selection', default)
for group, group_selection in zip(metrics_groups, metric_selection):
    for metric, selection in zip(group, group_selection):
        if not metric.children[0].disabled:
            metric.children[0].value = selection

BALANCE_IN_CLASS.children[2].value = params.get('balance_pos', True)
BALANCE_IN_CLASS.children[3].value = params.get('balance_neg', True)
SIMILARITY_BASED.children[2].value = params.get('sample_limit', 1000)
SIMILARITY_BASED.children[3].value = params.get('distance_metric', 'Manhattan Distance')

# Display the result
display(outer_accordion)

# Save parameters button
SAVE_CHANGES = Button(
    description='Save changes',
    button_style='success',
    tooltip='Save your current selection',
)
output_save_changes = widgets.Output()

def save_changes(id):
    with output_save_changes:
        params['metric_selection'] = [[metric.children[0].value for metric in group] for group in metrics_groups]
        params['balance_pos'] = BALANCE_IN_CLASS.children[2].value
        params['balance_neg'] = BALANCE_IN_CLASS.children[3].value
        params['sample_limit'] = SIMILARITY_BASED.children[2].value
        params['distance_metric'] = SIMILARITY_BASED.children[3].value
        params['condition'] = {slider.description:slider.value for slider in TAGS_SLIDERS.children}
        torch.save(params, utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET)

SAVE_CHANGES.on_click(save_changes)

display(SAVE_CHANGES, output_save_changes)

In [None]:
# Loading and calculating each metric 
params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

def flatten(xss):
    return [x for xs in xss for x in xs]

metrics = [[Error_rate_metrics, Predictive_value_metrics, statistical_parity],
           [well_calibration, balance_in_pos_neg],
           [LipschitzFairness]]

default = [[metric.children[0].value for metric in group] for group in metrics_groups]
metric_selection = params.get('metric_selection', default)
used_metrics = []
for metric, metric_name, selection in zip(flatten(metrics), flatten(metric_names), flatten(metric_selection)):
    if selection:
        used_metrics.append((metric(dataset, params), metric_name))
        print(f'{metric_name} calculated')

In [None]:
# show results of each metric
for metric, metric_name in used_metrics:
    print(f'{'#' * (len(metric_name) + 11)}\n# Metric {metric_name} #\n{'#' * (len(metric_name) + 11)}')
    for image in metric.show():
        image.show()

In [None]:
# Saving parameters to file
OUTPUT_FILE_NAME = Text(
    placeholder='Type output file name',
    description='File name:',
    tooltip='do not forget the file extension'
)
FEEDBACK_INFO = Label(
    value='',
)
SUBMIT_BUTTON = Button(
    description='Save preset',
    disabled=True,
    button_style='danger',
    tooltip='No file name',
)

info_box = VBox([
    Label(value=f'Select a name for your preset file here, This will be saved in the folder "{utils.SAVED_PRESET_PATH}"'),
    OUTPUT_FILE_NAME,
    SUBMIT_BUTTON,
    FEEDBACK_INFO,
])

output_savefile = widgets.Output()
display(info_box, output_savefile)


def update_file_name_info(id):
    with output_savefile:
        if type(id) == Button:
            params = torch.load(utils.SAVED_PRESET_PATH + utils.RESEVERD_PRESET, weights_only=True)

            torch.save(params, utils.SAVED_PRESET_PATH + OUTPUT_FILE_NAME.value)
            FEEDBACK_INFO.value=f'Preset saved to "{OUTPUT_FILE_NAME.value}"'
        
        if type(id) == Text:
            FEEDBACK_INFO.value=''
        
        file_exits = os.path.exists(utils.SAVED_PRESET_PATH + OUTPUT_FILE_NAME.value)
        if OUTPUT_FILE_NAME.value == '':
            SUBMIT_BUTTON.button_style='danger'
            SUBMIT_BUTTON.tooltip='No file name'
            SUBMIT_BUTTON.disabled=True
        elif OUTPUT_FILE_NAME.value == 'session_save.pt':
            SUBMIT_BUTTON.button_style='danger'
            SUBMIT_BUTTON.tooltip='This file name is reserved'
            SUBMIT_BUTTON.disabled=True
        elif not OUTPUT_FILE_NAME.value.endswith('.pt'):
            SUBMIT_BUTTON.button_style='warning'
            SUBMIT_BUTTON.tooltip='Please make it a .pt file'
            SUBMIT_BUTTON.disabled=False
        elif file_exits:
            SUBMIT_BUTTON.button_style='info'
            SUBMIT_BUTTON.tooltip='A file already exists with this name, overwrite is possible'
            SUBMIT_BUTTON.disabled=False
        else:
            SUBMIT_BUTTON.button_style='success'
            SUBMIT_BUTTON.tooltip='Save file'
            SUBMIT_BUTTON.disabled=False

        
OUTPUT_FILE_NAME.on_trait_change(update_file_name_info)
SUBMIT_BUTTON.on_click(update_file_name_info)
