# AutoFair Automated Toolkit

## Necessary Imports

In [2]:
# Imports
import ipywidgets as widgets
from ipywidgets import VBox, Layout, Button
#from IPython.display import HTML, display
from IPython.display import display
from ipywidgets import HTML

import pandas as pd
import time
import os
from omegaconf import DictConfig, OmegaConf
import hydra
import mlflow
import mlflow.sklearn

# AIF 360
from aif360.detectors.mdss.ScoringFunctions import Bernoulli
from aif360.detectors.mdss.MDSS import MDSS
from aif360.sklearn.datasets.openml_datasets import fetch_adult
from aif360.sklearn.detectors.facts.clean import clean_dataset
from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult
from aif360.sklearn.detectors.facts import FACTS_bias_scan, FACTS
from aif360.sklearn.metrics import ot_distance
from sklearn.linear_model import LogisticRegression

# Plots
import matplotlib.pyplot as plt
import seaborn as sns

# Sklearn imports
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier

# DiCE
import dice_ml
from dice_ml.utils import helpers  # helper functions
import io
import base64

import warnings
warnings.filterwarnings("ignore")




2025-05-07 07:25:15.713001: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-07 07:25:15.719497: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-07 07:25:15.726554: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-07 07:25:15.728682: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-07 07:25:15.734313: I tensorflow/core/platform/cpu_feature_guar

## Configuration-related information

In [3]:
from __future__ import annotations

import os
from typing import Any, Dict, List, Mapping, Optional, Tuple

DATA_HELPER_PATH = os.path.join("..", "autochoicebackend", "data_helper.py")

from data_helper import *
#init_mlflow_from_cfg, mlflow_client_from_cfg


def mlflow_init_cfg(
    config_dir: Optional[Union[str, Path]] = None,
    *,
    config_name: str = "config",
    honor_env_var: bool = True,
) -> Tuple[DictConfig, MlflowClient]:
    """
    Initialize MLflow from a Hydra config and return (cfg, MlflowClient).

    Parameters
    ----------
    config_dir : str | Path | None, optional
        Directory containing the Hydra YAML (e.g., ``humancompatible`` if your
        file lives at ``humancompatible/config.yaml``). If ``None``, this function
        will use the environment variable ``AUTOCHOICE_CONFIG_DIR`` when
        ``honor_env_var=True``, otherwise it defaults to ``Path.cwd() / 'humancompatible'``.
    config_name : str, default="config"
        Base name of the YAML file (without extension).
    honor_env_var : bool, default=True
        If True, use ``AUTOCHOICE_CONFIG_DIR`` when ``config_dir`` is not provided.

    Returns
    -------
    cfg : omegaconf.DictConfig
        The composed Hydra configuration.
    client : mlflow.tracking.MlflowClient
        An MLflow client configured according to the YAML.

    Notes
    -----
    - This function clears any previous Hydra global state so it can be safely
      re-run in notebooks without raising "Hydra is already initialized".
    - It calls :func:`init_mlflow_from_cfg` to set tracking/registry URIs and env
      variables, then builds a :class:`MlflowClient` via
      :func:`mlflow_client_from_cfg`.
    """
    # Resolve the config directory
    if config_dir is None and honor_env_var:
        env_dir = Path(str(Path.cwd() / "humancompatible"))  # sensible default
        auto_env = os.environ.get("AUTOCHOICE_CONFIG_DIR")
        config_dir = Path(auto_env) if auto_env else env_dir
    elif config_dir is None:
        config_dir = Path.cwd() / "humancompatible"
    else:
        config_dir = Path(config_dir)

    cfg_path = config_dir / f"{config_name}.yaml"
    if not cfg_path.exists():
        raise FileNotFoundError(f"Config not found: {cfg_path}")

    # Hydra-safe init for notebooks
    GlobalHydra.instance().clear()
    with initialize(version_base=None, config_path=str(config_dir)):
        cfg: DictConfig = compose(config_name=config_name)

    # Initialize MLflow from YAML and return a matching client
    init_mlflow_from_cfg(cfg)
    client = mlflow_client_from_cfg(cfg)
    return cfg, client



def load_hydra_config(config_path="config.yaml"):
    """
    Load the configuration file using OmegaConf for use in a Voila-compatible notebook.
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file '{config_path}' not found. Ensure it exists.")
    return OmegaConf.load(config_path)

# Load config
config = load_hydra_config()

# Access URI and set it
mlflow.set_tracking_uri(config.mlflow.tracking_uri)
# MLFLOW_TRACKING_URI = "http://192.168.1.151:5000"
# mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

PARALLEL_SCRIPT_PATH = os.path.join("..", "autochoicebackend", "run_experiments_parallel.py")

In [4]:
# Include the Font Awesome CSS
font_awesome_link = HTML('<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css" />')
display(font_awesome_link)

HTML(value='<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.mi…

In [5]:
# Embed the CSS file
with open('style.css', 'r') as css_file:
    css_content = css_file.read()
    display(HTML(f'<style>{css_content}</style>'))

HTML(value='<style>.my-font {\n    font-family: optima;\n    font-size: medium;\n}\n\n.circular {\n    border-…

In [6]:
def create_navigation_buttons():
    # Custom CSS for the buttons
    custom_css = """
    <style>
        .custom-nav-button {
            background-color: MediumSeaGreen !important;  /* Green background */
            color: white !important;               /* White text */
            border: none;                          /* No border */
            padding: 10px 20px;                    /* Padding */
            text-align: center;                    /* Centered text */
            text-decoration: none;                 /* No underline */
            display: inline-block;                 /* Inline-block display */
            font-size: 16px;                       /* Font size */
            margin: 4px 2px;                       /* Margin */
            cursor: pointer;                       /* Pointer cursor on hover */
            border-radius: 5px;                    /* Rounded corners */
            line-height: 1.2;                      /* Line height */
            vertical-align: middle;                /* Vertical alignment */
        }
        .custom-nav-button:hover {
            background-color: #45a049;             /* Darker green on hover */
        }
    </style>
    """

    # Display the CSS
    display(HTML(custom_css))

    # Create back button
    back_button = widgets.Button(
        description='Back',
        layout=widgets.Layout(width='100px', height='40px', margin='5px 5px 5px 5px'),
        style={'button_color': 'lightgray'}
    )
    back_button.add_class('custom-nav-button')
    
    # Create next button
    next_button = widgets.Button(
        description='Next',
        layout=widgets.Layout(width='100px', height='40px', margin='5px 5px 5px 5px'),
        style={'button_color': 'lightgray'}
    )
    next_button.add_class('custom-nav-button')
    
    return back_button, next_button



current_state = None

prev_state = {"Data" : None, "Analyze" : "Data", "Features" : "Analyze", "Model" : "Features", "Parameters" : "Model", "Run": "Parameters"}

next_state = {"Data" : "Analyze", "Analyze" : "Features", "Features" : "Model", "Model" : "Parameters", "Parameters" : "Run"}



def on_back_button_click(b):
    if prev_state[current_state] == "Data":
        upload_bias_detection_data(b)
    elif prev_state[current_state] == "Analyze":
        dataset_analysis(b)
    elif prev_state[current_state] == "Features":
        select_features(b)
    elif prev_state[current_state] == "Model":
        give_model_to_detect_bias_updated_checkboxes_new(b)
    elif prev_state[current_state] == "Parameters":
        give_parameters_updated_and_extended_all_newer_integr(b)
        
def on_next_button_click(b):
    if next_state[current_state] == "Analyze":
        dataset_analysis(b)
    elif next_state[current_state] == "Features":
        select_features(b)
    elif next_state[current_state] == "Model":
        give_model_to_detect_bias_updated_checkboxes_new(b)
    elif next_state[current_state] == "Parameters":
        give_parameters_updated_and_extended_all_newer_integr(b)
    elif next_state[current_state] == "Run":
        run_bias_detection_newer_version(b)

In [7]:
# Global variables
global_output = widgets.Output()
header = widgets.VBox()
steps_box = widgets.HBox()

back_button, next_button = create_navigation_buttons()
back_button.on_click(on_back_button_click)
next_button.on_click(on_next_button_click)

nav_buttons_box = widgets.HBox([back_button, next_button], layout=widgets.Layout(justify_content='space-between'))

output_status = widgets.HTML(
    value='',
    layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
)
features_status_output = widgets.Output()

df = None # Store the dataset here
selected_dataset = None
selected_algorithm = None
selected_algorithm_parameters = dict()

x = None 
y = None 
sample_weight = None

home_grid = None
features_4_scanning = None
random_seed = 131313 # for reproducibility

not_to_change_features = None
has_selected_not_to_change_features = False
not_to_change_features_status_output = widgets.Output(layout=widgets.Layout(margin='10px 10px', padding='5px 0px 5px 0px'))

# Map the dataset names to dataset files
dataset_name_2_file_name = dict()
dataset_name_2_file_name['Ad Campaign'] = "ad_campaign_data_small.csv"
dataset_name_2_file_name['Workable'] = "dataset1M.parquet"
# Dataset options with "Learn more" links
dataset_options = [
    {'name': 'Ad Campaign', 'url': 'https://developer.ibm.com/exchanges/data/all/bias-in-advertising/'},
    {'name': 'Adult (Census Income)', 'url': 'https://archive.ics.uci.edu/dataset/2/adult'},
    {'name': 'Workable', 'url': 'https://workable.com'}
]

def set_variables():
    global df, features_4_scanning, home_grid, has_selected_features
    df = None # Store the dataset here
    has_selected_features = False
    features_4_scanning = None

HTML(value='\n    <style>\n        .custom-nav-button {\n            background-color: MediumSeaGreen !importa…

In [8]:
# Display functions
def display_info_box(algorithm, dataset):
    info_message = f"""
    <div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">
        <span style="font-weight: bold; color: #333333; font-size: 16px">How to interpret the bias detection results?</span>
        <p>We are using the algorithm <b>{algorithm}</b> on the dataset '<b>{dataset}</b>'.</p>
        <p>The algorithm found the above subset as biased. Following is the distribution of values for each feature of the biased subset to help you analyze and understand their impact on the overall model performance.</p>
    </div>
    """
    display(HTML(info_message))

def display_message(message, color='black', font_size='16px', font_weight='normal'):
    html_message = widgets.HTML(
        value=f"<span style='color: {color}; font-size: {font_size}; font-weight: {font_weight};'>{message}</span>",
        layout=widgets.Layout(width='95%')
    )
    display(html_message)
    
def display_running_animation(color='black', duration=5):
    # Create the initial message widget
    running_msg = widgets.HTML(
        value='',
        layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
    )
    display(running_msg)
    
    start_time = time.time()
    while time.time() - start_time < duration:
        for i in range(4):  # For number of dots from 0 to 3
            dots = '.' * i
            running_message = f"""
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: {color}; padding: 10px; border-radius: 5px;">
                <span style="font-weight: bold;">Running{dots}</span>
            </div>
            """
            running_msg.value = running_message
            time.sleep(0.5) 

def display_dataframe_styled(dataframe):
    # Apply your styling to the dataframe
    styled_df = dataframe.style.set_properties(**{
        'background-color': 'white',  # Background color
        'border': '1px solid lightgray',  # Border properties
        'font-size': '14px',  # Adjust font size
        'text-align': 'left',  # Text alignment
        'width': '95%',
    }).set_table_styles([{
        'selector': 'th',
        'props': [
            ('background-color', '#f4f4f4'),  # Header background color
            ('text-align', 'left')]  # Header text alignment
    }])

    # Display the centered HTML
    display(styled_df)


import seaborn as sns
import matplotlib.pyplot as plt
import base64
import io
from IPython.display import HTML, display


def plot_histogram_grid_top10_spaced(df):
    sns.set_theme(style="whitegrid")
    features = df.columns
    num_features = len(features)
    num_cols = 2  # Fewer columns to give more width per plot
    num_rows = num_features // num_cols + (num_features % num_cols > 0)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12 * num_cols, 5 * num_rows))
    axes = axes.flatten()

    for i, feature in enumerate(features):
        ax = axes[i]
        series = df[feature].dropna()

        try:
            is_numeric = pd.api.types.is_numeric_dtype(series)

            if is_numeric:
                sns.histplot(data=series, bins=20, kde=True, ax=ax, color=sns.color_palette("Set2")[i % 8])
            else:
                top10 = series.value_counts().nlargest(10)
                sns.barplot(x=top10.index, y=top10.values, ax=ax, color=sns.color_palette("Set2")[i % 8])
                ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=10)

            ax.set_title(f'Histogram of {feature}', fontsize=14, fontweight='bold')
            ax.set_xlabel(feature, fontsize=12)
            ax.set_ylabel('Frequency', fontsize=12)

        except Exception as e:
            ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')

    for i in range(num_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=3.0, w_pad=3.0)

    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png', bbox_inches='tight')
    plt.close(fig)
    img_buf.seek(0)

    img_base64 = base64.b64encode(img_buf.read()).decode('utf-8')
    img_html = f'<div style="width: 90%; margin: auto; text-align: center;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"/></div>'
    display(HTML(img_html))


def plot_histogram_grid_new(df):
    sns.set_theme(style="whitegrid")
    features = df.columns
    num_features = len(features)
    num_rows = num_features // 3 + (num_features % 3 > 0)

    fig, axes = plt.subplots(num_rows, 3, figsize=(18, 5 * num_rows))
    axes = axes.flatten()

    for i, feature in enumerate(features):
        ax = axes[i]
        
        # Ensure the column is numeric and drop NaNs
        try:
            data = pd.to_numeric(df[feature], errors='coerce').dropna()
            if len(data) == 0:
                ax.text(0.5, 0.5, 'No valid data', ha='center', va='center')
                continue

            sns.histplot(data=data, bins=20, kde=True, ax=ax, color=sns.color_palette("Set2")[i % 8])
            ax.set_title(f'Histogram of {feature}', fontsize=14, fontweight='bold')
            ax.set_xlabel(feature, fontsize=12)
            ax.set_ylabel('Frequency', fontsize=12)

            if len(data.unique()) > 10:
                ax.set_xticks(ax.get_xticks()[::3])

            ax.tick_params(axis='x', rotation=45)

        except Exception as e:
            ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')

    # Remove unused axes
    for i in range(num_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()

    # Encode and display the image
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png', bbox_inches='tight')
    plt.close(fig)
    img_buf.seek(0)

    img_base64 = base64.b64encode(img_buf.read()).decode('utf-8')
    img_html = f'<div style="width: 90%; margin: auto; text-align: center;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"/></div>'
    display(HTML(img_html))



# Old function
def plot_histogram_grid(features):
    sns.set_theme(style="whitegrid")  # Use a white grid theme
    num_features = len(features)
    num_rows = num_features // 3 + (num_features % 3 > 0)  # Calculate the number of rows for the grid
    
    fig, axes = plt.subplots(num_rows, 3, figsize=(18, 5 * num_rows))
    axes = axes.flatten()

    for i, feature in enumerate(features):
        ax = axes[i]
        sns.histplot(data=df, x=feature, bins=20, kde=True, ax=ax, color=sns.color_palette("Set2")[i % 8])
        ax.set_title(f'Histogram of {feature}', fontsize=14, fontweight='bold')
        ax.set_xlabel(feature, fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.tick_params(axis='x', rotation=45)
        
        # Reduce the number of x-axis ticks
        if len(df[feature].unique()) > 10:  # Adjust the threshold as needed
            ax.set_xticks(ax.get_xticks()[::3])
        
        ax.tick_params(axis='x', rotation=45)
        
    # Remove any empty subplots
    for i in range(num_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()

    # Convert the figure to a PNG image in memory
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png', bbox_inches='tight')
    plt.close(fig)
    img_buf.seek(0)
    
    # Encode the image in base64 and display it within an HTML block that centers it
    img_base64 = base64.b64encode(img_buf.read()).decode('utf-8')
    img_html = f'<div style="width: 90%; margin: auto; text-align: center;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"/></div>'
    display(HTML(img_html))


    
def plot_distribution(data, title, ax):
    sns.barplot(x=list(data.keys()), y=list(data.values()), palette="Set2", ax=ax)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Category', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.tick_params(axis='x', rotation=45)

def print_report(_, subset):  # From AIF360 Notebook (modified for prettier printing)
    global df

    if subset:
        to_choose = df[subset.keys()].isin(subset).all(axis=1)
        filtered_df = df[['true_conversion', 'predicted_conversion']][to_choose]
    else:
        for col in features_4_scanning:
            subset[col] = list(df[col].unique())
        filtered_df = df[['true_conversion', 'predicted_conversion']]

    true = filtered_df['true_conversion'].sum()
    pred = filtered_df['predicted_conversion'].sum()

    # Enhanced subset information
    subset_html = '''
    <div>
        <h2>The algorithm found the following subset:</h2>
        <table style="width: 100%; border-collapse: collapse;">
            <thead>
                <tr>
                    <th style="border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;">Feature</th>
                    <th style="border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;">Values</th>
                </tr>
            </thead>
            <tbody>
    '''
    for key, values in subset.items():
        values_str = ', '.join(map(str, values))
        subset_html += f'''
                <tr>
                    <td style="border: 1px solid #ddd; padding: 8px;">{key}</td>
                    <td style="border: 1px solid #ddd; padding: 8px;">{values_str}</td>
                </tr>
        '''
    subset_html += '''
            </tbody>
        </table>
        <p><strong>Subset Size:</strong> {subset_size}</p>
        <p><strong>True Clicks:</strong> {true}</p>
        <p><strong>Predicted Clicks:</strong> {pred}</p>
    </div>
    '''.format(subset_size=len(filtered_df), true=true, pred=pred)
    display(HTML(subset_html))

    # display_info_box("Multi-Dimensional Subset Scan (MDSS)", "Ad Campaign Dataset")
    
    # Calculate number of rows and columns for the grid
    num_plots = len(subset)
    num_cols = 4 # keep it to at most 4 plots per row
    num_rows = (num_plots + num_cols - 1) // num_cols  # Calculate the number of rows needed

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(6 * num_cols, 6 * num_rows))
    axes = axes.flatten()

    # Plot distribution for each subset key
    for i, (key, values) in enumerate(subset.items()):
        value_counts = df[key].value_counts().to_dict()
        subset_counts = {value: value_counts.get(value, 0) for value in values}
        plot_distribution(subset_counts, f"Distribution of {key}", axes[i])
    
    # Remove any empty subplots
    for i in range(num_plots, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()
    plt.show()

In [9]:
def clean_toolkit_content():
    with global_output:
        global_output.clear_output(wait=True)  # Use wait=True for smoother updates
        display(header)
        display(steps_box)
        display(nav_buttons_box)

In [10]:
# Event handlers
    
def on_button_click(b):
    if b.description == 'Detect Bias':
        show_bias_detection_toolkit()
    elif b.description == 'Explainability Toolkit':
        #show_bias_detection_toolkit()
        show_explainability_toolkit()
    elif b.description == 'Fairness Trade-offs':
        show_tradeoff_toolkit()
    elif b.description == 'Apriori Certify Fairness':
        show_apriori_certification()
    else:
        raise ValueError

def update_status(loaded=False):
    status_color = 'red' if not loaded else 'green'
    status_text = 'No dataset is loaded.' if not loaded else 'Dataset loaded successfully.'
    status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
    status_message += f'<span style="font-weight: bold; color: #333333;">Status:</span> '
    status_message += f'<span style="color: {status_color}; font-weight: bold;">{status_text}</span>'
    status_message += '</div>'
    display(HTML(status_message))


        
def update_selected_features(change):
    global features_4_scanning, global_output, has_selected_features, features_status_output
    features_4_scanning = list(change.new)
    
    # Construct the message based on selected features
    if features_4_scanning:
        selected_features_str = ', '.join(features_4_scanning)
        has_selected_features = True
    else:
        has_selected_features = False

    # Construct the HTML message with the specified styling
    status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
    if has_selected_features:
        status_message += f'<span style="color: black; font-weight: bold;">Selected features/protected attributes: </span>'
        status_message += f'<span style="color: black;">{selected_features_str}</span>'
    else:
        status_message += f'<span style="color: red; font-weight: bold;">No selected features/protected attributes.</span>'
        
    status_message += '</div>'
    
    # Display the message in the global_output
    with features_status_output:
        features_status_output.clear_output(wait=True)
        display(HTML(status_message))
        
def update_selected_not_to_change_features(change):
    global not_to_change_features, global_output, \
        has_selected_not_to_change_features, \
        not_to_change_features_status_output
    
    not_to_change_features = list(change.new)
    # Construct the message based on selected features
    if not_to_change_features:
        not_to_change_features_str = ', '.join(not_to_change_features)
        has_selected_not_to_change_features = True
    else:
        has_selected_not_to_change_features = False

    # Construct the HTML message with the specified styling
    status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
    if has_selected_not_to_change_features:
        status_message += f'<span style="color: black; font-weight: bold;">Selected features not to change: </span>'
        status_message += f'<span style="color: black;">{not_to_change_features_str}</span>'
    else:
        status_message += f'<span style="color: red; font-weight: bold;">No selected features.</span>'
        
    status_message += '</div>'
    
    # Display the message in the global_output
    with not_to_change_features_status_output:
        not_to_change_features_status_output.clear_output(wait=True)
        display(HTML(status_message))

In [11]:
# Styling

# Common image dimensions
image_height = 100
image_width = 100

# Common layouts
text_button_css = """
<style>
    .text-button-class {
        background-color: transparent !important;
    }
</style>
"""

display(HTML(text_button_css))

text_button_layout = widgets.Layout(
    height='auto',      
    width='250px',  
    margin='10px 5px', 
    padding='10px',  
    border='1px solid #ccc',  
    background_color='transparent',  
    flex_flow='row wrap',  
    justify_content='flex-start',  
    align_items='center',  
    box_shadow='0 2px 4px 0 rgba(0,0,0,0.1)',
)

step_layout = widgets.Layout(display='flex', flex_flow='column', align_items='center', justify_content='center')

pipeline_title_style = {'font_size': '30px', 
                        'color': '#333333',                 # Dark grey text
                        'background_color': '#F0F0F0',      # Light background color
                        'border_radius': '5px',             # Slightly rounded corners
                        'border': '1px solid #CCCCCC'       # Light grey border
                        }

# Layout for step buttons
button_layout_wide = widgets.Layout(
    width='300px',  
    height='50px',
    margin='5px',  
    padding='5px',  
    border='1px solid #CCCCCC',     # Light grey border
    border_radius='5px',            # Slightly rounded corners
    font_weight='bold',
    color='#333333',                # Dark text color
    background_color='#F0F0F0',     # Light background color
    align_items='center',
    justify_content='center'
)

# Define the layout for the main content to allow it to be scrollable
main_content_layout = widgets.Layout(
    overflow_y='auto',  # Allow vertical scrolling
    margin='0px', 
    padding='10px', 
    width='100%'
)
global global_output
global_output.layout = main_content_layout
display(global_output)

HTML(value='\n<style>\n    .text-button-class {\n        background-color: transparent !important;\n    }\n</s…

Output(layout=Layout(margin='0px', padding='10px', width='100%'))

In [12]:
def display_home_screen():
    global global_output
    with global_output:
        global_output.clear_output(wait=True)  # Clear existing content
        # Title
        title = widgets.Label(value='humancompatible.autochoice', 
                              layout=widgets.Layout(margin='120px 0px 20px 0px'), 
                              style={'font_size': '38px'})
        title.add_class('my-font')


        # Subtitle
        subtitle = widgets.Label(value='An automated tool for evaluating fairness in AI/ML models', 
                              layout=widgets.Layout(margin='10px 0px 70px 0px'), 
                              style={'font_size': '24px'})
        subtitle.add_class('my-font')

        subtitle2 = widgets.Label(value='Developed and maintained by NKUA team', 
                              layout=widgets.Layout(margin='10px 0px 70px 0px'), 
                              style={'font_size': '20px'})
        subtitle2.add_class('my-font')        


        # Define layout for rows and columns
        vbox_layout = widgets.Layout(display='flex', flex_flow='column', 
                                     align_items='center', justify_content='center')
        
        hbox_layout = widgets.Layout(display='flex',
                                     flex_flow='row',
                                     align_items='center',
                                     justify_content='space-between')  # Use space-between for even spacing


        # Load icons
        with open('detect_bias_icon.png', 'rb') as file:
            detect_bias_icon = file.read()
        with open('explainability_toolkit_icon.png', 'rb') as file:
            explainability_icon = file.read()
        with open('tradeoff_icon.png', 'rb') as file:
            tradeoff_icon = file.read()
        with open('certification_icon.png', 'rb') as file:
            certification_icon = file.read()

        # Creating image and button widgets
        detect_bias_image = widgets.Image(value=detect_bias_icon, format='png', width=image_width, height=image_height)
        bias_detection_button = widgets.Button(description='Detect Bias', layout=text_button_layout, style={'button_color': 'white'})
        bias_detection_button.on_click(on_button_click)
        bias_detection_button.add_class('my-font')
        bias_detection_button.add_class('text-button-class')

        explainability_image = widgets.Image(value=explainability_icon, format='png', width=image_width, height=image_height)
        explainability_button = widgets.Button(description='Explainability Toolkit', layout=text_button_layout, style={'button_color': 'white'})
        explainability_button.on_click(on_button_click)
        explainability_button.add_class('my-font')
        explainability_button.add_class('text-button-class')

        tradeoff_image = widgets.Image(value=tradeoff_icon, format='png', width=image_width, height=image_height)
        tradeoff_button = widgets.Button(description='Fairness Trade-offs', layout=text_button_layout, style={'button_color': 'white'})
        tradeoff_button.on_click(on_button_click)
        tradeoff_button.add_class('my-font')
        tradeoff_button.add_class('text-button-class')

        certification_image = widgets.Image(value=certification_icon, format='png', width=image_width, height=image_height)
        certification_button = widgets.Button(description='Apriori Certify Fairness', layout=text_button_layout, style={'button_color': 'white'})
        certification_button.on_click(on_button_click)
        certification_button.add_class('my-font')
        certification_button.add_class('text-button-class')

        detect_bias_step = widgets.VBox([detect_bias_image, bias_detection_button], layout=step_layout)
        explainability_step = widgets.VBox([explainability_image, explainability_button], layout=step_layout)
        tradeoff_step = widgets.VBox([tradeoff_image, tradeoff_button], layout=step_layout)
        certification_step = widgets.VBox([certification_image, certification_button], layout=step_layout)

        arrow_button = widgets.Button(
            description='',         # No text, only icon
            icon='arrow-right',     # Specify the FontAwesome icon name here
            layout=widgets.Layout(width='auto', height='auto', overflow='visible'),
            style={'button_color': 'transparent'}  # Attempt to set the button background to transparent
        )

        arrow_button.add_class('my-font')
        row = widgets.HBox([detect_bias_step, arrow_button, 
                            explainability_step, arrow_button, 
                            tradeoff_step, arrow_button, certification_step], layout=hbox_layout)

        # Adjust the layout to place the arrows correctly
        row.layout.justify_content = 'space-between'
        # Combine title and row in a VBox
        home_screen_layout = widgets.VBox([title, subtitle, subtitle2, row], layout=vbox_layout)
        
        display(home_screen_layout)  # Display home screen layout in global_output

In [13]:
def display_bias_detection_welcome_page():
    info_message = """
    <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
        <span style="font-weight: bold; color: #333333; font-size: 16px;">
            <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the automated bias detection toolkit.
        </span>
        <p style="font-size: 14px; margin-top: 10px;">
            Follow the pipeline to load your desired data, analyze them, and run bias detection algorithms.
        </p>
        <h3 style="font-size: 16px; color: #333333;">Bias Detection</h3>
        <p style="font-size: 14px;">
            Bias occurs in data used to train a model. We provide sample datasets, metrics and algorithms that you can use to explore bias checking.
        </p>
        <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
        <ol style="font-size: 14px; padding-left: 20px;">
            <li>Load your dataset by following the pipeline steps. Currently you can only select from available options.</li>
            <li>Analyze the dataset and plot metrics for potential biases.</li>
            <li>Select and run the bias detection algorithm to identify biased subgroups.</li>
            <li>Review the results and take necessary actions to mitigate bias.</li>
        </ol>
    </div>
    """
    
    info_box = widgets.HTML(
        value=info_message,
        layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
    )
    with global_output:
        display(info_box)


def display_explainability_toolkit_welcome_page():
    info_message = """
    <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
        <span style="font-weight: bold; color: #333333; font-size: 16px;">
            <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the automated explainability toolkit.
        </span>
        <p style="font-size: 14px; margin-top: 10px;">
            Follow the pipeline to load your desired data, analyze them, and run explainability algorithms.
        </p>
        <h3 style="font-size: 16px; color: #333333;">Explainability</h3>
        <p style="font-size: 14px;">
            Explainability focuses on the reasoning behind the decisions or predictions made by ML models to make them more understandable and transparent. We provide sample datasets, metrics and algorithms that you can use to explore explainability.
        </p>
        <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
        <ol style="font-size: 14px; padding-left: 20px;">
            <li>Load your dataset by following the pipeline steps. Currently you can upload your own or select from available options.</li>
            <li>Analyze the dataset and plot metrics for potential biases.</li>
            <li>Select and run explainability algorithms with different parameters to identify biased subgroups.</li>
            <li>Review the results and take necessary actions to mitigate bias.</li>
        </ol>
    </div>
    """
    
    info_box = widgets.HTML(
        value=info_message,
        layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
    )
    with global_output:
        display(info_box)



def display_fairness_measures_welcome_page():
    info_message = """
    <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
        <span style="font-weight: bold; color: #333333; font-size: 16px;">
            <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the fairness measures trade-off toolkit.
        </span>
        <p style="font-size: 14px; margin-top: 10px;">
            Follow the pipeline to load your desired data, analyze them, and run fairness measures trade-off algorithms.
        </p>
        <h3 style="font-size: 16px; color: #333333;">Fairness Measures Trade-offs</h3>
        <p style="font-size: 14px;">
            The tradeo  between di erent measures of fairness is captured by the so-called Pareto front, a visual representing the most one can achieve of a particular measure without sacri cing another.
        </p>
        <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
        <ol style="font-size: 14px; padding-left: 20px;">
            <li>Load your dataset by following the pipeline steps. Currently you can upload your own or select from available options.</li>
            <li>Analyze the dataset and plot metrics for potential biases.</li>
            <li>Select and run exploration of trade-off measures algorithms with different parameters.</li>
            <li>Review the results and take necessary actions.</li>
        </ol>
    </div>
    """
    
    info_box = widgets.HTML(
        value=info_message,
        layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
    )
    with global_output:
        display(info_box)



In [14]:
def create_arrow():
    """Creates an arrow symbol between step widgets."""
    return widgets.HTML(value='<i class="fa fa-arrow-right" style="font-size:18px;color:grey;"></i>',
                layout=widgets.Layout(display='flex', justify_content='center', align_items='center', height='auto', width='auto', margin='0px 10px'))

In [1]:
### Work Packages

def create_step_widget(icon, text):
    """Creates a circular button with an icon and a text label below it."""
    # Button with an icon
    button = widgets.Button(icon=icon, layout=Layout(width='70px', height='70px', margin='0px 5px'))
    button.style.button_color = 'lightgray'
    button.style.border_radius = '35px'
    button.add_class('circular')
    label = widgets.HTML(value=f"<div style='text-align:center; color:grey;'>{text}</div>", 
                 layout=Layout(height='auto', width='auto'))
    label.add_class('my-font')
    # Combine button and label in a VBox
    step_widget = VBox([button, label], layout=Layout(align_items='center', justify_content='center', margin='0px 10px', overflow='hidden'))
    return step_widget

def show_bias_detection_toolkit():
    with global_output:
        global_output.clear_output(wait=True)
        set_variables()
        # Home button
        home_button = Button(
            description='Home',
            layout=Layout(width='auto', 
                        height='auto', 
                        margin='0px 0px 5px 0px'),  
            style={'button_color': 'white', 'font_size': '18px'},  
            icon='home',  
        )
        home_button.add_class('my-font')
        home_button.on_click(lambda b: display_home_screen())  # Use lambda for simplicity
        
        home_button_box = widgets.Box(children=[home_button], layout=Layout(
            overflow='hidden',
            width='100%', 
            height='auto',
            display='flex',
            flex_flow='column',  
            align_items='flex-start',
        ))
        home_button_box.add_class('header-box')  # Assuming this class sets the necessary CSS
        
        page_title = widgets.Label(
            'Bias Detection Toolkit', 
            layout=widgets.Layout(
                margin='0px 0px 15px 0px',
                padding='0px',
                justify_content='center'
            ),
            style=pipeline_title_style
        )
        page_title.add_class('my-font')
        
        ### PIPELINE STEPS ###
        # Upload Button #
        # upload_button = widgets.FileUpload(accept='.csv', 
        #                                 description='',
        #                                 multiple=False, 
        #                                 layout=Layout(width='70px', height='70px', margin='0px 5px'))
        # upload_button.style.button_color = 'lightgray'
        # upload_button.style.border_radius = '35px'
        # upload_button.add_class('circular')
        # upload_button.observe(on_file_upload_change, names='value')
        # upload_button.add_class('my-font')

        # label = HTML(value="<div style='text-align:center; color:grey;'>Upload</div>", 
        #             layout=Layout(height='auto', width='auto'))
        # label.add_class('my-font')
        # # Combine button and label
        # upload_step = VBox([upload_button, label], layout=Layout(align_items='center', justify_content='center', margin='0px 10px', overflow='hidden'))

        # Upload Button #
        upload_step = create_step_widget('upload', 'Data') 
        upload_step.children[0].on_click(upload_bias_detection_data)
        
        # Analyze Button #
        analyze_step = create_step_widget('area-chart', 'Analyze')
        analyze_step.children[0].on_click(dataset_analysis)
        
        # Features Button #
        features_step = create_step_widget('tasks', 'Features')
        features_step.children[0].on_click(select_features)

        # Model Button #

        #We modify this in order to get the available algorithms here
        # At the moment, apart from those in the toolkit, we will have
        # from AIF360, Reweighing for pre-processing and for post-processing
        # EqOddsPostprocessing, CalibratedEqOddsPostprocessing, RejectOptionClassification
        # 
        model_step = create_step_widget('cogs', 'Model')
        model_step.children[0].on_click(give_model_to_detect_bias_updated_checkboxes_new)
        
        # Parameters Button #
        parameters_step = create_step_widget('sliders-h', 'Parameters')
        parameters_step.children[0].on_click(give_parameters_updated_and_extended_all_newer_integr)
        
        # Run Button #
        run_step = create_step_widget('play', 'Run')
        run_step.children[0].on_click(run_bias_detection_newer_version)
        
        # Define arrows
        arrow_1 = create_arrow()
        arrow_2 = create_arrow()
        arrow_3 = create_arrow()
        arrow_4 = create_arrow()
        arrow_5 = create_arrow()
        global header, steps_box
        header.children = [home_button_box, page_title]
        header.add_class('fixed-header')        
        
        steps_box.children = [upload_step, arrow_1, analyze_step, arrow_2, features_step, 
                                arrow_3, model_step, arrow_4, parameters_step, 
                                arrow_5, run_step]
        steps_box.layout = Layout(display='flex', flex_flow='row', 
                                    justify_content='center', align_items='center', 
                                    margin='0px', width='100%', height='20vh')

        steps_box.add_class('fixed-steps-box')
        
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the automated bias detection toolkit.
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Follow the pipeline to load your desired data, analyze them, and run bias detection algorithms.
            </p>
            <h3 style="font-size: 16px; color: #333333;">Bias Detection</h3>
            <p style="font-size: 14px;">
                Bias occurs in data used to train a model. We provide sample datasets, metrics and algorithms that you can use to explore bias checking.
            </p>
            <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
            <ol style="font-size: 14px; padding-left: 20px;">
                <li>Load your dataset by following the pipeline steps. Currently you can only select from available options.</li>
                <li>Analyze the dataset and plot metrics for potential biases.</li>
                <li>Select and run the bias detection algorithm to identify biased subgroups.</li>
                <li>Review the results and take necessary actions to mitigate bias.</li>
            </ol>
        </div>
        """
        
        info_box = widgets.HTML(
            value=info_message,
            layout=widgets.Layout(margin='15px 0px', width='100%', padding='5px 0px 5px 0px')
        )
        
        start_button = widgets.Button(
            description='Start',
            layout=widgets.Layout(width='150px', height='50px', margin='10px 0px', border_radius='10px'),
            style={'button_color': 'MediumSeaGreen', 'font_size': '16px', 'text_color': 'white','font_weight': 'bold', 'color': 'white'}
        )
        start_button.add_class('my-font')
        start_button.on_click(upload_bias_detection_data)
        # Create an HBox to right-align the button
        start_button_hbox = widgets.HBox([start_button], layout=widgets.Layout(justify_content='center', width='100%'))
        
        toolkit_layout_start = widgets.VBox([header, info_box, start_button_hbox], 
                                      layout=widgets.Layout(width='100%', flex_flow='column'))
        display(toolkit_layout_start)  # Display the toolkit layout  
   
def show_explainability_toolkit():
    with global_output:
        global_output.clear_output(wait=True)
        set_variables()
        #print("here I am!!")
        # Home button
        back_button = Button(
            description='Home',
            layout=Layout(width='auto', 
                        height='auto', 
                        margin='0px 0px 5px 0px'),  
            style={'button_color': 'white', 'font_size': '18px'},  
            icon='home',  
        )
        back_button.add_class('my-font')
        back_button.on_click(lambda b: display_home_screen())  # Use lambda for simplicity
        
        back_button_box = widgets.Box(children=[back_button], layout=Layout(
            overflow='hidden',
            width='100%', 
            height='auto',
            display='flex',
            flex_flow='column',  
            align_items='flex-start',
        ))
        back_button_box.add_class('header-box')
        
        page_title = widgets.Label(
            'Explainability Toolkit', 
            layout=widgets.Layout(
                margin='0px 0px 15px 0px',
                padding='0px',
                justify_content='center'
            ),
            style=pipeline_title_style
        )
        page_title.add_class('my-font')

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the automated explainability toolkit.
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Follow the pipeline to load your desired data, analyze them, and run explainability algorithms.
            </p>
            <h3 style="font-size: 16px; color: #333333;">Explainability</h3>
            <p style="font-size: 14px;">
                Explainability focuses on the reasoning behind the decisions or predictions made by ML models to make them more understandable and transparent. We provide sample datasets, metrics and algorithms that you can use to explore explainability.
            </p>
            <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
            <ol style="font-size: 14px; padding-left: 20px;">
                <li>Load your dataset by following the pipeline steps. Currently you can upload your own or select from available options.</li>
                <li>Analyze the dataset and plot metrics for potential biases.</li>
                <li>Select and run explainability algorithms with different parameters to identify biased subgroups.</li>
                <li>Review the results and take necessary actions to mitigate bias.</li>
            </ol>
        </div>
        """
        
        info_box = widgets.HTML(
            value=info_message,
            layout=widgets.Layout(margin='15px 0px', width='100%', padding='5px 0px 5px 0px')
        )

        start_button = widgets.Button(
            description='Start',
            layout=widgets.Layout(width='150px', height='50px', margin='10px 0px', border_radius='10px'),
            style={'button_color': 'MediumSeaGreen', 'font_size': '16px', 'text_color': 'white','font_weight': 'bold', 'color': 'white'}
        )
        start_button.add_class('my-font')
        start_button.on_click(upload_bias_detection_data)
        # Create an HBox to right-align the button
        start_button_hbox = widgets.HBox([start_button], layout=widgets.Layout(justify_content='center', width='100%'))
        
        ### PIPELINE STEPS ###
        # Upload Button #
        upload_button = widgets.FileUpload(accept='.csv', 
                                        description='',
                                        multiple=False, 
                                        layout=Layout(width='70px', height='70px', margin='0px 5px'))
        upload_button.style.button_color = 'lightgray'
        upload_button.style.border_radius = '35px'
        upload_button.add_class('circular')
        upload_button.observe(select_explainability_dataset, names='value')
        upload_button.add_class('my-font')

        label = widgets.HTML(value="<div style='text-align:center; color:grey;'>Upload</div>", 
                    layout=Layout(height='auto', width='auto'))
        label.add_class('my-font')
        # Combine button and label
        upload_step = VBox([upload_button, label], layout=Layout(align_items='center', justify_content='center', margin='0px 10px', overflow='hidden'))
        
        # Model Button #
        model_step = create_step_widget('cogs', 'Model')
        model_step.children[0].on_click(give_parameters_and_explainability_model)
        
        # Features Button #
        features_step = create_step_widget('tasks', 'Action')
        features_step.children[0].on_click(select_explainability_action)

        
        # # Parameters Button #
        # parameters_step = create_step_widget('sliders-h', 'Parameters')
        # parameters_step.children[0].on_click(give_parameters)
        
        # Run Button #
        run_step = create_step_widget('play', 'Run')
        run_step.children[0].on_click(run_explainability)
        
        # Define arrows
        arrow_1 = create_arrow()
        arrow_2 = create_arrow()
        arrow_3 = create_arrow()
        global header
        header.children = [back_button_box, page_title]
        header.add_class('fixed-header')
        global steps_box
        steps_box.children = [upload_step, arrow_1, features_step, 
                                arrow_2, model_step, arrow_3, run_step]
        steps_box.layout = Layout(display='flex', flex_flow='row', 
                                    justify_content='center', align_items='center', 
                                    margin='0px', width='100%', height='20vh')

        steps_box.add_class('fixed-steps-box')

        # Combine everything into the main layout for the toolkit
        toolkit_layout = widgets.VBox([header, info_box, start_button_hbox], 
                                      layout=widgets.Layout(width='100%', flex_flow='column'))
        
        display(toolkit_layout)  # Display the toolkit layout  

def show_tradeoff_toolkit():
    with global_output:
        global_output.clear_output(wait=True)
        set_variables()

        # Home button
        back_button = widgets.Button(
            description='Home',
            layout=Layout(width='auto', height='auto', margin='0px 0px 5px 0px'),
            style={'button_color': 'white', 'font_size': '18px'},
            icon='home',
        )
        back_button.add_class('my-font')
        back_button.on_click(lambda b: display_home_screen())

        back_button_box = widgets.Box(
            children=[back_button],
            layout=Layout(
                overflow='hidden',
                width='100%',
                height='auto',
                display='flex',
                flex_flow='column',
                align_items='flex-start',
            ),
        )
        back_button_box.add_class('header-box')

        page_title = widgets.Label(
            'Examine Trade-off Between Measures of Fairness',
            layout=widgets.Layout(margin='0px 0px 15px 0px', padding='0px', justify_content='center'),
            style=pipeline_title_style
        )
        page_title.add_class('my-font')

        # Info box
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This is the automated fairness toolkit.
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Follow the pipeline to load your desired data, analyze them, and run fairness trade-offs algorithms.
            </p>
            <h3 style="font-size: 16px; color: #333333;">Fairness Trade-offs</h3>
            <p style="font-size: 14px;">
                The tradeoff  between different measures of fairness is captured by the so-called Pareto front, a visual representing the most one can achieve of a particular measure without sacrificing another.
            </p>
            <h3 style="font-size: 16px; color: #333333;">How to Use This Toolkit</h3>
            <ol style="font-size: 14px; padding-left: 20px;">
                <li>Load your dataset by following the pipeline steps. Currently you can upload your own or select from available options.</li>
                <li>Analyze the dataset and plot metrics to explore.</li>
                <li>Select and run fairness measure trade-off algorithms with different parameters.</li>
                <li>Review the results and take necessary actions.</li>
            </ol>
        </div>
        """
        info_box = widgets.HTML(
            value=info_message,
            layout=widgets.Layout(margin='15px 0px', width='100%', padding='5px 0px 5px 0px')
        )

        # Start button (launch first step)
        start_button = widgets.Button(
            description='Start',
            layout=widgets.Layout(width='150px', height='50px', margin='10px 0px', border_radius='10px'),
            style={'button_color': 'MediumSeaGreen', 'font_size': '16px', 'text_color': 'white', 'font_weight': 'bold', 'color': 'white'}
        )
        start_button.add_class('my-font')
        if 'upload_bias_detection_data' in globals():
            start_button.on_click(upload_bias_detection_data)

        start_button_hbox = widgets.HBox([start_button], layout=widgets.Layout(justify_content='center', width='100%'))

        # ---- PIPELINE STEPS ----
        # Data
        # data_step = create_step_widget('upload', 'Data')
        # if 'upload_bias_detection_data' in globals():
        #     data_step.children[0].on_click(upload_bias_detection_data)

        # # Measures (select which fairness metrics to compare)
        # measures_step = create_step_widget('balance-scale', 'Measures')
        # if 'select_tradeoff_measures' in globals():
        #     measures_step.children[0].on_click(select_tradeoff_measures)

        # # Parameters
        # parameters_step = create_step_widget('sliders-h', 'Parameters')
        # if 'give_parameters_updated_and_extended_all_newer_integr' in globals():
        #     parameters_step.children[0].on_click(give_parameters_updated_and_extended_all_newer_integr)


        # Upload Button #
        upload_step = create_step_widget('upload', 'Data') 
        upload_step.children[0].on_click(upload_bias_detection_data)
        
        # Analyze Button #
        analyze_step = create_step_widget('area-chart', 'Analyze')
        analyze_step.children[0].on_click(dataset_analysis)
        
        # Features Button #
        features_step = create_step_widget('tasks', 'Features')
        features_step.children[0].on_click(select_features)
        
        # Run
        run_step = create_step_widget('play', 'Run')
        if 'run_tradeoff' in globals():
            run_step.children[0].on_click(run_tradeoff)

        # Arrows
        arrow_1 = create_arrow()
        arrow_2 = create_arrow()
        arrow_3 = create_arrow()

        # Header + steps box (fixed)
        global header, steps_box
        header.children = [back_button_box, page_title]
        header.add_class('fixed-header')

        steps_box.children = [data_step, arrow_1, measures_step, arrow_2, parameters_step, arrow_3, run_step]
        steps_box.layout = Layout(display='flex', flex_flow='row',
                                  justify_content='center', align_items='center',
                                  margin='0px', width='100%', height='20vh')
        steps_box.add_class('fixed-steps-box')



        # Display layout
        toolkit_layout = widgets.VBox([header, info_box, start_button_hbox],
                                      layout=widgets.Layout(width='100%', flex_flow='column'))
        display(toolkit_layout)


def show_apriori_certification():
    with global_output:
        global_output.clear_output(wait=True)
        set_variables()

        # Home button
        back_button = widgets.Button(
            description='Home',
            layout=Layout(width='auto', height='auto', margin='0px 0px 5px 0px'),
            style={'button_color': 'white', 'font_size': '18px'},
            icon='home',
        )
        back_button.add_class('my-font')
        back_button.on_click(lambda b: display_home_screen())

        back_button_box = widgets.Box(
            children=[back_button],
            layout=Layout(
                overflow='hidden',
                width='100%',
                height='auto',
                display='flex',
                flex_flow='column',
                align_items='flex-start',
            ),
        )
        back_button_box.add_class('header-box')

        page_title = widgets.Label(
            'Apriori Certify Fairness',
            layout=widgets.Layout(margin='0px 0px 15px 0px', padding='0px', justify_content='center'),
            style=pipeline_title_style
        )
        page_title.add_class('my-font')

        # Info box
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-info-circle" style="margin-right: 10px;"></i> This toolkit certifies fairness a priori.
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Provide a model and constraints to verify fairness properties before deployment, using data schema and selected features.
            </p>
            <h3 style="font-size: 16px; color: #333333;">How to Use</h3>
            <ol style="font-size: 14px; padding-left: 20px;">
                <li>Load or select a dataset (or schema) as context.</li>
                <li>Select/inspect features relevant to certification.</li>
                <li>Adjust input parameters (thresholds, protected attributes, constraints).</li>
                <li>Review the results and take necessary actions.</li>
            </ol>
        </div>
        """
        info_box = widgets.HTML(
            value=info_message,
            layout=widgets.Layout(margin='15px 0px', width='100%', padding='5px 0px 5px 0px')
        )

        # Start button (launch first step)
        start_button = widgets.Button(
            description='Start',
            layout=widgets.Layout(width='150px', height='50px', margin='10px 0px', border_radius='10px'),
            style={'button_color': 'MediumSeaGreen', 'font_size': '16px', 'text_color': 'white', 'font_weight': 'bold', 'color': 'white'}
        )
        start_button.add_class('my-font')
        if 'upload_bias_detection_data' in globals():
            start_button.on_click(upload_bias_detection_data)

        start_button_hbox = widgets.HBox([start_button], layout=widgets.Layout(justify_content='center', width='100%'))

        # ---- PIPELINE STEPS ----
        # Data
        data_step = create_step_widget('upload', 'Data')
        if 'upload_bias_detection_data' in globals():
            data_step.children[0].on_click(upload_bias_detection_data)

        # Features
        features_step = create_step_widget('tasks', 'Features')
        if 'select_features' in globals():
            features_step.children[0].on_click(select_features)

        # Parameters
        parameters_step = create_step_widget('sliders-h', 'Parameters')
        # Reuse the unified parameters panel if available
        if 'give_parameters_updated_and_extended_all_newer_integr' in globals():
            parameters_step.children[0].on_click(give_parameters_updated_and_extended_all_newer_integr)
        elif 'give_input_parameters' in globals():
            parameters_step.children[0].on_click(give_input_parameters)

        # Model (to certify)
        model_step = create_step_widget('cogs', 'Model')
        # Prefer a specific cert function if present
        if 'give_model_to_certify' in globals():
            model_step.children[0].on_click(give_model_to_certify)
        elif 'give_model_to_detect_bias_updated_checkboxes_new' in globals():
            model_step.children[0].on_click(give_model_to_detect_bias_updated_checkboxes_new)

        # Run
        run_step = create_step_widget('play', 'Run')
        if 'run_apriori_certification' in globals():
            run_step.children[0].on_click(run_apriori_certification)
        elif 'run_bias_detection_newer_version' in globals():
            run_step.children[0].on_click(run_bias_detection_newer_version)

        # Arrows
        arrow_1 = create_arrow()
        arrow_2 = create_arrow()
        arrow_3 = create_arrow()
        arrow_4 = create_arrow()

        # Header + steps box (fixed)
        global header, steps_box
        header.children = [back_button_box, page_title]
        header.add_class('fixed-header')

        steps_box.children = [data_step, arrow_1, features_step, arrow_2, parameters_step, arrow_3, model_step, arrow_4, run_step]
        steps_box.layout = Layout(display='flex', flex_flow='row',
                                  justify_content='center', align_items='center',
                                  margin='0px', width='100%', height='20vh')
        steps_box.add_class('fixed-steps-box')

        # Display layout
        toolkit_layout = widgets.VBox([header, info_box, start_button_hbox],
                                      layout=widgets.Layout(width='100%', flex_flow='column'))
        display(toolkit_layout)



def show_tradeoff_old():
    # Clear the vbox contents
    vbox.children = []

    # Back to Pipeline button
    back_button = widgets.Button(description='Back to Pipeline', layout=widgets.Layout(width='150px', height='30px', margin='5px'))
    def back_to_pipeline(b):
        vbox.children = home_grid
    back_button.on_click(back_to_pipeline)
    
    title_label = widgets.Label('Examine Trade-off Between Measures of Fairness', 
                                layout=widgets.Layout(margin='10px'),
                                style={'font_size': '20px', 'font_weight': 'bold'})

    # Updated layout for wider buttons
    button_layout_wide = widgets.Layout(width='300px', height='50px', margin='5px', border='solid black 1px')

    # Creating new buttons for the bias detection toolkit
    button_load_data = widgets.Button(description='Load Data', layout=button_layout_wide)
    button_visualize_features = widgets.Button(description='Visualize/Select Features', layout=button_layout_wide)
    button_input_parameters = widgets.Button(description='Give Parameters of Model', layout=button_layout_wide)
    button_select_action = widgets.Button(description='Select Measures', layout=button_layout_wide)
    button_run = widgets.Button(description='Run', layout=button_layout_wide)
    button_return_results = widgets.Button(description='Return Results & Visualize', layout=button_layout_wide)

    # Creating centered downward arrow labels
    arrow_down_style = {'font_size': '24px'}
    arrow_down1 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down2 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down3 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down4 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down5 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)

    # Create a new vertical box (VBox) with centered alignment
    vbox_new_layout = widgets.Layout(display='flex', flex_flow='column', align_items='center', width='100%')
    vbox_new = widgets.VBox([back_button, title_label, button_load_data, arrow_down1, button_visualize_features, 
                             arrow_down2, button_input_parameters, arrow_down3, button_select_action, arrow_down4, 
                             button_run, arrow_down5, button_return_results], layout=vbox_new_layout)

    # Add the new VBox to the vbox
    vbox.children = [vbox_new]
    
def show_apriori_certification_old():
    # Clear the vbox contents
    vbox.children = []

    # Back to Pipeline button
    back_button = widgets.Button(description='Back to Pipeline', layout=widgets.Layout(width='150px', height='30px', margin='5px'))
    def back_to_pipeline(b):
        vbox.children = home_grid
    back_button.on_click(back_to_pipeline)
    
    title_label = widgets.Label('Apriori Certify Fairness', 
                                layout=widgets.Layout(margin='10px'),
                                style={'font_size': '20px', 'font_weight': 'bold'})

    # Updated layout for wider buttons
    button_layout_wide = widgets.Layout(width='300px', height='50px', margin='5px', border='solid black 1px')

    # Creating new buttons for the bias detection toolkit
    button_load_data = widgets.Button(description='Load Data', layout=button_layout_wide)
    button_visualize_features = widgets.Button(description='Visualize/Select Features', layout=button_layout_wide)
    button_input_parameters = widgets.Button(description='Give Input Parameters', layout=button_layout_wide)
    button_select_action = widgets.Button(description='Give Model to Certify', layout=button_layout_wide)
    button_run = widgets.Button(description='Run', layout=button_layout_wide)
    button_return_results = widgets.Button(description='Return Results & Visualize', layout=button_layout_wide)

    # Creating centered downward arrow labels
    arrow_down_style = {'font_size': '24px'}
    arrow_down1 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down2 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down3 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down4 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)
    arrow_down5 = widgets.Label(value='↓', layout=widgets.Layout(width='auto'), style=arrow_down_style)

    # Create a new vertical box (VBox) with centered alignment
    vbox_new_layout = widgets.Layout(display='flex', flex_flow='column', align_items='center', width='100%')
    vbox_new = widgets.VBox([back_button, title_label, button_load_data, arrow_down1, button_visualize_features, arrow_down2, 
                             button_input_parameters, arrow_down3, button_select_action, arrow_down4, 
                             button_run, arrow_down5, button_return_results], layout=vbox_new_layout)

    # Add the new VBox to the vbox
    vbox.children = [vbox_new]

In [16]:
def update_button_color(b):
    global steps_box
    for item in steps_box.children:
        if not isinstance(item, VBox):
            continue
        button = item.children[0]
        label = item.children[1].value
        if current_state in label:
        # ifbutton == b or (b.description == "Start" and "Data" in label): # for the starting case
            button.style.button_color = 'LightSkyBlue'
        else:
            button.style.button_color = 'lightgray'

In [1]:
from mlflow.tracking import MlflowClient

### Inner pipelines buttons

### for BIAS DETECTION TOOLKIT ###
# style2 = """
# <style>
#     .widget-select-multiple {
#         border: none !important;
#         box-shadow: 0 4px 6px rgba(0,0,0,0.1);
#         border-radius: 4px;
#         overflow: hidden;
#     }
#     .widget-label {
#         margin-right: 10px;
#         color: #333;
#     }
#     .custom-select-container {
#         display: flex;
#         align-items: center;
#         margin-bottom: 20px;
#     }
#     .custom-select-container .widget-html {
#         margin-right: 10px;
#     }
# </style>
# """

# display(HTML(style2))

# def upload_bias_detection_data(b):
#     global df, current_state
#     current_state = "Data"
    
#     custom_css = """
#     <style>
#         .custom-radio {
#             display: flex;
#             align-items: center;
#             margin-bottom: 10px;
#         }
#         .custom-radio input {
#             margin-right: 10px;
#         }
#         .custom-radio label {
#             font-size: 16px;
#             font-family: Optima, sans-serif;
#         }
#     </style>
#     """
    
#     # Custom HTML and JavaScript for the radio buttons
#     dataset_html = "<div class='radio-group'>"
#     for i, option in enumerate(dataset_options):
#         checked_attribute = "checked" if i == 0 else ""
#         dataset_html += f"<div class='custom-radio'><input type='radio' id='radio{i}' name='dataset' value='{option['name']}' {checked_attribute} onclick='select_dataset(\"{option['name']}\")'>"
#         dataset_html += f"<label for='radio{i}'>{option['name']} - <a href='{option['url']}' target='_blank' style='color: #0096FF; text-decoration: none; font-weight: bold;'>Learn more</a></label></div>"
#     dataset_html += "</div>"
    
#     full_html = custom_css + dataset_html
    
#     # Display the HTML and JavaScript
#     dataset_selector = widgets.HTML(value=full_html)

#     # Info message for dataset selection
#     info_message = """
#     <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
#         <span style="font-weight: bold; color: #333333; font-size: 16px;">
#             <i class="fa fa-database" style="margin-right: 10px;"></i> Dataset Information
#         </span>
#         <p style="font-size: 14px; margin-top: 10px;">
#             Choose a sample dataset to perform bias detection. This toolkit allows you to analyze and detect biases in various datasets. 
#             By selecting a dataset, you can explore how different biases may affect your model's performance and fairness. Select one of the available datasets to start your analysis and understand the biases present in the data.
#         </p>
#     </div>
#     """
    
#     combined_box = widgets.VBox([widgets.HTML(value=info_message), dataset_selector])
    
#     with global_output:
#         # Display the combined box
#         update_button_color(b)
#         clean_toolkit_content()
#         display(combined_box)

def on_dataset_selector_change(change):
    global df, X_adult, y_adult, sample_weight, selected_dataset
    selected_dataset = change['new']
    # selected_url = next(option['url'] for option in dataset_options if option['name'] == selected_option)
    if selected_dataset == "Ad Campaign":
        df = pd.read_csv(dataset_name_2_file_name["Ad Campaign"])
    elif selected_dataset == "Adult (Census Income)":
        X_adult, y_adult, sample_weight = fetch_adult()
        df = clean_dataset(X_adult.assign(income=y_adult), "adult")
    elif selected_dataset == "Workable":
        #df = pd.read_parquet(dataset_name_2_file_name["Workable"],engine="pyarrow")
        df_schema = pd.read_parquet(dataset_name_2_file_name["Workable"], engine="pyarrow", columns=None)
        all_columns = df_schema.columns.tolist()
        # Step 2: Specify columns to exclude
        exclude = ["skills"]
        # Step 3 Filter the remaining columns
        include_columns = [col for col in all_columns if col not in exclude]
        # Step 4: Load only the desired columns
        df = pd.read_parquet(dataset_name_2_file_name["Workable"], columns=include_columns)


def on_algorithm_selector_change(change):
    global selected_algorithm
    selected_algorithm = change['new']

# def on_metric_selector_change(change):
#     global selected_algorithm_parameters
#     selected_algorithm_parameters['metric'] = change['new'].lower().replace(' ', '-')
    
# Function to handle changes in the metric selector
def on_metric_selector_change(change):
    global phi_widget, c_widget, selected_algorithm_parameters
    selected_algorithm_parameters['metric'] = change['new'].lower().replace(' ', '-')
    
    if change['new'] in {"Equal choice for recourse", "Equal cost of effectiveness"}:
        phi_widget.layout.display = 'block'
        c_widget.layout.display = 'none'
        del selected_algorithm_parameters['c'] # delete unused value           
    elif change['new'] == "Equal effectiveness within budget":
        phi_widget.layout.display = 'none'
        c_widget.layout.display = 'block'
        del selected_algorithm_parameters['phi'] # delete unused value
    else:
        phi_widget.layout.display = 'none'
        c_widget.layout.display = 'none'
        # delete unused values
        del selected_algorithm_parameters['phi']
        del selected_algorithm_parameters['c']
    
def on_viewpoint_selector_change(change):
    global selected_algorithm_parameters
    selected_algorithm_parameters['viewpoint'] = change['new'].lower()

def on_score_function_selector_change(change):
    pass

def on_penalty_function_selector_change(change):
    global selected_algorithm_parameters
    selected_algorithm_parameters['penalty'] = change['new']

# Function to upload bias detection data
def upload_bias_detection_data(b):
    global df, current_state, selected_dataset
    current_state = "Data"
    with global_output:
        # Info message for dataset selection
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-database" style="margin-right: 10px;"></i> Dataset Information
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Choose a sample dataset to perform bias detection. This toolkit allows you to analyze and detect biases in various datasets. 
                By selecting a dataset, you can explore how different biases may affect your model's performance and fairness. Select one of the available datasets to start your analysis and understand the biases present in the data.
            </p>
            <p style="font-size: 14px; font-weight: bold; color: red">
                Note that a dataset may only support a specific algorithm for bias detection. 
                For example, the dataset "Ad Campaign" can only be used with the MDSS algorithm.
            </p>
            <p style="font-size: 14px; margin-top: 10px;">
                <span style="font-weight: bold;">About supported datasets (click on dataset name to learn more):</span>
                <ul style="font-size: 14px;">
                    <li><b><a href="https://developer.ibm.com/exchanges/data/all/bias-in-advertising/" target="_blank" style="color: #0096FF; text-decoration: none;">Ad Campaign</a>:</b> A synthetic dataset for users who were shown a certain advertisement.</li>
                    <li><b><a href="https://archive.ics.uci.edu/dataset/2/adult" target="_blank" style="color: #0096FF; text-decoration: none;">Adult</a>:</b> A dataset to predict whether income exceeds $50K/yr based on census data. Also known as the "Census Income" dataset.</li>
                    <li><b><a href="https://workable.com" target="_blank" style="color: #0096FF; text-decoration: none;">Workable HR automation Data</a>:</b> A dataset provided by Workable.com</li>
                </ul>
            </p>
        </div>
        """

        # Create Dropdown for dataset selection
        dataset_selector = widgets.RadioButtons(
            # description='Select Dataset:',
            options=[option['name'] for option in dataset_options],
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
        )
        dataset_selector.add_class('my-radio-style')
        # Add event listener to the Dropdown
        # By default the selected dataset is "Ad Campaign"
        df = pd.read_csv(dataset_name_2_file_name["Ad Campaign"])
        selected_dataset = "Ad Campaign"
        dataset_selector.observe(on_dataset_selector_change, names='value')

        # Combine the info box and the custom dataset selector
        combined_box = widgets.VBox([widgets.HTML(value=info_message), dataset_selector])

        # Display the combined box
        update_button_color(b)
        clean_toolkit_content()
        display(combined_box)

# def upload_bias_detection_data(b):
#     global df, current_state
#     current_state = "Data"
#     with global_output:
#         info_message = """
#         <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
#             <span style="font-weight: bold; color: #333333; font-size: 16px;">
#                 <i class="fa fa-database" style="margin-right: 10px;"></i> Dataset Information
#             </span>
#             <p style="font-size: 14px; margin-top: 10px;">
#                 Choose a sample dataset to perform bias detection. This toolkit allows you to analyze and detect biases in various datasets. 
#                 By selecting a dataset, you can explore how different biases may affect your model's performance and fairness. Select one of the available datasets to start your analysis and understand the biases present in the data.
#             </p>
#         </div>
#         """
        
#         dataset_html = "<div style='font-size: 16px; margin-top: 10px;'>"
#         for i, option in enumerate(dataset_options):
#             checked_attribute = "checked" if i == 0 else ""
#             dataset_html += f"<div style='margin-bottom: 10px;'><input type='radio' name='dataset' value='{option['name']}' style='margin-right: 10px;' {checked_attribute}>"
#             dataset_html += f"<label style='font-size: 16px;'>{option['name']} - <a href='{option['url']}' target='_blank' style='color: #0096FF; text-decoration: none; font-weight: bold;'>Learn more</a></label></div>"
#         dataset_html += "</div>"

#         dataset_selector = widgets.HTML(
#             value=dataset_html,
#             layout=widgets.Layout(margin='10px 0 0 0', width='auto')
#         )
#         dataset_selector.add_class('my-font')
#         # By default the dataset is 'Ad Campaign'
        
#         # Combine the info box and the custom dataset selector
#         combined_box = widgets.VBox([widgets.HTML(value=info_message), dataset_selector])
        
#         update_button_color(b)
#         clean_toolkit_content()
#         display(combined_box)

# Path to your loading bar GIF
gif_address = 'loading_bar.gif'

# Load the GIF
with open(gif_address, 'rb') as f:
    img = f.read()

# Create loading bar widget
loading_bar = widgets.Image(value=img, format='gif')


def plot_histogram_grid_all(df):
    sns.set_theme(style="whitegrid")
    features = df.columns
    num_features = len(features)
    num_rows = num_features // 3 + (num_features % 3 > 0)

    fig, axes = plt.subplots(num_rows, 3, figsize=(18, 5 * num_rows))
    axes = axes.flatten()

    for i, feature in enumerate(features):
        ax = axes[i]
        series = df[feature]

        try:
            # Drop missing values
            data = series.dropna()

            # Determine if the feature is numeric
            is_numeric = pd.api.types.is_numeric_dtype(data)

            if is_numeric:
                sns.histplot(data=data, bins=20, kde=True, ax=ax, color=sns.color_palette("Set2")[i % 8])
            else:
                # Treat non-numeric (categorical) data
                sns.histplot(data=data.astype(str), discrete=True, shrink=0.8, ax=ax, color=sns.color_palette("Set2")[i % 8])

            ax.set_title(f'Histogram of {feature}', fontsize=14, fontweight='bold')
            ax.set_xlabel(feature, fontsize=12)
            ax.set_ylabel('Frequency', fontsize=12)

            # Optional: reduce x-tick labels for categorical with many unique values
            if not is_numeric and len(data.unique()) > 10:
                ax.set_xticks(ax.get_xticks()[::2])
            ax.tick_params(axis='x', rotation=45)

        except Exception as e:
            ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')

    for i in range(num_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()

    # Convert the figure to base64
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png', bbox_inches='tight')
    plt.close(fig)
    img_buf.seek(0)

    img_base64 = base64.b64encode(img_buf.read()).decode('utf-8')
    img_html = f'<div style="width: 90%; margin: auto; text-align: center;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"/></div>'
    display(HTML(img_html))


def dataset_analysis(b):
    global global_output, current_state
    current_state = "Analyze"
    info_message_analyze = f"""
    <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
        <span style="font-weight: bold; color: #333333; font-size: 16px;">
            <i class="fa fa-chart-bar" style="margin-right: 10px;"></i> Dataset Analysis
        </span>
        <p style="font-size: 14px; margin-top: 10px;">
            <b>Dataset loaded:</b> {selected_dataset}.
        </p>
        <p style="font-size: 14px; margin-top: 10px;">
            We are displaying the first few rows of the dataset to understand its structure. 
            Also, histograms are generated for numerical and categorical features to visualize their distributions.
            Histograms help in understanding the range, central tendency, and variability of the data.
        </p>    
    </div>
    """
    with global_output:
        # Display the loading bar initially
        update_button_color(b)
        clean_toolkit_content()
        
        loading_box = widgets.VBox([loading_bar], layout=widgets.Layout(align_items='center', width='100%'))
        
        # Display the centered loading bar initially
        display(loading_box)
        
        # Create an Output widget to capture all the outputs
        tmp_output_widget = widgets.Output()

        # Define a function to run the actual analysis
        def run_analysis():
            with tmp_output_widget: # Capture all the outputs in the output widget
                display_dataframe_styled(df.head())
                if selected_dataset == "Adult (Census Income)":
                    plot_histogram_grid(list(df)[1:])
                else:
                    plot_histogram_grid_top10_spaced(df)
                    
        # Run the analysis function
        run_analysis()
        
        # Clear the loading bar and display all the captured outputs
        clean_toolkit_content()
        display(widgets.HTML(value=info_message_analyze))
        display(tmp_output_widget)

def select_features(b):
    global features_4_scanning, features_status_output, current_state
    current_state = "Features"
    
    if 'df' in globals():
        # Info message for feature selection
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-list" style="margin-right: 10px;"></i> Feature/Protected Attribute Selection
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Select the features or protected attributes (depending on the bias detection model you will select next) you want to include in the bias detection analysis. 
                The selected features/protected attributes will be used to identify potential biases in the dataset.
            </p>
        </div>
        """
        
        info_box = widgets.HTML(
            value=info_message,
            layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
        )

        selector = widgets.SelectMultiple(
            options=df.columns,
            disabled=False,
            layout=widgets.Layout(width='20%', height='180px')
        )
        
        selector.add_class('my-select-multiple')
        
        features_box = widgets.VBox([selector], 
                                    layout=widgets.Layout(align_items='center', 
                                    justify_content='center', width='100%'))

        status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
        status_message += f'<span style="color: red; font-weight: bold;">No selected features/protected attributes.</span></div>'
        
        with global_output:
            update_button_color(b)
            clean_toolkit_content()
            display(info_box)
            display(features_box)
            display(features_status_output)
        
        with features_status_output:
            features_status_output.clear_output(wait=True)
            display(HTML(status_message))
            
        selector.observe(update_selected_features, names='value')
    else:
        with global_output:
            update_button_color(b)
            clean_toolkit_content()
            display_message("No dataset loaded.", color='red')


# def display_parameters_hydra(cfg):
#     """
#     Display all parameters from the configuration file in a VBox.
#     """
#     parameter_widgets = []
    
#     # Iterate over all parameters and add them to widgets
#     for param_name, values in cfg.parameters.items():
#         description = widgets.HTML(
#             value=f"<b>{param_name}:</b> {values}"
#         )
#         parameter_widgets.append(description)
    
#     # Create a VBox to display parameters
#     vbox = widgets.VBox(parameter_widgets)
#     return vbox


def give_parameters_updated_and_extended(b):
    global current_state, selected_algorithm_parameters
    current_state = "Parameters"
    
    with global_output:
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Each algorithm requires a different set of configuration parameters. You are seeing the parameters associated with algorithm you chose in the 'Model' stage.
            </p>
        </div>
        """

        if selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
            # Info message for scoring function selection
            mdss_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for MDSS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>MDSS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Scoring function:</b> The scoring function evaluates the quality of subsets by measuring how much the observed data deviates from the expected data under a null hypothesis.</li>
                    <li><b>Number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                </ul>
            </div>
            """

            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/mdss.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            C_display = widgets.HTML(value=f"<b>Scoring Function:</b> Bernoulli")
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(map(str, cfg.parameters.penalty))}")
            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.num_iters))}")
            
            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
            
            # # Create Dropdown for dataset selection
            # scoring_function_selector = widgets.RadioButtons(
            #     description='Scoring function:',
            #     options=["Bernoulli"],
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            # )
            # scoring_function_selector.add_class('my-radio-style')

            # # By default the selected algorithm is the first option
            # selected_algorithm_parameters['scoring_function'] = scoring_function_selector.options[0]
            # scoring_function_selector.observe(on_score_function_selector_change, names='value')
            
            # # Create IntText widgets for Penalty and Number of iterations
            # penalty_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Penalty:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # penalty_widget.add_class('my-int-text')
            
            # iterations_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Number of iterations:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # iterations_widget.add_class('my-int-text')
            
            # # Store the default values in the selected_algorithm_parameters dictionary
            # selected_algorithm_parameters['penalty'] = penalty_widget.value
            # selected_algorithm_parameters['num_iterations'] = iterations_widget.value

            # # Observe changes in the widgets
            # penalty_widget.observe(lambda change: selected_algorithm_parameters.update({'penalty': change['new']}), names='value')
            # iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            # Combine the info box, scoring function selector, and parameter input widgets
            # combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message), scoring_function_selector, penalty_widget, iterations_widget])
        elif selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
            # Info message for scoring function selection
            ot_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for OT
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>OT requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                    <li><b>C:</b> The parameter <b>C</b> is a regularization parameter which controls the trade-off between achieving a low error on the training data and minimizing the complexity of the model.</li>
                    <li><b>Maximum number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                </ul>
            </div>
            """
            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/config.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(cfg.parameters.penalty)}")

            C_display = widgets.HTML(value=f"<b>C (Regularization):</b> {', '.join(map(str, cfg.parameters.c))}")

            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.n_iter))}")

            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=ot_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
        elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
            # Info message for scoring function selection
            facts_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for FACTS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>FACTS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Maximum number of iterations:</b> The maximum number of iterations allowed for convergence of the solver of the demonstration ML model, here a simple logistic regression.</li>
                    <li><b>Frequent Itemset Minimum Support:</b>All groups of individuals examined are constrained to cover at least this percentage of the whole population.</li>
                    <li><b>Viewpoint:</b>One of two viewpoints defined for counterfactual actions by our framework. "macro" means that all individuals in a group must receive the same action, while "micro" means that each individual can choose, from a set of actions, the one that flips their class and has the minimum cost for the specific individual.</li>
                    <li><b>Metric:</b> The fairness metric/definition to be used by the algorithm.</li>
                    <li><b>Phi:</b> This is the parameter that determines whether we consider an action sufficiently effective or not.</li>
                    <li><b>Top count:</b> The number of (most) biased groups to detect based on the given metric.</li>
                    Specifically, an action is considered effective if it manages to flip the prediction of the individuals under study with at least <it>phi</it> probability, and ineffective otherwise.</li>
                    <li><b>Features not allowed to change:</b> The features that are not allowed to be changed from the algorithm. By default none. You may select up to two.</li>
                </ul>
            </div>
            """
            
            iterations_widget = widgets.BoundedIntText(
                value=1500,
                min=1,
                max=10000,
                description='Maximum number of iterations:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            iterations_widget.add_class('my-int-text')
            selected_algorithm_parameters['num_iterations'] = iterations_widget.value
            # Observe changes in the widgets
            iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            itemset_min_support_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Frequent Itemset Minimum Support:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            itemset_min_support_widget.add_class('my-int-text')
            selected_algorithm_parameters['itemset_min_support'] = itemset_min_support_widget.value
            # Observe changes in the widgets
            itemset_min_support_widget.observe(lambda change: selected_algorithm_parameters.update({'itemset_min_support': change['new']}), names='value')

            # Add viewpoint
            viewpoint_selector = widgets.RadioButtons(
                description='Viewpoint:',
                value="macro",
                options=["micro", "macro"],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            viewpoint_selector.add_class('my-radio-style')
            selected_algorithm_parameters['viewpoint'] = viewpoint_selector.value.lower()
            viewpoint_selector.observe(on_viewpoint_selector_change, names='value')

            metric_selector = widgets.RadioButtons(
                description='Fairness metric:',
                value="Equal choice for recourse",
                options=["Equal choice for recourse", 
                         "Equal effectiveness", 
                         "Equal effectiveness within budget", 
                         "Equal cost of effectiveness",
                         "Equal mean recourse",
                         "Fair tradeoff",],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            metric_selector.add_class('my-radio-style')
            selected_algorithm_parameters['metric'] = metric_selector.value.lower().replace(' ','-')
            metric_selector.observe(on_metric_selector_change, names='value')            
            global phi_widget, c_widget
            phi_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Phi:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='block')
            )
            phi_widget.add_class('my-int-text')
            selected_algorithm_parameters['phi'] = phi_widget.value

            # Observe changes in the widgets
            phi_widget.observe(lambda change: selected_algorithm_parameters.update({'phi': change['new']}), names='value')
            
            c_widget = widgets.BoundedIntText(
                value=1,
                min=1,
                description='Cost budget (c):',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='none')  # Initially hidden
            )
            c_widget.add_class('my-int-text')
            # selected_algorithm_parameters['c'] = c_widget.value
            c_widget.observe(lambda change: selected_algorithm_parameters.update({'c': change['new']}), names='value')

            features_not_to_change_selector = widgets.SelectMultiple(
                description="Features not to change:",
                options=df.columns,
                layout=widgets.Layout(width='30%', height='180px', margin='10px 0 0 10px'),
                style={'description_width': 'initial'}
            )

            features_not_to_change_selector.add_class('my-select-multiple')
            features_not_to_change_selector.observe(update_selected_not_to_change_features, names='value')
            
            top_count_widget = widgets.BoundedIntText(
                value=3,
                min=1,
                description='Number of subgroups to show:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            top_count_widget.add_class('my-int-text')
            selected_algorithm_parameters['top_count'] = top_count_widget.value
            # Observe changes in the widgets
            top_count_widget.observe(lambda change: selected_algorithm_parameters.update({'top_count': change['new']}), names='value')
            
            status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
            status_message += f'<span style="color: red; font-weight: bold;">No selected not to change features.</span></div>'
        
            # Combine the info box, scoring function selector, and parameter input widgets
            combined_box = widgets.VBox([widgets.HTML(value=facts_info_message), 
                                         iterations_widget,
                                         itemset_min_support_widget,
                                         viewpoint_selector,
                                         metric_selector, 
                                         phi_widget,
                                         c_widget,
                                         top_count_widget,
                                         features_not_to_change_selector,
                                         not_to_change_features_status_output])
            
            with not_to_change_features_status_output:
                not_to_change_features_status_output.clear_output(wait=True)
                display(HTML(status_message))   
        elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
            facts_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for FACTS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>FACTS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Maximum number of iterations:</b> The maximum number of iterations allowed for convergence of the solver of the demonstration ML model, here a simple logistic regression.</li>
                    <li><b>Frequent Itemset Minimum Support:</b>All groups of individuals examined are constrained to cover at least this percentage of the whole population.</li>
                    <li><b>Viewpoint:</b>One of two viewpoints defined for counterfactual actions by our framework. "macro" means that all individuals in a group must receive the same action, while "micro" means that each individual can choose, from a set of actions, the one that flips their class and has the minimum cost for the specific individual.</li>
                    <li><b>Metric:</b> The fairness metric/definition to be used by the algorithm.</li>
                    <li><b>Phi:</b> This is the parameter that determines whether we consider an action sufficiently effective or not.</li>
                    <li><b>Top count:</b> The number of (most) biased groups to detect based on the given metric.</li>
                    Specifically, an action is considered effective if it manages to flip the prediction of the individuals under study with at least <it>phi</it> probability, and ineffective otherwise.</li>
                    <li><b>Features not allowed to change:</b> The features that are not allowed to be changed from the algorithm. By default none. You may select up to two.</li>
                </ul>
            </div>
            """

            
        update_button_color(b)
        clean_toolkit_content()
        display(combined_box)


def give_parameters_updated_and_extended_new(b):
    global current_state, selected_algorithm_parameters
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        # Info banner
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                You may review and optionally edit the configuration parameters of the selected algorithm. If supported, you may enable Auto-Tuning using FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        if selected_algorithm in general_algorithms:
            # Example for OT Logistic Regression
            from sklearn.linear_model import LogisticRegression
            from omegaconf import OmegaConf

            model_class = None
            init_args = {}
            config_path = "configs/config.yaml"
            cfg = OmegaConf.load(config_path)

            if selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                model_class = LogisticRegression
                init_args = {
                    "max_iter": cfg.parameters.n_iter[0],
                    "C": cfg.parameters.c[0],
                    "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                }

            elif selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                from subset_scanning.scanners import MDSS
                from subset_scanning.scoring import Bernoulli
                model_class = MDSS
                init_args = {
                    "scoring_function": Bernoulli(direction='negative')
                }

            elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                from facts import FACTS
                model_class = FACTS
                init_args = {
                    "clf": None,
                    "prot_attr": "gender",
                    "freq_itemset_min_supp": 0.1,
                    "feature_weights": {},
                    "feats_not_allowed_to_change": None
                }

            model = model_class(**init_args)
            param_widgets = {}
            if hasattr(model, "get_params"):
                for k, v in model.get_params().items():
                    if isinstance(v, bool):
                        w = widgets.Checkbox(value=v, description=k)
                    elif isinstance(v, int):
                        w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                    elif isinstance(v, float):
                        w = widgets.FloatText(value=v, description=k)
                    else:
                        w = widgets.Text(value=str(v), description=k)
                    param_widgets[k] = w
                display(widgets.VBox(list(param_widgets.values())))

                if auto_tune_checkbox.value:
                    display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

            # Save current parameter widget values on update
            def update_selected_algorithm_parameters(change):
                for k, widget in param_widgets.items():
                    selected_algorithm_parameters[k] = widget.value

            for widget in param_widgets.values():
                widget.observe(update_selected_algorithm_parameters, names='value')

        else:
            display(widgets.HTML("<b>This algorithm is handled by a custom pre/post-processing pipeline.</b>"))



def give_parameters_updated(b):
    global current_state, selected_algorithm_parameters
    current_state = "Parameters"
    
    with global_output:
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Each algorithm requires a different set of configuration parameters. You are seeing the parameters associated with algorithm you chose in the 'Model' stage.
            </p>
        </div>
        """

        if selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
            # Info message for scoring function selection
            mdss_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for MDSS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>MDSS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Scoring function:</b> The scoring function evaluates the quality of subsets by measuring how much the observed data deviates from the expected data under a null hypothesis.</li>
                    <li><b>Number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                </ul>
            </div>
            """

            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/mdss.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            C_display = widgets.HTML(value=f"<b>Scoring Function:</b> Bernoulli")
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(map(str, cfg.parameters.penalty))}")
            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.num_iters))}")
            
            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
            
            # # Create Dropdown for dataset selection
            # scoring_function_selector = widgets.RadioButtons(
            #     description='Scoring function:',
            #     options=["Bernoulli"],
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            # )
            # scoring_function_selector.add_class('my-radio-style')

            # # By default the selected algorithm is the first option
            # selected_algorithm_parameters['scoring_function'] = scoring_function_selector.options[0]
            # scoring_function_selector.observe(on_score_function_selector_change, names='value')
            
            # # Create IntText widgets for Penalty and Number of iterations
            # penalty_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Penalty:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # penalty_widget.add_class('my-int-text')
            
            # iterations_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Number of iterations:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # iterations_widget.add_class('my-int-text')
            
            # # Store the default values in the selected_algorithm_parameters dictionary
            # selected_algorithm_parameters['penalty'] = penalty_widget.value
            # selected_algorithm_parameters['num_iterations'] = iterations_widget.value

            # # Observe changes in the widgets
            # penalty_widget.observe(lambda change: selected_algorithm_parameters.update({'penalty': change['new']}), names='value')
            # iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            # Combine the info box, scoring function selector, and parameter input widgets
            # combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message), scoring_function_selector, penalty_widget, iterations_widget])
        elif selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
            # Info message for scoring function selection
            ot_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for OT
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>OT requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                    <li><b>C:</b> The parameter <b>C</b> is a regularization parameter which controls the trade-off between achieving a low error on the training data and minimizing the complexity of the model.</li>
                    <li><b>Maximum number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                </ul>
            </div>
            """
            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/config.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(cfg.parameters.penalty)}")

            C_display = widgets.HTML(value=f"<b>C (Regularization):</b> {', '.join(map(str, cfg.parameters.c))}")

            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.n_iter))}")

            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=ot_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
        elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
            # Info message for scoring function selection
            facts_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for FACTS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>FACTS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Maximum number of iterations:</b> The maximum number of iterations allowed for convergence of the solver of the demonstration ML model, here a simple logistic regression.</li>
                    <li><b>Frequent Itemset Minimum Support:</b>All groups of individuals examined are constrained to cover at least this percentage of the whole population.</li>
                    <li><b>Viewpoint:</b>One of two viewpoints defined for counterfactual actions by our framework. "macro" means that all individuals in a group must receive the same action, while "micro" means that each individual can choose, from a set of actions, the one that flips their class and has the minimum cost for the specific individual.</li>
                    <li><b>Metric:</b> The fairness metric/definition to be used by the algorithm.</li>
                    <li><b>Phi:</b> This is the parameter that determines whether we consider an action sufficiently effective or not.</li>
                    <li><b>Top count:</b> The number of (most) biased groups to detect based on the given metric.</li>
                    Specifically, an action is considered effective if it manages to flip the prediction of the individuals under study with at least <it>phi</it> probability, and ineffective otherwise.</li>
                    <li><b>Features not allowed to change:</b> The features that are not allowed to be changed from the algorithm. By default none. You may select up to two.</li>
                </ul>
            </div>
            """
            
            iterations_widget = widgets.BoundedIntText(
                value=1500,
                min=1,
                max=10000,
                description='Maximum number of iterations:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            iterations_widget.add_class('my-int-text')
            selected_algorithm_parameters['num_iterations'] = iterations_widget.value
            # Observe changes in the widgets
            iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            itemset_min_support_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Frequent Itemset Minimum Support:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            itemset_min_support_widget.add_class('my-int-text')
            selected_algorithm_parameters['itemset_min_support'] = itemset_min_support_widget.value
            # Observe changes in the widgets
            itemset_min_support_widget.observe(lambda change: selected_algorithm_parameters.update({'itemset_min_support': change['new']}), names='value')

            # Add viewpoint
            viewpoint_selector = widgets.RadioButtons(
                description='Viewpoint:',
                value="macro",
                options=["micro", "macro"],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            viewpoint_selector.add_class('my-radio-style')
            selected_algorithm_parameters['viewpoint'] = viewpoint_selector.value.lower()
            viewpoint_selector.observe(on_viewpoint_selector_change, names='value')

            metric_selector = widgets.RadioButtons(
                description='Fairness metric:',
                value="Equal choice for recourse",
                options=["Equal choice for recourse", 
                         "Equal effectiveness", 
                         "Equal effectiveness within budget", 
                         "Equal cost of effectiveness",
                         "Equal mean recourse",
                         "Fair tradeoff",],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            metric_selector.add_class('my-radio-style')
            selected_algorithm_parameters['metric'] = metric_selector.value.lower().replace(' ','-')
            metric_selector.observe(on_metric_selector_change, names='value')            
            global phi_widget, c_widget
            phi_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Phi:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='block')
            )
            phi_widget.add_class('my-int-text')
            selected_algorithm_parameters['phi'] = phi_widget.value

            # Observe changes in the widgets
            phi_widget.observe(lambda change: selected_algorithm_parameters.update({'phi': change['new']}), names='value')
            
            c_widget = widgets.BoundedIntText(
                value=1,
                min=1,
                description='Cost budget (c):',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='none')  # Initially hidden
            )
            c_widget.add_class('my-int-text')
            # selected_algorithm_parameters['c'] = c_widget.value
            c_widget.observe(lambda change: selected_algorithm_parameters.update({'c': change['new']}), names='value')

            features_not_to_change_selector = widgets.SelectMultiple(
                description="Features not to change:",
                options=df.columns,
                layout=widgets.Layout(width='30%', height='180px', margin='10px 0 0 10px'),
                style={'description_width': 'initial'}
            )

            features_not_to_change_selector.add_class('my-select-multiple')
            features_not_to_change_selector.observe(update_selected_not_to_change_features, names='value')
            
            top_count_widget = widgets.BoundedIntText(
                value=3,
                min=1,
                description='Number of subgroups to show:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            top_count_widget.add_class('my-int-text')
            selected_algorithm_parameters['top_count'] = top_count_widget.value
            # Observe changes in the widgets
            top_count_widget.observe(lambda change: selected_algorithm_parameters.update({'top_count': change['new']}), names='value')
            
            status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
            status_message += f'<span style="color: red; font-weight: bold;">No selected not to change features.</span></div>'
        
            # Combine the info box, scoring function selector, and parameter input widgets
            combined_box = widgets.VBox([widgets.HTML(value=facts_info_message), 
                                         iterations_widget,
                                         itemset_min_support_widget,
                                         viewpoint_selector,
                                         metric_selector, 
                                         phi_widget,
                                         c_widget,
                                         top_count_widget,
                                         features_not_to_change_selector,
                                         not_to_change_features_status_output])
            
            with not_to_change_features_status_output:
                not_to_change_features_status_output.clear_output(wait=True)
                display(HTML(status_message))   
        elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
            facts_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for FACTS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>FACTS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Maximum number of iterations:</b> The maximum number of iterations allowed for convergence of the solver of the demonstration ML model, here a simple logistic regression.</li>
                    <li><b>Frequent Itemset Minimum Support:</b>All groups of individuals examined are constrained to cover at least this percentage of the whole population.</li>
                    <li><b>Viewpoint:</b>One of two viewpoints defined for counterfactual actions by our framework. "macro" means that all individuals in a group must receive the same action, while "micro" means that each individual can choose, from a set of actions, the one that flips their class and has the minimum cost for the specific individual.</li>
                    <li><b>Metric:</b> The fairness metric/definition to be used by the algorithm.</li>
                    <li><b>Phi:</b> This is the parameter that determines whether we consider an action sufficiently effective or not.</li>
                    <li><b>Top count:</b> The number of (most) biased groups to detect based on the given metric.</li>
                    Specifically, an action is considered effective if it manages to flip the prediction of the individuals under study with at least <it>phi</it> probability, and ineffective otherwise.</li>
                    <li><b>Features not allowed to change:</b> The features that are not allowed to be changed from the algorithm. By default none. You may select up to two.</li>
                </ul>
            </div>
            """

            
        update_button_color(b)
        clean_toolkit_content()
        display(combined_box)


def give_parameters(b):
    global current_state, selected_algorithm_parameters
    current_state = "Parameters"
    
    with global_output:
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Each algorithm requires a different set of configuration parameters. You are seeing the parameters associated with algorithm you chose in the 'Model' stage.
            </p>
        </div>
        """

        if selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
            # Info message for scoring function selection
            mdss_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for MDSS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>MDSS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Scoring function:</b> The scoring function evaluates the quality of subsets by measuring how much the observed data deviates from the expected data under a null hypothesis.</li>
                    <li><b>Number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                </ul>
            </div>
            """

            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/mdss.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            C_display = widgets.HTML(value=f"<b>Scoring Function:</b> Bernoulli")
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(map(str, cfg.parameters.penalty))}")
            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.num_iters))}")
            
            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
            
            # # Create Dropdown for dataset selection
            # scoring_function_selector = widgets.RadioButtons(
            #     description='Scoring function:',
            #     options=["Bernoulli"],
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            # )
            # scoring_function_selector.add_class('my-radio-style')

            # # By default the selected algorithm is the first option
            # selected_algorithm_parameters['scoring_function'] = scoring_function_selector.options[0]
            # scoring_function_selector.observe(on_score_function_selector_change, names='value')
            
            # # Create IntText widgets for Penalty and Number of iterations
            # penalty_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Penalty:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # penalty_widget.add_class('my-int-text')
            
            # iterations_widget = widgets.BoundedIntText(
            #     value=1,
            #     min=1,
            #     description='Number of iterations:',
            #     style={'description_width': 'initial'},
            #     layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            # )
            # iterations_widget.add_class('my-int-text')
            
            # # Store the default values in the selected_algorithm_parameters dictionary
            # selected_algorithm_parameters['penalty'] = penalty_widget.value
            # selected_algorithm_parameters['num_iterations'] = iterations_widget.value

            # # Observe changes in the widgets
            # penalty_widget.observe(lambda change: selected_algorithm_parameters.update({'penalty': change['new']}), names='value')
            # iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            # Combine the info box, scoring function selector, and parameter input widgets
            # combined_box = widgets.VBox([widgets.HTML(value=mdss_info_message), scoring_function_selector, penalty_widget, iterations_widget])
        elif selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
            # Info message for scoring function selection
            ot_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for OT
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>OT requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Penalty:</b> The penalty parameter controls the trade-off between the size of the subset and the scoring function value, helping to prevent overfitting.</li>
                    <li><b>C:</b> The parameter <b>C</b> is a regularization parameter which controls the trade-off between achieving a low error on the training data and minimizing the complexity of the model.</li>
                    <li><b>Maximum number of iterations:</b> The number of iterations determines how many times the algorithm will run to find the optimal subset.</li>
                </ul>
            </div>
            """
            from omegaconf import OmegaConf

            # Load the config.yaml manually
            config_path = "configs/config.yaml"
            cfg = OmegaConf.load(config_path)

            # Create HTML widgets to display the parameters
            penalty_display = widgets.HTML(value=f"<b>Penalty:</b> {', '.join(cfg.parameters.penalty)}")

            C_display = widgets.HTML(value=f"<b>C (Regularization):</b> {', '.join(map(str, cfg.parameters.c))}")

            iterations_display = widgets.HTML(value=f"<b>Maximum number of iterations:</b> {', '.join(map(str, cfg.parameters.n_iter))}")

            # Combine all parameter displays into a VBox
            combined_box = widgets.VBox([widgets.HTML(value=ot_info_message),widgets.HTML(value="<b>Algorithm Configuration Parameters:</b>"),penalty_display,C_display,iterations_display])
        elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
            # Info message for scoring function selection
            facts_info_message = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-sliders" style="margin-right: 10px;"></i> Select the parameters for FACTS
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    <b>FACTS requires the following parameters:</b>
                </p>
                <ul style="font-size: 14px; margin-top: 5px;">
                    <li><b>Maximum number of iterations:</b> The maximum number of iterations allowed for convergence of the solver of the demonstration ML model, here a simple logistic regression.</li>
                    <li><b>Frequent Itemset Minimum Support:</b>All groups of individuals examined are constrained to cover at least this percentage of the whole population.</li>
                    <li><b>Viewpoint:</b>One of two viewpoints defined for counterfactual actions by our framework. "macro" means that all individuals in a group must receive the same action, while "micro" means that each individual can choose, from a set of actions, the one that flips their class and has the minimum cost for the specific individual.</li>
                    <li><b>Metric:</b> The fairness metric/definition to be used by the algorithm.</li>
                    <li><b>Phi:</b> This is the parameter that determines whether we consider an action sufficiently effective or not.</li>
                    <li><b>Top count:</b> The number of (most) biased groups to detect based on the given metric.</li>
                    Specifically, an action is considered effective if it manages to flip the prediction of the individuals under study with at least <it>phi</it> probability, and ineffective otherwise.</li>
                    <li><b>Features not allowed to change:</b> The features that are not allowed to be changed from the algorithm. By default none. You may select up to two.</li>
                </ul>
            </div>
            """
            
            iterations_widget = widgets.BoundedIntText(
                value=1500,
                min=1,
                max=10000,
                description='Maximum number of iterations:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            iterations_widget.add_class('my-int-text')
            selected_algorithm_parameters['num_iterations'] = iterations_widget.value
            # Observe changes in the widgets
            iterations_widget.observe(lambda change: selected_algorithm_parameters.update({'num_iterations': change['new']}), names='value')

            itemset_min_support_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Frequent Itemset Minimum Support:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            itemset_min_support_widget.add_class('my-int-text')
            selected_algorithm_parameters['itemset_min_support'] = itemset_min_support_widget.value
            # Observe changes in the widgets
            itemset_min_support_widget.observe(lambda change: selected_algorithm_parameters.update({'itemset_min_support': change['new']}), names='value')

            # Add viewpoint
            viewpoint_selector = widgets.RadioButtons(
                description='Viewpoint:',
                value="macro",
                options=["micro", "macro"],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            viewpoint_selector.add_class('my-radio-style')
            selected_algorithm_parameters['viewpoint'] = viewpoint_selector.value.lower()
            viewpoint_selector.observe(on_viewpoint_selector_change, names='value')

            metric_selector = widgets.RadioButtons(
                description='Fairness metric:',
                value="Equal choice for recourse",
                options=["Equal choice for recourse", 
                         "Equal effectiveness", 
                         "Equal effectiveness within budget", 
                         "Equal cost of effectiveness",
                         "Equal mean recourse",
                         "Fair tradeoff",],
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
            )
            metric_selector.add_class('my-radio-style')
            selected_algorithm_parameters['metric'] = metric_selector.value.lower().replace(' ','-')
            metric_selector.observe(on_metric_selector_change, names='value')            
            global phi_widget, c_widget
            phi_widget = widgets.BoundedFloatText(
                value=0.1,
                min=0,
                description='Phi:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='block')
            )
            phi_widget.add_class('my-int-text')
            selected_algorithm_parameters['phi'] = phi_widget.value

            # Observe changes in the widgets
            phi_widget.observe(lambda change: selected_algorithm_parameters.update({'phi': change['new']}), names='value')
            
            c_widget = widgets.BoundedIntText(
                value=1,
                min=1,
                description='Cost budget (c):',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px', display='none')  # Initially hidden
            )
            c_widget.add_class('my-int-text')
            # selected_algorithm_parameters['c'] = c_widget.value
            c_widget.observe(lambda change: selected_algorithm_parameters.update({'c': change['new']}), names='value')

            features_not_to_change_selector = widgets.SelectMultiple(
                description="Features not to change:",
                options=df.columns,
                layout=widgets.Layout(width='30%', height='180px', margin='10px 0 0 10px'),
                style={'description_width': 'initial'}
            )

            features_not_to_change_selector.add_class('my-select-multiple')
            features_not_to_change_selector.observe(update_selected_not_to_change_features, names='value')
            
            top_count_widget = widgets.BoundedIntText(
                value=3,
                min=1,
                description='Number of subgroups to show:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='30%', margin='10px 0 0 10px')
            )
            top_count_widget.add_class('my-int-text')
            selected_algorithm_parameters['top_count'] = top_count_widget.value
            # Observe changes in the widgets
            top_count_widget.observe(lambda change: selected_algorithm_parameters.update({'top_count': change['new']}), names='value')
            
            status_message = f'<div style="background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 5px; padding: 10px; width: auto; margin: 10px auto; text-align: left;">'
            status_message += f'<span style="color: red; font-weight: bold;">No selected not to change features.</span></div>'
        
            # Combine the info box, scoring function selector, and parameter input widgets
            combined_box = widgets.VBox([widgets.HTML(value=facts_info_message), 
                                         iterations_widget,
                                         itemset_min_support_widget,
                                         viewpoint_selector,
                                         metric_selector, 
                                         phi_widget,
                                         c_widget,
                                         top_count_widget,
                                         features_not_to_change_selector,
                                         not_to_change_features_status_output])
            
            with not_to_change_features_status_output:
                not_to_change_features_status_output.clear_output(wait=True)
                display(HTML(status_message))   
                                
        update_button_color(b)
        clean_toolkit_content()
        display(combined_box)


selected_algorithms = []
selected_algorithm_general = []
selected_algorithm_pre = []
selected_algorithm_post = []

def create_checkbox_group(options, on_change_callback):
    checkboxes = []
    for option in options:
        cb = widgets.Checkbox(value=False, description=option, layout=widgets.Layout(width='90%'))
        cb.observe(on_change_callback, names='value')
        checkboxes.append(cb)
    return widgets.VBox(checkboxes)


def update_selected_algorithms(change, group):
    cb = change['owner']
    if group == "general":
        if cb.value and cb.description not in selected_algorithm_general:
            selected_algorithm_general.append(cb.description)
        elif not cb.value and cb.description in selected_algorithm_general:
            selected_algorithm_general.remove(cb.description)
    elif group == "pre":
        if cb.value and cb.description not in selected_algorithm_pre:
            selected_algorithm_pre.append(cb.description)
        elif not cb.value and cb.description in selected_algorithm_pre:
            selected_algorithm_pre.remove(cb.description)
    elif group == "post":
        if cb.value and cb.description not in selected_algorithm_post:
            selected_algorithm_post.append(cb.description)
        elif not cb.value and cb.description in selected_algorithm_post:
            selected_algorithm_post.remove(cb.description)


def give_model_to_detect_bias_updated_checkboxes(b):
    global current_state
    current_state = "Model"
    with global_output:
        update_button_color(b)
        clean_toolkit_content()
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-cogs" style="margin-right: 10px;"></i> Algorithm Selection
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Select the algorithm to be used for detecting bias in the provided dataset. 
                The chosen algorithm will analyze the data and identify any potential biases that may affect the model's performance and fairness.
            </p>
            <p style="font-size: 14px; font-weight: bold; color: red">
                Note that a dataset may only support a specific algorithm for bias detection. 
                For example, if you chose the dataset "Ad Campaign", you must select the MDSS algorithm.
                You are only seeing the supported algorithms per dataset as options, as those defined next.
            </p>
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
               Supported algorithms per dataset:
            </span>
            <span style="font-size: 14px;">
                <ul>
                    <li>Ad Campaign: Multi-Dimensional Subset Scan (MDSS)</li>
                    <li>Adult: Bias Detection via Optimal Transport, FACTS</li>
                </ul>
            </span>
        </div>
        """

        # Dataset options with "Learn more" links
        algorithm_options_per_dataset = {
            'Ad Campaign': ['Multi-Dimensional Subset Scan (MDSS)'],
            'Adult (Census Income)': ['Bias Detection via Optimal Transport (Logistic Regression)', 'Fairness Aware Counterfactuals for Subgroups (FACTS)']
        }

        algorithm_options_for_preprocessing = {
            'Ad Campaign' : ['Reweighing Pre-processing technique (Reweighing)'],
            'Adult (Census Income)' : ['Reweighing Pre-processing technique (Reweighing)']
        }

        algorithm_options_for_postprocessing = {
            #'Ad Campaign' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)'],
            #'Ad Campaign' : ['Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)'],
            #'Ad Campaign' : ['Reject Option Classification Post-processing technique (RejectOptionClassification)'],
            'Ad Campaign' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)','Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)','Reject Option Classification Post-processing technique (RejectOptionClassification)'],
            'Adult (Census Income)' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)','Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)','Reject Option Classification Post-processing technique (RejectOptionClassification)']
        }

        
        
        algorithm_selector = create_checkbox_group(
            algorithm_options_per_dataset[selected_dataset],
            lambda change: update_selected_algorithms(change, group="general")
        )

        algorithm_selector_multiple_preprocess = create_checkbox_group(
            algorithm_options_for_preprocessing[selected_dataset],
            lambda change: update_selected_algorithms(change, group="pre")
        )

        algorithm_selector_multiple_postprocess = create_checkbox_group(
            algorithm_options_for_postprocessing[selected_dataset],
            lambda change: update_selected_algorithms(change, group="post")
        )
        # algorithm_selector.add_class('my-radio-style')
        # algorithm_selector_multiple_preprocess.add_class('my-radio-style')
        # algorithm_selector_multiple_postprocess.add_class('my-radio-style')

        
        global selected_algorithms
        # By default the selected algorithm is the first option
        # selected_algorithm = algorithm_selector.options[0]
        # selected_algorithmspre = algorithm_selector_multiple_preprocess.options[0]
        # selected_algorithmspost = algorithm_selector_multiple_postprocess.options[0]        
        # selected_algorithm = algorithm_selector.options[0]
        # selected_algorithmspre = algorithm_selector_multiple_preprocess.options[0]
        # selected_algorithmspost = algorithm_selector_multiple_postprocess.options[0]     
        
        algorithm_selector.observe(on_algorithm_selector_change, names='value')
        algorithm_selector_multiple_preprocess.observe(on_algorithm_selector_change, names='value')
        algorithm_selector_multiple_postprocess.observe(on_algorithm_selector_change, names='value')
        
        # Combine the info box and the custom dataset selector
        #combined_box = widgets.VBox([widgets.HTML(value=info_message), algorithm_selector, algorithm_selector_multiple_preprocess,algorithm_selector_multiple_postprocess])

        combined_box = widgets.VBox([widgets.HTML(value=info_message),
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>General Toolkit Algorithms</h3>"),algorithm_selector,
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>Pre-processing Algorithms</h3>"),algorithm_selector_multiple_preprocess,
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>Post-processing Algorithms</h3>"),algorithm_selector_multiple_postprocess,])
        
        display(combined_box)



def give_model_to_detect_bias_updated(b):
    global current_state
    current_state = "Model"
    with global_output:
        update_button_color(b)
        clean_toolkit_content()
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-cogs" style="margin-right: 10px;"></i> Algorithm Selection
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Select the algorithm to be used for detecting bias in the provided dataset. 
                The chosen algorithm will analyze the data and identify any potential biases that may affect the model's performance and fairness.
            </p>
            <p style="font-size: 14px; font-weight: bold; color: red">
                Note that a dataset may only support a specific algorithm for bias detection. 
                For example, if you chose the dataset "Ad Campaign", you must select the MDSS algorithm.
                You are only seeing the supported algorithms per dataset as options, as those defined next.
            </p>
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
               Supported algorithms per dataset:
            </span>
            <span style="font-size: 14px;">
                <ul>
                    <li>Ad Campaign: Multi-Dimensional Subset Scan (MDSS)</li>
                    <li>Adult: Bias Detection via Optimal Transport, FACTS</li>
                </ul>
            </span>
        </div>
        """

        # Dataset options with "Learn more" links
        algorithm_options_per_dataset = {
            'Ad Campaign': ['Multi-Dimensional Subset Scan (MDSS)'],
            'Adult (Census Income)': ['Bias Detection via Optimal Transport (Logistic Regression)', 'Fairness Aware Counterfactuals for Subgroups (FACTS)']
        }

        algorithm_options_for_preprocessing = {
            'Ad Campaign' : ['Reweighing Pre-processing technique (Reweighing)'],
            'Adult (Census Income)' : ['Reweighing Pre-processing technique (Reweighing)']
        }

        algorithm_options_for_postprocessing = {
            #'Ad Campaign' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)'],
            #'Ad Campaign' : ['Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)'],
            #'Ad Campaign' : ['Reject Option Classification Post-processing technique (RejectOptionClassification)'],
            'Ad Campaign' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)','Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)','Reject Option Classification Post-processing technique (RejectOptionClassification)'],
            'Adult (Census Income)' : ['Equalized Odds Post-processing technique (EqOddsPostprocessing)','Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)','Reject Option Classification Post-processing technique (RejectOptionClassification)']
        }

        
        
        # Create Dropdown for dataset selection
        algorithm_selector = widgets.RadioButtons(
            options=[option for option in algorithm_options_per_dataset[selected_dataset]],
            value = None,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
        )

        algorithm_selector_multiple_preprocess = widgets.RadioButtons(
            options=[option for option in algorithm_options_for_preprocessing[selected_dataset]],
            value = None,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
        )

        algorithm_selector_multiple_postprocess = widgets.RadioButtons(
            options=[option for option in algorithm_options_for_postprocessing[selected_dataset]],
            value = None,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
        )
        
        algorithm_selector.add_class('my-radio-style')
        algorithm_selector_multiple_preprocess.add_class('my-radio-style')
        algorithm_selector_multiple_postprocess.add_class('my-radio-style')

        
        global selected_algorithm
        # By default the selected algorithm is the first option
        selected_algorithm = algorithm_selector.options[0]
        selected_algorithmspre = algorithm_selector_multiple_preprocess.options[0]
        selected_algorithmspost = algorithm_selector_multiple_postprocess.options[0]        
        # selected_algorithm = algorithm_selector.options[0]
        # selected_algorithmspre = algorithm_selector_multiple_preprocess.options[0]
        # selected_algorithmspost = algorithm_selector_multiple_postprocess.options[0]     
        
        algorithm_selector.observe(on_algorithm_selector_change, names='value')
        algorithm_selector_multiple_preprocess.observe(on_algorithm_selector_change, names='value')
        algorithm_selector_multiple_postprocess.observe(on_algorithm_selector_change, names='value')
        
        # Combine the info box and the custom dataset selector
        #combined_box = widgets.VBox([widgets.HTML(value=info_message), algorithm_selector, algorithm_selector_multiple_preprocess,algorithm_selector_multiple_postprocess])

        combined_box = widgets.VBox([widgets.HTML(value=info_message),
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>General Toolkit Algorithms</h3>"),algorithm_selector,
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>Pre-processing Algorithms</h3>"),algorithm_selector_multiple_preprocess,
                                     widgets.HTML(value="<h3 style='margin-top:20px;'>Post-processing Algorithms</h3>"),algorithm_selector_multiple_postprocess,])
        
        display(combined_box)


def give_model_to_detect_bias(b):
    global current_state
    current_state = "Model"
    with global_output:
        update_button_color(b)
        clean_toolkit_content()
        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-cogs" style="margin-right: 10px;"></i> Algorithm Selection
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Select the algorithm to be used for detecting bias in the provided dataset. 
                The chosen algorithm will analyze the data and identify any potential biases that may affect the model's performance and fairness.
            </p>
            <p style="font-size: 14px; font-weight: bold; color: red">
                Note that a dataset may only support a specific algorithm for bias detection. 
                For example, if you chose the dataset "Ad Campaign", you must select the MDSS algorithm.
                You are only seeing the supported algorithms per dataset as options, as those defined next.
            </p>
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
               Supported algorithms per dataset:
            </span>
            <span style="font-size: 14px;">
                <ul>
                    <li>Ad Campaign: Multi-Dimensional Subset Scan (MDSS)</li>
                    <li>Adult: Bias Detection via Optimal Transport, FACTS</li>
                </ul>
            </span>
        </div>
        """

        # Dataset options with "Learn more" links
        algorithm_options_per_dataset = {
            'Ad Campaign': ['Multi-Dimensional Subset Scan (MDSS)'],
            'Adult (Census Income)': ['Bias Detection via Optimal Transport (Logistic Regression)', 'Fairness Aware Counterfactuals for Subgroups (FACTS)']
        }

        
        # Create Dropdown for dataset selection
        algorithm_selector = widgets.RadioButtons(
            options=[option for option in algorithm_options_per_dataset[selected_dataset]],
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%', margin='10px 0 0 10px')
        )
        
        algorithm_selector.add_class('my-radio-style')
        global selected_algorithm
        # By default the selected algorithm is the first option
        selected_algorithm = algorithm_selector.options[0]
        algorithm_selector.observe(on_algorithm_selector_change, names='value')

        # Combine the info box and the custom dataset selector
        combined_box = widgets.VBox([widgets.HTML(value=info_message), algorithm_selector])

        display(combined_box)

####################################################################

def give_parameters_updated_and_extended_all_newer_integr(b):
    global current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    preprocessing_algorithms = [
        "Reweighing Pre-processing technique (Reweighing)"
    ]

    postprocessing_algorithms = [
        "Equalized Odds Post-processing technique (EqOddsPostprocessing)",
        "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)",
        "Reject Option Classification Post-processing technique (RejectOptionClassification)"
    ]

    # Create summary HTML widget placeholder
    summary_widget = widgets.HTML()

    def update_summary(_=None):
        summary = "<h4>Selected Algorithm Parameters Summary</h4><ul>"
        for alg in selected_algorithms:
            summary += f"<li><b>{alg}</b><ul>"
            params = selected_algorithm_parameters.get(alg, {})
            for k, v in params.items():
                summary += f"<li>{k}: {v}</li>"
            summary += "</ul></li>"
        summary += "</ul>"
        summary_widget.value = summary

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                You may review and optionally edit the configuration parameters of the selected algorithm(s). General algorithms can also be auto-tuned via FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            layout=widgets.Layout(width='90%'),
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        for algorithm in selected_algorithms:
            display(widgets.HTML(value=f"<h4>{algorithm}</h4>"))

            if algorithm in general_algorithms:
                from sklearn.linear_model import LogisticRegression
                from omegaconf import OmegaConf

                model_class = None
                init_args = {}
                config_path = "configs/config.yaml"
                cfg = OmegaConf.load(config_path)

                if algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    model_class = LogisticRegression
                    init_args = {
                        "max_iter": cfg.parameters.n_iter[0],
                        "C": cfg.parameters.c[0],
                        "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                    }

                elif algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    from aif360.algorithms.mdss import MDSS
                    model_class = MDSS
                    init_args = {
                        "scoring_function": "Bernoulli",
                        "max_iters": 50,
                        "penalty": 1.0
                    }

                elif algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    from facts import FACTS
                    model_class = FACTS
                    init_args = {
                        "clf": None,
                        "prot_attr": "gender",
                        "freq_itemset_min_supp": 0.1,
                        "feature_weights": {},
                        "feats_not_allowed_to_change": None
                    }

                model = model_class(**init_args)
                param_widgets = {}
                selected_algorithm_parameters[algorithm] = {}

                if hasattr(model, "get_params"):
                    for k, v in model.get_params().items():
                        if isinstance(v, bool):
                            w = widgets.Checkbox(value=v, description=k)
                        elif isinstance(v, int):
                            w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                        elif isinstance(v, float):
                            w = widgets.FloatText(value=v, description=k)
                        else:
                            w = widgets.Text(value=str(v), description=k)

                        param_widgets[k] = w
                        selected_algorithm_parameters[algorithm][k] = w.value

                        w.observe(lambda change, k=k, alg=algorithm: (
                            selected_algorithm_parameters[alg].update({k: change['new']}), update_summary()
                        ), names='value')

                    display(widgets.VBox(list(param_widgets.values())))

                    if auto_tune_checkbox.value:
                        display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

            elif algorithm == "Reweighing Pre-processing technique (Reweighing)":
                from aif360.algorithms.preprocessing import Reweighing
                param_widgets = {
                    "unprivileged_groups": widgets.Text(value="[{'job.experience_no'<= 2}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'job.experience_no'> 2}]", description="Privileged groups:")
                }
                selected_algorithm_parameters[algorithm] = {}
                for k, w in param_widgets.items():
                    selected_algorithm_parameters[algorithm][k] = w.value
                    w.observe(lambda change, k=k, alg=algorithm: (
                        selected_algorithm_parameters[alg].update({k: change['new']}), update_summary()
                    ), names='value')
                    display(w)

            elif algorithm == "Equalized Odds Post-processing technique (EqOddsPostprocessing)":
                from aif360.algorithms.postprocessing import EqOddsPostprocessing
                param_widgets = {
                    "protected_attr": widgets.Text(value="job.experience_no", description="Protected attribute:"),
                    "label_attr": widgets.Text(value="job.remote", description="Label column:"),
                    "score_attr": widgets.Text(value="score", description="Score column:")
                }
                selected_algorithm_parameters[algorithm] = {}
                for k, w in param_widgets.items():
                    selected_algorithm_parameters[algorithm][k] = w.value
                    w.observe(lambda change, k=k, alg=algorithm: (
                        selected_algorithm_parameters[alg].update({k: change['new']}), update_summary()
                    ), names='value')
                    display(w)

            elif algorithm == "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)":
                from aif360.algorithms.postprocessing import CalibratedEqOddsPostprocessing
                param_widgets = {
                    "cost_constraint": widgets.Dropdown(options=["fnr", "fpr", "weighted"], value="weighted", description="Cost Constraint:"),
                    "unprivileged_groups": widgets.Text(value="[{'job.experience_no' <= 2}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'job.experience_no'> 2}]", description="Privileged groups:")
                }
                selected_algorithm_parameters[algorithm] = {}
                for k, w in param_widgets.items():
                    selected_algorithm_parameters[algorithm][k] = w.value
                    w.observe(lambda change, k=k, alg=algorithm: (
                        selected_algorithm_parameters[alg].update({k: change['new']}), update_summary()
                    ), names='value')
                    display(w)

            elif algorithm == "Reject Option Classification Post-processing technique (RejectOptionClassification)":
                from aif360.algorithms.postprocessing import RejectOptionClassification
                param_widgets = {
                    "unprivileged_groups": widgets.Text(value="[{'sex': 0}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'sex': 1}]", description="Privileged groups:"),
                    "low_class_thresh": widgets.FloatSlider(value=0.01, min=0, max=1, step=0.01, description="Low threshold:"),
                    "high_class_thresh": widgets.FloatSlider(value=0.99, min=0, max=1, step=0.01, description="High threshold:")
                }
                selected_algorithm_parameters[algorithm] = {}
                for k, w in param_widgets.items():
                    selected_algorithm_parameters[algorithm][k] = w.value
                    w.observe(lambda change, k=k, alg=algorithm: (
                        selected_algorithm_parameters[alg].update({k: change['new']}), update_summary()
                    ), names='value')
                    display(w)

            else:
                display(widgets.HTML("<i>No parameters available for this algorithm.</i>"))

        # Initial summary render
        update_summary()
        display(summary_widget)





###################################################################
def give_parameters_updated_and_extended_all_newer_integ(b):
    global current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    preprocessing_algorithms = [
        "Reweighing Pre-processing technique (Reweighing)"
    ]

    postprocessing_algorithms = [
        "Equalized Odds Post-processing technique (EqOddsPostprocessing)",
        "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)",
        "Reject Option Classification Post-processing technique (RejectOptionClassification)"
    ]

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                You may review and optionally edit the configuration parameters of the selected algorithm(s). General algorithms can also be auto-tuned via FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            layout=widgets.Layout(width='90%'),
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        for algorithm in selected_algorithms:
            display(widgets.HTML(value=f"<h4>{algorithm}</h4>"))

            if algorithm in general_algorithms:
                from sklearn.linear_model import LogisticRegression
                from omegaconf import OmegaConf

                model_class = None
                init_args = {}
                config_path = "configs/config.yaml"
                cfg = OmegaConf.load(config_path)

                if algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    model_class = LogisticRegression
                    init_args = {
                        "max_iter": cfg.parameters.n_iter[0],
                        "C": cfg.parameters.c[0],
                        "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                    }

                elif algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    from aif360.algorithms.mdss import MDSS
                    model_class = MDSS
                    init_args = {
                        "scoring_function": "Bernoulli",
                        "max_iters": 50,
                        "penalty": 1.0
                    }

                elif algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    from facts import FACTS
                    model_class = FACTS
                    init_args = {
                        "clf": None,
                        "prot_attr": "gender",
                        "freq_itemset_min_supp": 0.1,
                        "feature_weights": {},
                        "feats_not_allowed_to_change": None
                    }

                model = model_class(**init_args)
                param_widgets = {}

                if hasattr(model, "get_params"):
                    for k, v in model.get_params().items():
                        if isinstance(v, bool):
                            w = widgets.Checkbox(value=v, description=k)
                        elif isinstance(v, int):
                            w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                        elif isinstance(v, float):
                            w = widgets.FloatText(value=v, description=k)
                        else:
                            w = widgets.Text(value=str(v), description=k)
                        param_widgets[k] = w

                    display(widgets.VBox(list(param_widgets.values())))

                    if auto_tune_checkbox.value:
                        display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

                    def update_selected_algorithm_parameters(change):
                        for k, widget in param_widgets.items():
                            selected_algorithm_parameters[k] = widget.value

                    for widget in param_widgets.values():
                        widget.observe(update_selected_algorithm_parameters, names='value')

            elif algorithm == "Reweighing Pre-processing technique (Reweighing)":
                from aif360.algorithms.preprocessing import Reweighing
                param_widgets = {
                    "unprivileged_groups": widgets.Text(value="[{'sex': 0}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'sex': 1}]", description="Privileged groups:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            elif algorithm == "Equalized Odds Post-processing technique (EqOddsPostprocessing)":
                from aif360.algorithms.postprocessing import EqOddsPostprocessing
                param_widgets = {
                    "protected_attr": widgets.Text(value="sex", description="Protected attribute:"),
                    "label_attr": widgets.Text(value="label", description="Label column:"),
                    "score_attr": widgets.Text(value="score", description="Score column:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            elif algorithm == "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)":
                from aif360.algorithms.postprocessing import CalibratedEqOddsPostprocessing
                param_widgets = {
                    "cost_constraint": widgets.Dropdown(options=["fnr", "fpr", "weighted"], value="weighted", description="Cost Constraint:"),
                    "unprivileged_groups": widgets.Text(value="[{'sex': 0}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'sex': 1}]", description="Privileged groups:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            elif algorithm == "Reject Option Classification Post-processing technique (RejectOptionClassification)":
                from aif360.algorithms.postprocessing import RejectOptionClassification
                param_widgets = {
                    "unprivileged_groups": widgets.Text(value="[{'sex': 0}]", description="Unprivileged groups:"),
                    "privileged_groups": widgets.Text(value="[{'sex': 1}]", description="Privileged groups:"),
                    "low_class_thresh": widgets.FloatSlider(value=0.01, min=0, max=1, step=0.01, description="Low threshold:"),
                    "high_class_thresh": widgets.FloatSlider(value=0.99, min=0, max=1, step=0.01, description="High threshold:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            else:
                display(widgets.HTML("<i>No parameters available for this algorithm.</i>"))

                # Display selected algorithm(s) and their parameters summary
        # Button to show final selected parameters
        def print_selected_parameters(_=None):
            summary = "<h4>Selected Algorithm Parameters Summary</h4><ul>"
            for alg in selected_algorithms:
                summary += f"<li><b>{alg}</b><ul>"

                # Check if parameters are nested per algorithm
                if alg in selected_algorithm_parameters:
                    params = selected_algorithm_parameters[alg]
                else:
                    # Fallback for general algorithms (flat dict)
                    params = selected_algorithm_parameters

                for k, v in params.items():
                    summary += f"<li>{k}: {v}</li>"
                summary += "</ul></li>"
            summary += "</ul>"
            display(widgets.HTML(value=summary))

        # Button widget
        show_summary_button = widgets.Button(
            description="Show Selected Parameters Summary",
            button_style='info',
            layout=widgets.Layout(margin="20px 0px")
        )
        show_summary_button.on_click(print_selected_parameters)
        display(show_summary_button)






def give_parameters_updated_and_extended_all_newer_integrated(b):
    global current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    preprocessing_algorithms = [
        "Reweighing Pre-processing technique (Reweighing)"
    ]

    postprocessing_algorithms = [
        "Equalized Odds Post-processing technique (EqOddsPostprocessing)",
        "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)",
        "Reject Option Classification Post-processing technique (RejectOptionClassification)"
    ]

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                You may review and optionally edit the configuration parameters of the selected algorithm(s). General algorithms can also be auto-tuned via FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            layout=widgets.Layout(width='90%'),
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        for algorithm in selected_algorithms:
            display(widgets.HTML(value=f"<h4>{algorithm}</h4>"))

            if algorithm in general_algorithms:
                from sklearn.linear_model import LogisticRegression
                from omegaconf import OmegaConf

                model_class = None
                init_args = {}
                config_path = "configs/config.yaml"
                cfg = OmegaConf.load(config_path)

                if algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    model_class = LogisticRegression
                    init_args = {
                        "max_iter": cfg.parameters.n_iter[0],
                        "C": cfg.parameters.c[0],
                        "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                    }

                elif algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    from aif360.metrics import MDSSClassificationMetric
                    from aif360.datasets import BinaryLabelDataset

                    dataset = BinaryLabelDataset(...)  # Replace with actual
                    classified_dataset = BinaryLabelDataset(...)  # Replace with actual

                    model_class = MDSSClassificationMetric
                    init_args = {
                        "dataset": dataset,
                        "classified_dataset": classified_dataset,
                        "scoring": "Bernoulli",
                        "privileged_groups": [{'protected_attribute': 1}],
                        "unprivileged_groups": [{'protected_attribute': 0}]
                    }

                elif algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    from facts import FACTS
                    model_class = FACTS
                    init_args = {
                        "clf": None,
                        "prot_attr": "gender",
                        "freq_itemset_min_supp": 0.1,
                        "feature_weights": {},
                        "feats_not_allowed_to_change": None
                    }

                model = model_class(**init_args)
                param_widgets = {}

                if hasattr(model, "get_params"):
                    for k, v in model.get_params().items():
                        if isinstance(v, bool):
                            w = widgets.Checkbox(value=v, description=k)
                        elif isinstance(v, int):
                            w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                        elif isinstance(v, float):
                            w = widgets.FloatText(value=v, description=k)
                        else:
                            w = widgets.Text(value=str(v), description=k)
                        param_widgets[k] = w

                    display(widgets.VBox(list(param_widgets.values())))

                    if auto_tune_checkbox.value:
                        display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

                    def update_selected_algorithm_parameters(change):
                        for k, widget in param_widgets.items():
                            selected_algorithm_parameters[k] = widget.value

                    for widget in param_widgets.values():
                        widget.observe(update_selected_algorithm_parameters, names='value')

            elif algorithm in preprocessing_algorithms:
                param_widgets = {
                    "reweighing_attr": widgets.Text(value="sex", description="Protected attribute:"),
                    "reweighing_label": widgets.Text(value="label", description="Label column:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            elif algorithm in postprocessing_algorithms:
                param_widgets = {
                    "protected_attr": widgets.Text(value="sex", description="Protected attribute:"),
                    "label_attr": widgets.Text(value="label", description="Label column:"),
                    "score_attr": widgets.Text(value="score", description="Score column:")
                }
                for widget in param_widgets.values():
                    display(widget)
                selected_algorithm_parameters[algorithm] = {k: w.value for k, w in param_widgets.items()}
                for k, w in param_widgets.items():
                    w.observe(lambda change, k=k: selected_algorithm_parameters[algorithm].update({k: change['new']}), names='value')

            else:
                display(widgets.HTML("<i>No parameters available for this algorithm.</i>"))



def give_parameters_updated_and_extended_all_newer(b):
    global current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    preprocessing_algorithms = [
        "Reweighing Pre-processing technique (Reweighing)"
    ]

    postprocessing_algorithms = [
        "Equalized Odds Post-processing technique (EqOddsPostprocessing)",
        "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)",
        "Reject Option Classification Post-processing technique (RejectOptionClassification)"
    ]

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-sliders" style="margin-right: 10px;"></i> Algorithm Parameters
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                You may review and optionally edit the configuration parameters of the selected algorithm(s). General algorithms can also be auto-tuned via FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            layout=widgets.Layout(width='90%'),
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        for algorithm in selected_algorithms:
            display(widgets.HTML(value=f"<h4>{algorithm}</h4>"))

            # General algorithms with introspectable params
            if algorithm in general_algorithms:
                from sklearn.linear_model import LogisticRegression
                from omegaconf import OmegaConf

                model_class = None
                init_args = {}
                config_path = "configs/config.yaml"
                cfg = OmegaConf.load(config_path)

                if algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    model_class = LogisticRegression
                    init_args = {
                        "max_iter": cfg.parameters.n_iter[0],
                        "C": cfg.parameters.c[0],
                        "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                    }

                elif algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    from aif360.metrics import MDSSClassificationMetric
                    from aif360.datasets import BinaryLabelDataset

                    # These must be set in your environment beforehand
                    dataset = BinaryLabelDataset(...)  # your true-labeled dataset
                    classified_dataset = BinaryLabelDataset(...)  # predictions

                    model_class = MDSSClassificationMetric
                    init_args = {
                        "dataset": dataset,
                        "classified_dataset": classified_dataset,
                        "scoring": "Bernoulli",
                        "privileged_groups": [{'protected_attribute': 1}],  # adjust accordingly
                        "unprivileged_groups": [{'protected_attribute': 0}]
                    }
                elif algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    from facts import FACTS
                    model_class = FACTS
                    init_args = {
                        "clf": None,
                        "prot_attr": "gender",
                        "freq_itemset_min_supp": 0.1,
                        "feature_weights": {},
                        "feats_not_allowed_to_change": None
                    }

                model = model_class(**init_args)
                param_widgets = {}

                if hasattr(model, "get_params"):
                    for k, v in model.get_params().items():
                        if isinstance(v, bool):
                            w = widgets.Checkbox(value=v, description=k)
                        elif isinstance(v, int):
                            w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                        elif isinstance(v, float):
                            w = widgets.FloatText(value=v, description=k)
                        else:
                            w = widgets.Text(value=str(v), description=k)
                        param_widgets[k] = w

                    display(widgets.VBox(list(param_widgets.values())))

                    if auto_tune_checkbox.value:
                        display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

                    def update_selected_algorithm_parameters(change):
                        for k, widget in param_widgets.items():
                            selected_algorithm_parameters[k] = widget.value

                    for widget in param_widgets.values():
                        widget.observe(update_selected_algorithm_parameters, names='value')

            # Pre-processing algorithms
            elif algorithm in preprocessing_algorithms:
                # TODO: Replace below with dynamic inspection if your pipeline class is accessible
                display(widgets.HTML("<b>Reweighing will be applied. Parameters can be configured in the preprocessing pipeline module.</b>"))
                
            # Post-processing algorithms
            elif algorithm in postprocessing_algorithms:
                # TODO: Replace below with dynamic inspection if your pipeline class is accessible
                display(widgets.HTML("<b>Post-processing (e.g., EqOdds, ROC) will be applied. Tune inside your pipeline or add FLAML tuning support.</b>"))

            else:
                display(widgets.HTML("<i>No parameters available for this algorithm.</i>"))


def give_model_to_detect_bias_updated_checkboxes_new(b):
    global current_state, selected_algorithms, selected_algorithm_general, selected_algorithm_pre, selected_algorithm_post
    current_state = "Model"
    selected_algorithms = []
    selected_algorithm_general = []
    selected_algorithm_pre = []
    selected_algorithm_post = []

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        info_message = """
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-cogs" style="margin-right: 10px;"></i> Algorithm Selection
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                Select the algorithm to be used for detecting bias in the provided dataset. 
                The chosen algorithm will analyze the data and identify any potential biases that may affect the model's performance and fairness.
            </p>
            <p style="font-size: 14px; font-weight: bold; color: red">
                Note that a dataset may only support a specific algorithm for bias detection. 
                For example, if you chose the dataset \"Ad Campaign\", you must select the MDSS algorithm.
                You are only seeing the supported algorithms per dataset as options, as those defined next.
            </p>
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
               Supported algorithms per dataset:
            </span>
            <span style="font-size: 14px;">
                <ul>
                    <li>Ad Campaign: Multi-Dimensional Subset Scan (MDSS)</li>
                    <li>Adult: Bias Detection via Optimal Transport, FACTS</li>
                    <li>Workable: Reweighing, EqOddsPostprocessing, CalibratedEqOddsPostprocessing,RejectOptionClassification, FACTS</li>
                </ul>
            </span>
        </div>
        """

        algorithm_options_per_dataset = {
            'Ad Campaign': ['Multi-Dimensional Subset Scan (MDSS)'],
            'Adult (Census Income)': ['Bias Detection via Optimal Transport (Logistic Regression)', 'Fairness Aware Counterfactuals for Subgroups (FACTS)'],
            'Workable': ['Bias Detection via Optimal Transport (Logistic Regression)','Fairness Aware Counterfactuals for Subgroups (FACTS)']
        }

        algorithm_options_for_preprocessing = {
            'Ad Campaign': ['Reweighing Pre-processing technique (Reweighing)'],
            'Adult (Census Income)': ['Reweighing Pre-processing technique (Reweighing)'],
            'Workable' : ['Reweighing Pre-processing technique (Reweighing)']
        }

        algorithm_options_for_postprocessing = {
            'Ad Campaign': [
                'Equalized Odds Post-processing technique (EqOddsPostprocessing)',
                'Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)',
                'Reject Option Classification Post-processing technique (RejectOptionClassification)'
            ],
            'Adult (Census Income)': [
                'Equalized Odds Post-processing technique (EqOddsPostprocessing)',
                'Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)',
                'Reject Option Classification Post-processing technique (RejectOptionClassification)'
            ],
            'Workable': [
                'Equalized Odds Post-processing technique (EqOddsPostprocessing)',
                'Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)',
                'Reject Option Classification Post-processing technique (RejectOptionClassification)'
            ]            
        }

        def create_checkbox_group(options, on_change):
            checkboxes = []
            for opt in options:
                cb = widgets.Checkbox(value=False, description=opt, layout=widgets.Layout(width='auto'))
                cb.observe(on_change, names='value')
                checkboxes.append(cb)
            return widgets.VBox(checkboxes)

        def update_selected_algorithms(change, group):
            selected_algorithms.clear()
            selected_algorithm_general.clear()
            selected_algorithm_pre.clear()
            selected_algorithm_post.clear()
            for cb in general_box.children:
                if cb.value:
                    selected_algorithms.append(cb.description)
                    selected_algorithm_general.append(cb.description)
            for cb in preprocess_box.children:
                if cb.value:
                    selected_algorithms.append(cb.description)
                    selected_algorithm_pre.append(cb.description)
            for cb in postprocess_box.children:
                if cb.value:
                    selected_algorithms.append(cb.description)
                    selected_algorithm_post.append(cb.description)

        general_box = create_checkbox_group(
            algorithm_options_per_dataset[selected_dataset],
            lambda change: update_selected_algorithms(change, group="general")
        )

        preprocess_box = create_checkbox_group(
            algorithm_options_for_preprocessing[selected_dataset],
            lambda change: update_selected_algorithms(change, group="pre")
        )

        postprocess_box = create_checkbox_group(
            algorithm_options_for_postprocessing[selected_dataset],
            lambda change: update_selected_algorithms(change, group="post")
        )

        combined_box = widgets.VBox([
            widgets.HTML(value=info_message),
            widgets.HTML(value="<h3 style='margin-top:20px;'>General Toolkit Algorithms</h3>"),
            general_box,
            widgets.HTML(value="<h3 style='margin-top:20px;'>Pre-processing Algorithms</h3>"),
            preprocess_box,
            widgets.HTML(value="<h3 style='margin-top:20px;'>Post-processing Algorithms</h3>"),
            postprocess_box
        ])

        display(combined_box)



def give_parameters_updated_and_extended_all(b):
    global current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Parameters"
    selected_algorithm_parameters = {}

    general_algorithms = [
        "Multi-Dimensional Subset Scan (MDSS)",
        "Bias Detection via Optimal Transport (Logistic Regression)",
        "Fairness Aware Counterfactuals for Subgroups (FACTS)"
    ]

    preprocessing_algorithms = [
        "Reweighing Pre-processing technique (Reweighing)"
    ]

    postprocessing_algorithms = [
        "Equalized Odds Post-processing technique (EqOddsPostprocessing)",
        "Calibrated Equalized Odds Post-processing technique (CalibratedEqOddsPostprocessing)",
        "Reject Option Classification Post-processing technique (RejectOptionClassification)"
    ]

    with global_output:
        update_button_color(b)
        clean_toolkit_content()

        # Info banner
        info_message = """
        <div style=\"background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;\">
            <span style=\"font-weight: bold; color: #333333; font-size: 16px;\">
                <i class=\"fa fa-sliders\" style=\"margin-right: 10px;\"></i> Algorithm Parameters
            </span>
            <p style=\"font-size: 14px; margin-top: 10px;\">
                You may review and optionally edit the configuration parameters of the selected algorithm(s). If supported, you may enable Auto-Tuning using FLAML.
            </p>
        </div>
        """
        display(widgets.HTML(value=info_message))

        auto_tune_checkbox = widgets.Checkbox(
            value=False,
            description="Auto-Tune Parameters with FLAML"
        )
        display(auto_tune_checkbox)

        for algorithm in selected_algorithms:
            display(widgets.HTML(value=f"<h4>{algorithm}</h4>"))

            if algorithm in general_algorithms:
                from sklearn.linear_model import LogisticRegression
                from omegaconf import OmegaConf

                model_class = None
                init_args = {}
                config_path = "configs/config.yaml"
                cfg = OmegaConf.load(config_path)

                if algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    model_class = LogisticRegression
                    init_args = {
                        "max_iter": cfg.parameters.n_iter[0],
                        "C": cfg.parameters.c[0],
                        "penalty": cfg.parameters.penalty[0].lower().replace(" ", "_")
                    }

                elif algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    from subset_scanning.scanners import MDSS
                    from subset_scanning.scoring import Bernoulli
                    model_class = MDSS
                    init_args = {
                        "scoring_function": Bernoulli(direction='negative')
                    }

                elif algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    from facts import FACTS
                    model_class = FACTS
                    init_args = {
                        "clf": None,
                        "prot_attr": "gender",
                        "freq_itemset_min_supp": 0.1,
                        "feature_weights": {},
                        "feats_not_allowed_to_change": None
                    }

                model = model_class(**init_args)
                param_widgets = {}
                if hasattr(model, "get_params"):
                    for k, v in model.get_params().items():
                        if isinstance(v, bool):
                            w = widgets.Checkbox(value=v, description=k)
                        elif isinstance(v, int):
                            w = widgets.BoundedIntText(value=v, description=k, min=1, max=10000)
                        elif isinstance(v, float):
                            w = widgets.FloatText(value=v, description=k)
                        else:
                            w = widgets.Text(value=str(v), description=k)
                        param_widgets[k] = w
                    display(widgets.VBox(list(param_widgets.values())))

                    if auto_tune_checkbox.value:
                        display(widgets.HTML(value="<i>FLAML Auto-tuning will be applied during execution.</i>"))

                def update_selected_algorithm_parameters(change):
                    for k, widget in param_widgets.items():
                        selected_algorithm_parameters[k] = widget.value

                for widget in param_widgets.values():
                    widget.observe(update_selected_algorithm_parameters, names='value')

            elif algorithm in preprocessing_algorithms:
                display(widgets.HTML("<b>Custom Pre-processing pipeline will be used for this algorithm.</b>"))

            elif algorithm in postprocessing_algorithms:
                display(widgets.HTML("<b>Custom Post-processing pipeline will be used for this algorithm.</b>"))

            else:
                display(widgets.HTML("<b>No parameters available for this algorithm.</b>"))


# CONFIG_PATH = os.path.join(os.getcwd(), "configs")

# # # Hydra Decorator to Load Configurations
# # @hydra.main(version_base=None, config_path="configs", config_name="config")
# def load_hydra_config():
#     hydra.initialize(config_path=CONFIG_PATH, version_base=None)
#     cfg = hydra.compose(config_name="config")
#     return cfg

# Directly load configuration using OmegaConf
def load_hydra_config(config_path="configs/config.yaml"):
    """
    Load the configuration file using OmegaConf for use in a Voila-compatible notebook.
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file '{config_path}' not found. Ensure it exists.")
    return OmegaConf.load(config_path)


def get_best_configuration(metric_name="mean_ot_val", maximize=True):
    """
    Retrieve the best configuration from MLFlow based on a given metric.
    
    Args:
        metric_name (str): Name of the metric to optimize.
        maximize (bool): Whether to maximize or minimize the metric.
    
    Returns:
        dict: The best configuration and its metric value.
    """
    client = MlflowClient()
    best_run = None
    best_metric = float("-inf") if maximize else float("inf")

    # Fetch all runs from the default experiment
    experiment = client.get_experiment_by_name("Default")
    runs = client.search_runs(experiment_ids=[experiment.experiment_id])

    # Iterate over runs to find the best configuration
    for run in runs:
        metrics = run.data.metrics
        params = run.data.params
        if metric_name in metrics:
            current_metric = metrics[metric_name]
            if (maximize and current_metric > best_metric) or (not maximize and current_metric < best_metric):
                best_metric = current_metric
                best_run = {
                    "metric": best_metric,
                    "parameters": params
                }

    return best_run

############################################ updated version below



def run_bias_detection_newer(b):
    global global_output, current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Run"
    update_button_color(b)
    clean_toolkit_content()

    with global_output:
        if df is None:
            warning_msg = """
            <div>    
                <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: red; font-size: 16px;">Warning:</span> 
                <span style="font-size: 14px">No dataset selected.</span>
            </div>"""
            display(HTML(warning_msg))
            return

        # Display algorithm/parameter summary
        parameter_summary = ""
        for alg in selected_algorithms:
            param_dict = selected_algorithm_parameters.get(alg, {})
            param_items = "".join([f"<li><b>{k}:</b> {v}</li>" for k, v in param_dict.items()])
            parameter_summary += f"""
                <li>
                    <b>{alg}</b>
                    <ul>{param_items}</ul>
                </li>
            """

        info_message_run = f"""
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-play-circle" style="margin-right: 10px;"></i> Run Algorithm
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                We have run the following algorithm(s) on the selected dataset <b>{selected_dataset}</b>:
            </p>
            <ul style="font-size: 13px;">
                {parameter_summary}
            </ul>
        </div>
        """
        display(widgets.HTML(value=info_message_run))

        # Dropdown for selecting best metric
        metric_selector = widgets.Dropdown(
            options=["FLAML Best Run", "Accuracy", "Conformal Coverage"],
            value="FLAML Best Run",
            description="Best by:",
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%')
        )
        display(metric_selector)

        # Show loading bar
        display(widgets.VBox([loading_bar]))
        output_widget = widgets.Output()

        def run_detection():
            with output_widget:
                if "Multi-Dimensional Subset Scan (MDSS)" in selected_algorithms:
                    pass
                    # config_path = "configs/mdss.yaml"
                    # cfg = load_hydra_config(config_path)
                    # scoring_function = Bernoulli(direction='negative')
                    # scanner = MDSS(scoring_function)
                    # mlflow.sklearn.autolog()
                    # if mlflow.active_run():
                    #     mlflow.end_run()
                    # for n_iter in cfg.parameters.num_iters:
                    #     for penalty in cfg.parameters.penalty:
                    #         with mlflow.start_run(run_name=f"n_iter={n_iter}_penalty={penalty}"):
                    #             try:
                    #                 scanned_subset, _ = scanner.scan(
                    #                     df[features_4_scanning],
                    #                     expectations=df['predicted_conversion'],
                    #                     outcomes=df['true_conversion'],
                    #                     penalty=penalty,
                    #                     num_iters=n_iter,
                    #                     verbose=False
                    #                 )
                    #                 print(f"Run complete: n_iter={n_iter}, penalty={penalty}")
                    #                 mlflow.log_param("num_iter", n_iter)
                    #                 mlflow.log_metric("penalty", penalty)
                    #                 print_report(df, scanned_subset)
                    #             except Exception as e:
                    #                 print(f"Error for n_iter={n_iter}, penalty={penalty}: {e}")
                    #                 mlflow.log_param("error", str(e))

                elif "Bias Detection via Optimal Transport (Logistic Regression)" in selected_algorithms:
                    pass
                    # data_raw = load_preproc_data_adult()
                    # data = data_raw.convert_to_dataframe()[0]
                    # X = data.drop('Income Binary', axis=1)
                    # y = data['Income Binary']
                    # config_path = "configs/config.yaml"
                    # cfg = load_hydra_config(config_path)
                    # if mlflow.active_run():
                    #     mlflow.end_run()
                    # mlflow.autolog()
                    # for n_iter in cfg.parameters.n_iter:
                    #     for c in cfg.parameters.c:
                    #         for penalty in cfg.parameters.penalty:
                    #             with mlflow.start_run(run_name=f"n_iter={n_iter}_c={c}_penalty={penalty}"):
                    #                 try:
                    #                     penalty = penalty.lower().replace(' ', '_')
                    #                     penalty_param = penalty if penalty != "none" else None
                    #                     clf = LogisticRegression(
                    #                         solver='lbfgs',
                    #                         max_iter=n_iter,
                    #                         C=c,
                    #                         penalty=penalty_param
                    #                     )
                    #                     clf.fit(X, y)
                    #                     preds = pd.Series(clf.predict_proba(X)[:, 0])
                    #                     protected_attribute = features_4_scanning[0]
                    #                     ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                    #                     bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                    #                     display(bs1)
                    #                     mlflow.log_param("n_iter", n_iter)
                    #                     mlflow.log_param("C", c)
                    #                     mlflow.log_param("penalty", penalty)
                    #                     mlflow.log_metric("mean_ot_val", bs1["ot_val"].mean())
                    #                     print(f"Run complete: n_iter={n_iter}, C={c}, penalty={penalty}")
                    #                 except Exception as e:
                    #                     print(f"Error for n_iter={n_iter}, C={c}, penalty={penalty}: {e}")
                    #                     mlflow.log_param("error", str(e))
                    # best_run = get_best_configuration(metric_name="mean_ot_val", maximize=False)
                    # parameters = best_run["parameters"]
                    # metric = best_run["metric"]
                    # ot_best = widgets.VBox([
                    #     widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                    #     widgets.HTML(f"<b>Metric (mean_ot_val):</b> {metric:.6f}"),
                    #     widgets.HTML("<b>Best Parameters:</b>"),
                    #     widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")
                    # ])
                    # display(ot_best)

                elif "Fairness Aware Counterfactuals for Subgroups (FACTS)" in selected_algorithms:
                    data = clean_dataset(X_adult.assign(income=y_adult), "adult")
                    y = data['income']
                    X = data.drop('income', axis=1)
                    mlflow.sklearn.autolog()
                    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)
                    categorical_features = X.select_dtypes(include=["object", "category"]).columns.to_list()
                    categorical_transformer = ColumnTransformer(
                        transformers=[("one-hot-encoder", OneHotEncoder(), categorical_features)],
                        remainder="passthrough"
                    )
                    n_iters = selected_algorithm_parameters['num_iterations']
                    model = Pipeline([
                        ("one-hot-encoder", categorical_transformer),
                        ("clf", LogisticRegression(max_iter=n_iters))
                    ])
                    model.fit(X_train, y_train)
                    metric = selected_algorithm_parameters['metric']
                    top_count = selected_algorithm_parameters['top_count']
                    viewpoint = selected_algorithm_parameters['viewpoint']
                    phi = selected_algorithm_parameters.get('phi', 0.5)
                    c = selected_algorithm_parameters.get('c', 0.5)
                    freq_itemset_min_supp = selected_algorithm_parameters['itemset_min_support']
                    protected_attribute = features_4_scanning[0]
                    detector = FACTS(
                        clf=model,
                        prot_attr=protected_attribute,
                        freq_itemset_min_supp=freq_itemset_min_supp,
                        feature_weights={f: 1 for f in X.columns},
                        feats_not_allowed_to_change=not_to_change_features,
                    )
                    detector.fit(X_test, verbose=False)
                    filter_seq = ["remove-fair-rules"]
                    if metric == "equal-effectiveness-within-budget":
                        filter_seq.append("remove-above-thr-cost")
                    if metric in ["equal-choice-for-recourse", "equal-cost-of-effectiveness"] and viewpoint == "macro":
                        filter_seq.append("remove-below-thr-corr")
                    if metric == "equal-cost-of-effectiveness" and viewpoint == "micro":
                        filter_seq.append("keep-rules-until-thr-corr-reached")
                    detector.bias_scan(
                        metric=metric,
                        viewpoint=viewpoint,
                        sort_strategy="max-cost-diff-decr",
                        top_count=top_count,
                        filter_sequence=filter_seq,
                        phi=phi,
                        c=c,
                    )
                    correctness_metric = metric in ["equal-effectiveness", "equal-effectiveness-within-budget"]
                    detector.print_recourse_report(
                        show_subgroup_costs=True,
                        show_action_costs=True,
                        correctness_metric=correctness_metric,
                    )

                else:
                    # Fallback to AutoML pipeline
                    selected_metric = metric_selector.value
                    df_results = run_automl_pipeline(
                        search_algo="tpe",
                        selected_algorithms=selected_algorithms,
                        selected_algorithm_parameters=selected_algorithm_parameters
                    )
                    if not df_results.empty and selected_metric in df_results.columns:
                        best_run = df_results.loc[df_results[selected_metric].idxmax()]
                        metric_value = best_run[selected_metric]
                        parameters = {
                            "Preprocessing": best_run["Preprocessing"],
                            "Postprocessing": best_run["Postprocessing"],
                            "Accuracy": best_run.get("Accuracy", "N/A"),
                            "Conformal Coverage": best_run.get("Conformal Coverage", "N/A")
                        }
                        ot_best = widgets.VBox([
                            widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                            widgets.HTML(f"<b>Metric ({selected_metric}):</b> {metric_value:.6f}"),
                            widgets.HTML("<b>Best Parameters:</b>"),
                            widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")
                        ])
                        display(ot_best)
                    else:
                        display(widgets.HTML(f"<b>No valid configuration found or metric '{selected_metric}' missing.</b>"))

        run_detection()
        clean_toolkit_content()
        display(output_widget)


##################################################################################################################


from IPython.display import display, HTML
import ipywidgets as widgets

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

def build_fairness_logs_from_mlflow_df_old(df_results):
    """
    Extract fairness metrics from a DataFrame returned by `run_automl_pipelinev2`.

    Returns:
        dict: { "Preprocessing+Postprocessing": {metric: {"Preprocessed": val, "Postprocessed": val}} }
    """
    fairness_logs = {}

    for _, row in df_results.iterrows():
        combo_key = f"{row['Preprocessing']}+{row['Postprocessing']}"
        fairness_logs[combo_key] = {}

        metrics = ["SPD", "DI", "EOD", "AOD", "Theil"]
        for metric in metrics:
            pre_val = row.get(f"preprocessed_{metric}", None)
            post_val = row.get(f"postprocessed_{metric}", None)
            if pre_val is not None or post_val is not None:
                fairness_logs[combo_key][metric] = {
                    "Preprocessed": pre_val,
                    "Postprocessed": post_val
                }

    return fairness_logs


def build_fairness_logs_from_best_run(best_run):
    """
    Build fairness metrics logs from the best run for plotting.

    Returns:
        dict: { "Preprocessing+Postprocessing": {metric: {"Preprocessed": val, "Postprocessed": val}} }
    """
    fairness_logs = {}
    combo_key = f"{best_run['Preprocessing']}+{best_run['Postprocessing']}"
    fairness_logs[combo_key] = {}

    metrics = ["SPD", "DI", "EOD", "AOD", "Theil"]
    for metric in metrics:
        pre_val = best_run.get(f"{metric} (Pre)")
        post_val = best_run.get(f"{metric} (Post)")
        if pre_val is not None or post_val is not None:
            fairness_logs[combo_key][metric] = {
                "Preprocessed": pre_val,
                "Postprocessed": post_val
            }

    return fairness_logs


def build_fairness_logs_from_mlflow_df(df_results):
    """
    Extract fairness metrics from df_results using explicitly labeled columns like 'SPD (Pre)' and 'SPD (Post)'.

    Returns:
        dict: { "Preprocessing+Postprocessing": {metric: {"Preprocessed": val, "Postprocessed": val}} }
    """
    fairness_logs = {}

    for _, row in df_results.iterrows():
        combo_key = f"{row['Preprocessing']}+{row['Postprocessing']}"
        fairness_logs[combo_key] = {}

        metrics = ["SPD", "DI", "EOD", "AOD", "Theil"]
        for metric in metrics:
            pre_val = row.get(f"{metric} (Pre)", None)
            post_val = row.get(f"{metric} (Post)", None)
            if pre_val is not None or post_val is not None:
                fairness_logs[combo_key][metric] = {
                    "Preprocessed": pre_val,
                    "Postprocessed": post_val
                }

    return fairness_logs



import matplotlib.pyplot as plt
import ipywidgets as widgets

def plot_fairness_metrics_comparison(fairness_logs):
    """
    Plot bar charts for preprocessed vs postprocessed fairness metrics,
    with dynamic scaling and clearer visualization.
    """
    for combo_key, metric_dict in fairness_logs.items():
        if not metric_dict:
            continue

        num_metrics = len(metric_dict)
        fig, axes = plt.subplots(num_metrics, 1, figsize=(10, 5 * num_metrics))

        if num_metrics == 1:
            axes = [axes]

        for ax, (metric, stages) in zip(axes, metric_dict.items()):
            labels = list(stages.keys())  # e.g., ['Preprocessed', 'Postprocessed']
            values = list(stages.values())

            ax.bar(labels, values)
            ax.set_title(f"{combo_key} - {metric} (Pre vs Post)", fontsize=14)
            ax.set_ylabel(f"{metric} Value", fontsize=12)
            ax.tick_params(axis='both', labelsize=12)

            # Dynamic Y-axis handling
            if all(-1 < v < 1 for v in values):
                ax.set_ylim(-0.2, 0.2)
            else:
                ax.set_yscale('symlog', linthresh=0.01)
                ax.set_ylim(-1.1, 1.1)

            ax.grid(True, axis='y')

        plt.tight_layout()
        display(widgets.HTML(f"<h3>Fairness Metrics for: <code>{combo_key}</code></h3>"))
        display(fig)
        plt.close(fig)


def plot_fairness_metrics_comparison_old(fairness_logs: dict):
    """
    Display 4 bar plots (SPD, AOD, EOD, DI) showing fairness metrics before and after post-processing.

    Args:
        fairness_logs (dict): A dictionary where keys are "Preprocessing+Postprocessing" combo names
                              and values are a list of dicts with fairness metrics (keys: SPD, AOD, EOD, DI)
                              at different stages (raw, preprocessed, postprocessed).
    """

    for combo_key, stages in fairness_logs.items():
        if len(stages) == 0:
            continue

        fig, axes = plt.subplots(4, 1, figsize=(8, 16))
        metrics = [
            "statistical_parity_difference",
            "disparate_impact",
            "equal_opportunity_difference",
            "average_odds_difference",
            "theil_index"
        ]

        for i, metric in enumerate(metrics):
            values = []
            labels = []
            for idx, stage_dict in enumerate(stages):
                label = ["Raw", "Preprocessed", "Postprocessed", "Postprocessed Fair"]
                val = stage_dict.get(metric, None)
                if val is not None:
                    values.append(val)
                    labels.append(label[idx] if idx < len(label) else f"Stage {idx+1}")

            if values:
                axes[i].bar(labels, values)
                axes[i].set_title(f"{combo_key} - {metric}")
                axes[i].set_ylabel("Value")
                axes[i].set_ylim(0, 1)
                axes[i].grid(True, axis='y')

        plt.tight_layout()
        display(widgets.HTML(f"<h3>Fairness Metrics for: <code>{combo_key}</code></h3>"))
        display(plt.gcf())
        plt.close()



def run_bias_detection_newer_version(b):
    global global_output, current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Run"
    update_button_color(b)
    clean_toolkit_content()

    with global_output:
        if df is None:
            warning_msg = """
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: red; font-size: 16px;">Warning:</span> 
                <span style="font-size: 14px">No dataset selected.</span>
            </div>"""
            display(HTML(warning_msg))
            return

        # Algorithm + parameters summary
        parameter_summary = ""
        for alg in selected_algorithms:
            param_dict = selected_algorithm_parameters.get(alg, {})
            param_items = "".join([f"<li><b>{k}:</b> {v}</li>" for k, v in param_dict.items()])
            parameter_summary += f"<li><b>{alg}</b><ul>{param_items}</ul></li>"

        info_message_run = f"""
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-play-circle" style="margin-right: 10px;"></i> Run Algorithm
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                We have run the following algorithm(s) on the selected dataset <b>{selected_dataset}</b>:
            </p>
            <ul style="font-size: 13px;">{parameter_summary}</ul>
        </div>"""
        display(widgets.HTML(value=info_message_run))

        metric_selector = widgets.Dropdown(
            options=["FLAML Best Run", "Accuracy", "Conformal Coverage"],
            value="FLAML Best Run",
            description="Best by:",
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='50%')
        )
        display(metric_selector)

        display(widgets.VBox([loading_bar]))
        output_widget = widgets.Output()

        def run_detection():
            with output_widget:
                if "Multi-Dimensional Subset Scan (MDSS)" in selected_algorithms:
                    pass
                    # config_path = "configs/mdss.yaml"
                    # cfg = load_hydra_config(config_path)
                    # scoring_function = Bernoulli(direction='negative')
                    # scanner = MDSS(scoring_function)
                    # mlflow.sklearn.autolog()
                    # if mlflow.active_run():
                    #     mlflow.end_run()
                    # for n_iter in cfg.parameters.num_iters:
                    #     for penalty in cfg.parameters.penalty:
                    #         with mlflow.start_run(run_name=f"n_iter={n_iter}_penalty={penalty}"):
                    #             try:
                    #                 scanned_subset, _ = scanner.scan(
                    #                     df[features_4_scanning],
                    #                     expectations=df['predicted_conversion'],
                    #                     outcomes=df['true_conversion'],
                    #                     penalty=penalty,
                    #                     num_iters=n_iter,
                    #                     verbose=False
                    #                 )
                    #                 print(f"Run complete: n_iter={n_iter}, penalty={penalty}")
                    #                 mlflow.log_param("num_iter", n_iter)
                    #                 mlflow.log_metric("penalty", penalty)
                    #                 print_report(df, scanned_subset)
                    #             except Exception as e:
                    #                 print(f"Error for n_iter={n_iter}, penalty={penalty}: {e}")
                    #                 mlflow.log_param("error", str(e))

                elif "Bias Detection via Optimal Transport (Logistic Regression)" in selected_algorithms:
                    pass
                    # data_raw = load_preproc_data_adult()
                    # data = data_raw.convert_to_dataframe()[0]
                    # X = data.drop('Income Binary', axis=1)
                    # y = data['Income Binary']
                    # config_path = "configs/config.yaml"
                    # cfg = load_hydra_config(config_path)
                    # if mlflow.active_run():
                    #     mlflow.end_run()
                    # mlflow.autolog()
                    # for n_iter in cfg.parameters.n_iter:
                    #     for c in cfg.parameters.c:
                    #         for penalty in cfg.parameters.penalty:
                    #             with mlflow.start_run(run_name=f"n_iter={n_iter}_c={c}_penalty={penalty}"):
                    #                 try:
                    #                     penalty = penalty.lower().replace(' ', '_')
                    #                     penalty_param = penalty if penalty != "none" else None
                    #                     clf = LogisticRegression(
                    #                         solver='lbfgs',
                    #                         max_iter=n_iter,
                    #                         C=c,
                    #                         penalty=penalty_param
                    #                     )
                    #                     clf.fit(X, y)
                    #                     preds = pd.Series(clf.predict_proba(X)[:, 0])
                    #                     protected_attribute = features_4_scanning[0]
                    #                     ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                    #                     bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                    #                     display(bs1)
                    #                     mlflow.log_param("n_iter", n_iter)
                    #                     mlflow.log_param("C", c)
                    #                     mlflow.log_param("penalty", penalty)
                    #                     mlflow.log_metric("mean_ot_val", bs1["ot_val"].mean())
                    #                     print(f"Run complete: n_iter={n_iter}, C={c}, penalty={penalty}")
                    #                 except Exception as e:
                    #                     print(f"Error for n_iter={n_iter}, C={c}, penalty={penalty}: {e}")
                    #                     mlflow.log_param("error", str(e))
                    # best_run = get_best_configuration(metric_name="mean_ot_val", maximize=False)
                    # parameters = best_run["parameters"]
                    # metric = best_run["metric"]
                    # ot_best = widgets.VBox([
                    #     widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                    #     widgets.HTML(f"<b>Metric (mean_ot_val):</b> {metric:.6f}"),
                    #     widgets.HTML("<b>Best Parameters:</b>"),
                    #     widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")
                    # ])
                    # display(ot_best)

                elif "Fairness Aware Counterfactuals for Subgroups (FACTS)" in selected_algorithms:
                    data = clean_dataset(X_adult.assign(income=y_adult), "adult")
                    y = data['income']
                    X = data.drop('income', axis=1)
                    mlflow.sklearn.autolog()
                    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)
                    categorical_features = X.select_dtypes(include=["object", "category"]).columns.to_list()
                    categorical_transformer = ColumnTransformer(
                        transformers=[("one-hot-encoder", OneHotEncoder(), categorical_features)],
                        remainder="passthrough"
                    )
                    n_iters = selected_algorithm_parameters['num_iterations']
                    model = Pipeline([
                        ("one-hot-encoder", categorical_transformer),
                        ("clf", LogisticRegression(max_iter=n_iters))
                    ])
                    model.fit(X_train, y_train)
                    metric = selected_algorithm_parameters['metric']
                    top_count = selected_algorithm_parameters['top_count']
                    viewpoint = selected_algorithm_parameters['viewpoint']
                    phi = selected_algorithm_parameters.get('phi', 0.5)
                    c = selected_algorithm_parameters.get('c', 0.5)
                    freq_itemset_min_supp = selected_algorithm_parameters['itemset_min_support']
                    protected_attribute = features_4_scanning[0]
                    detector = FACTS(
                        clf=model,
                        prot_attr=protected_attribute,
                        freq_itemset_min_supp=freq_itemset_min_supp,
                        feature_weights={f: 1 for f in X.columns},
                        feats_not_allowed_to_change=not_to_change_features,
                    )
                    detector.fit(X_test, verbose=False)
                    filter_seq = ["remove-fair-rules"]
                    if metric == "equal-effectiveness-within-budget":
                        filter_seq.append("remove-above-thr-cost")
                    if metric in ["equal-choice-for-recourse", "equal-cost-of-effectiveness"] and viewpoint == "macro":
                        filter_seq.append("remove-below-thr-corr")
                    if metric == "equal-cost-of-effectiveness" and viewpoint == "micro":
                        filter_seq.append("keep-rules-until-thr-corr-reached")
                    detector.bias_scan(
                        metric=metric,
                        viewpoint=viewpoint,
                        sort_strategy="max-cost-diff-decr",
                        top_count=top_count,
                        filter_sequence=filter_seq,
                        phi=phi,
                        c=c,
                    )
                    correctness_metric = metric in ["equal-effectiveness", "equal-effectiveness-within-budget"]
                    detector.print_recourse_report(
                        show_subgroup_costs=True,
                        show_action_costs=True,
                        correctness_metric=correctness_metric,
                    )

                else:
                    selected_metric = metric_selector.value
                    df_results, best_run, mlflow_runs = run_automl_pipelinev2(
                        search_algo="tpe",
                        selected_algorithms=selected_algorithms,
                        selected_algorithm_parameters=selected_algorithm_parameters
                    )

                    if df_results is None or df_results.empty:
                        display(widgets.HTML("<b>No MLflow results returned.</b>"))
                        return

                    if selected_metric not in df_results.columns:
                        display(widgets.HTML(f"<b>Metric '{selected_metric}' not found in results.</b>"))
                        return

                    best_run = df_results.loc[df_results[selected_metric].idxmax()]
                    metric_value = best_run[selected_metric]
                    parameters = {
                        "Preprocessing": best_run["Preprocessing"],
                        "Postprocessing": best_run["Postprocessing"],
                        "Accuracy": best_run.get("Accuracy", "N/A"),
                        "Conformal Coverage": best_run.get("Conformal Coverage", "N/A")
                    }
                    ot_best = widgets.VBox([
                        widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                        widgets.HTML(f"<b>Metric ({selected_metric}):</b> {metric_value:.6f}"),
                        widgets.HTML("<b>Best Parameters:</b>"),
                        widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")
                    ])
                    display(ot_best)

                    # Fairness metrics visualization
                    # fairness_logs = {}
                    # combo_key = f"{best_run['Preprocessing']}+{best_run['Postprocessing']}"
                    # fairness_logs[combo_key] = []

                    # for prefix in ["raw_", "preprocessed_", "postprocessed_"]:
                    #     stage_metrics = {}
                    #     for metric in ["SPD", "AOD", "EOD", "DI"]:
                    #         col_name = f"{prefix}{metric}"
                    #         if col_name in best_run:
                    #             stage_metrics[metric.upper()] = best_run[col_name]
                    #     if stage_metrics:
                    #         fairness_logs[combo_key].append(stage_metrics)

                    # post_fair_metrics = {}
                    # for metric in ["SPD", "AOD", "EOD", "DI"]:
                    #     col_name = f"postprocessed_fair_{metric}"
                    #     if col_name in best_run:
                    #         post_fair_metrics[metric.upper()] = best_run[col_name]
                    # if post_fair_metrics:
                    #     fairness_logs[combo_key].append(post_fair_metrics)
                    if df_results is not None:
                        fairness_logs = build_fairness_logs_from_best_run(best_run)
                        # import pprint
                        # pprint.pprint(fairness_logs)
                        plot_fairness_metrics_comparison(fairness_logs)
                    # Call to visualization function
                    #plot_fairness_metrics_comparison(fairness_logs)

        run_detection()
        clean_toolkit_content()
        display(output_widget)




###################################################################################################################






def run_bias_detection_new(b):
    """ Currently only MDSS supported for demo purposes. """
    global global_output, current_state, selected_algorithm_parameters, selected_algorithms
    current_state = "Run"
    update_button_color(b)
    clean_toolkit_content()

    with global_output:
        if df is None:
            warning_msg = """
            <div>    
                <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: red; font-size: 16px;">Warning:</span> 
                <span style="font-size: 14px">No dataset selected.</span>
            </div>"""
            display(HTML(warning_msg))
            return

        # 🔽 NEW: Construct parameter summary for all selected algorithms
        parameter_summary = ""
        for alg in selected_algorithms:
            param_dict = selected_algorithm_parameters.get(alg, {})
            param_items = "".join([f"<li><b>{k}:</b> {v}</li>" for k, v in param_dict.items()])
            parameter_summary += f"""
                <li>
                    <b>{alg}</b>
                    <ul>{param_items}</ul>
                </li>
            """

        # 🔽 Replace old info message with enhanced HTML block
        info_message_run = f"""
        <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
            <span style="font-weight: bold; color: #333333; font-size: 16px;">
                <i class="fa fa-play-circle" style="margin-right: 10px;"></i> Run Algorithm
            </span>
            <p style="font-size: 14px; margin-top: 10px;">
                We have run the following algorithm(s) on the selected dataset <b>{selected_dataset}</b>:
            </p>
            <ul style="font-size: 13px;">
                {parameter_summary}
            </ul>
        </div>
        """

        # Create and show loading spinner
        loading_box = widgets.VBox([loading_bar], layout=widgets.Layout(align_items='center', width='100%'))
        display(loading_box)

        output_widget = widgets.Output()

        def run_detection():
            with output_widget:
                pass  # Keep your original implementation here
                if selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    # Get selected parameters
                    config_path = "configs/mdss.yaml"  
                    cfg = load_hydra_config(config_path)
                    # selected_scoring_function = selected_algorithm_parameters['scoring_function']
                    # Init parameters
                    # scoring_function = None
                    # num_iters = selected_algorithm_parameters['num_iterations']
                    # penalty = selected_algorithm_parameters['penalty']
                    # if selected_scoring_function == "Bernoulli":
                    scoring_function = Bernoulli(direction='negative')
                    scanner = MDSS(scoring_function)
                    mlflow.sklearn.autolog()
                    if mlflow.active_run():
                        mlflow.end_run()
                    for n_iter in cfg.parameters.num_iters:
                        for penalty in cfg.parameters.penalty:
                            with mlflow.start_run(run_name=f"n_iter={n_iter}_penalty={penalty}"):
                                try:
                                    scanned_subset, _ = scanner.scan(df[features_4_scanning], 
                                        expectations = df['predicted_conversion'],
                                        outcomes = df['true_conversion'], 
                                        penalty = penalty, 
                                        num_iters = n_iter,
                                        verbose = False)
                                    print(f"Run complete: n_iter={n_iter}, penalty={penalty}") 
                                    mlflow.log_param("num_iter", n_iter)
                                    mlflow.log_metric("penalty", penalty)
                                    print_report(df, scanned_subset)
                                except Exception as e:
                                    print(f"Error for n_iter={n_iter}, penalty={penalty}: {e}")
                                    mlflow.log_param("error", str(e))    
                            #mlflow.log_metric("ot_val", bs1["ot_val"])
                            # mlflow.log_metric("least_penalty", bs1["ot_val"].mean())       
                elif selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    # Load the version of Adult that is used in the ot_demo ipynb
                    
                    data_raw = load_preproc_data_adult()
                    data = data_raw.convert_to_dataframe()[0]
                    data.head()
                    
                    X = data.drop('Income Binary',axis=1)
                    y = data['Income Binary']
                    
                    config_path = "configs/config.yaml"  
                    cfg = load_hydra_config(config_path)
                    if mlflow.active_run():
                        mlflow.end_run()
                    mlflow.autolog()
                    for n_iter in cfg.parameters.n_iter:
                        for c in cfg.parameters.c:
                            for penalty in cfg.parameters.penalty:
                                with mlflow.start_run(run_name=f"n_iter={n_iter}_c={c}_penalty={penalty}"):
                                    try:
                                        # Configure penalty
                                        penalty = penalty.lower().replace(' ', '_')
                                        penalty_param = penalty if penalty != "none" else None

                                        # Train Logistic Regression model
                                        clf = LogisticRegression(
                                            solver='lbfgs',
                                            max_iter=n_iter,
                                            C=c,
                                        penalty=penalty_param
                                        )
                                        clf.fit(X, y)

                                        # Predict probabilities and compute OT distance
                                        preds = pd.Series(clf.predict_proba(X)[:, 0])
                                        protected_attribute = features_4_scanning[0]
                                        ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                                        # Create results DataFrame and display
                                        bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                                        display(bs1)
                                        # Log parameters and metrics to MLFlow
                                        mlflow.log_param("n_iter", n_iter)
                                        mlflow.log_param("C", c)
                                        mlflow.log_param("penalty", penalty)
                                        #mlflow.log_metric("ot_val", bs1["ot_val"])
                                        mlflow.log_metric("mean_ot_val", bs1["ot_val"].mean())
                                        print(f"Run complete: n_iter={n_iter}, C={c}, penalty={penalty}")
                                    except Exception as e:
                                        print(f"Error for n_iter={n_iter}, C={c}, penalty={penalty}: {e}")
                                        mlflow.log_param("error", str(e))
                    #identify the best configuration here
                    metric_name="mean_ot_val"
                    best_run = get_best_configuration(metric_name="mean_ot_val", maximize=False)
                    parameters = best_run["parameters"]
                    metric = best_run["metric"]
                    ot_best = widgets.VBox([widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                                  widgets.HTML(f"<b>Metric ({metric_name}):</b> {metric:.20f}"),
                                  widgets.HTML("<b>Best Parameters:</b>"),
                                  widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")])
                    display(ot_best)
                    # n_iter = selected_algorithm_parameters['num_iterations']
                    # c = selected_algorithm_parameters['C']
                    # penalty = selected_algorithm_parameters['penalty'].lower().replace(' ', '_')
                    
                    # clf = LogisticRegression(solver='lbfgs', 
                    #                          max_iter=n_iter, 
                    #                          C=c, 
                    #                          penalty=penalty)
                    # clf.fit(X, y)
                    # preds = pd.Series(clf.predict_proba(X)[:,0])
                    # protected_attribute = features_4_scanning[0]
                    # ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                    # bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                    # display(bs1)
                elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    # Load the version of Adult that is used in the ot_demo ipynb
                    #### here, we incrementally build the example model. It consists of one preprocessing step,
                    #### which is to turn categorical features into the respective one-hot encodings, and
                    #### a simple scikit-learn logistic regressor.
                    data = clean_dataset(X_adult.assign(income=y_adult), "adult")

                    # split into train-test data
                    y = data['income']
                    X = data.drop('income', axis=1)
                    mlflow.sklearn.autolog()
                    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)

                    categorical_features = X.select_dtypes(include=["object", "category"]).columns.to_list()
                    
                    categorical_features_onehot_transformer = ColumnTransformer(
                        transformers=[
                            ("one-hot-encoder", OneHotEncoder(), categorical_features)
                        ],
                        remainder="passthrough"
                    )
                    n_iters = selected_algorithm_parameters['num_iterations']
                    
                    model = Pipeline([
                        ("one-hot-encoder", categorical_features_onehot_transformer),
                        ("clf", LogisticRegression(max_iter=n_iters))
                    ])

                    protected_attribute = features_4_scanning[0]
                    with global_output:
                        #### train the model
                        model = model.fit(X_train, y_train)
                        # showcase model's accuracy
                        # y_pred = model.predict(X_test)
                        # print(f"Accuracy = {(y_test.values == y_pred).sum() / y_test.shape[0]:.2%}")
                        
                        # Retrieve selected algorithms
                        metric = selected_algorithm_parameters['metric']
                        top_count = selected_algorithm_parameters['top_count']
                        viewpoint = selected_algorithm_parameters['viewpoint']
                        phi = selected_algorithm_parameters['phi'] if 'phi' in selected_algorithm_parameters else 0.5
                        c = selected_algorithm_parameters['c'] if 'c' in selected_algorithm_parameters else 0.5
                        freq_itemset_min_supp = selected_algorithm_parameters['itemset_min_support']

                        detector = FACTS(
                            clf=model,
                            prot_attr=protected_attribute,
                            freq_itemset_min_supp=freq_itemset_min_supp,
                            feature_weights={f: 1 for f in X.columns},
                            feats_not_allowed_to_change=not_to_change_features,
                        )
                        detector.fit(X_test, verbose=False)

                        filter_seq = ["remove-fair-rules"]
                        if metric == "equal-effectiveness-within-budget":
                            filter_seq.append("remove-above-thr-cost")
                        if metric == "equal-choice-for-recourse" or (metric == "equal-cost-of-effectiveness" and viewpoint == "macro"):
                            filter_seq.append("remove-below-thr-corr")
                        if metric == "equal-cost-of-effectiveness" and viewpoint == "micro":
                            filter_seq.append("keep-rules-until-thr-corr-reached")
                        detector.bias_scan(
                            metric=metric,
                            viewpoint=viewpoint,
                            sort_strategy="max-cost-diff-decr",
                            top_count=top_count,
                            filter_sequence=filter_seq,
                            phi=phi,
                            c=c,
                        )
                        correctness_metric = metric in ["equal-effectiveness", "equal-effectiveness-within-budget"]
                        detector.print_recourse_report(
                            show_subgroup_costs=True,
                            show_action_costs=True,
                            correctness_metric=correctness_metric,
                        )
        
        run_detection()

        clean_toolkit_content()
        display(widgets.HTML(value=info_message_run))
        display(output_widget)



#########adding interface to run_mlflow.py in the docker container

import subprocess
import time
import mlflow
import mlflow.tracking
import pandas as pd
from joblib import Parallel, delayed

def run_automl_pipelinev2(
    search_algo: str = "tpe",
    mlflow_uri: str = "http://192.168.1.151:5000",
    output_csv: str = "mlflow_results.csv",
    parallel_limit: int = 3,
    selected_algorithms: list = None,
    selected_algorithm_parameters: dict = None
):
    import mlflow
    import time
    import subprocess
    import pandas as pd
    from joblib import Parallel, delayed

    if selected_algorithms is None:
        selected_algorithms = []

    # ✅ Set MLflow URI
    mlflow.set_tracking_uri(mlflow_uri)

    # ✅ Known combinations
    base_commands = [
        ["Reweighing", "EqOddsPostprocessing"],
        ["Reweighing", "CalibratedEqOddsPostprocessing"],
        ["Reweighing", "RejectOptionClassification"]
    ]

    docker_commands = []
    for combo in base_commands:
        if all(any(alg_fragment in sel_alg for sel_alg in selected_algorithms) for alg_fragment in combo):
            joined = " ".join(combo)
            docker_commands.append(
                f"docker run --rm --network=host workable-experiment {joined} {search_algo}"
            )
    if not docker_commands:
        print("❌ No matching Docker commands for selected algorithms.")
        return None, None

    def stop_experiment_containers():
        try:
            result = subprocess.run(["docker", "ps", "--filter", "ancestor=workable-experiment", "-q"], stdout=subprocess.PIPE, text=True)
            container_ids = result.stdout.strip().split("\n")
            for container in container_ids:
                if container:
                    subprocess.run(["docker", "stop", container])
                    print(f"🛑 Stopped container: {container}")
        except Exception as e:
            print(f"❌ Error stopping containers: {e}")

    def run_docker(command):
        print(f"🚀 Running: {command}")
        try:
            process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            start_time = time.time()
            while process.poll() is None:
                time.sleep(1)
                if time.time() - start_time > 50:
                    print(f"⏳ Timeout! Killing: {command}")
                    process.terminate()
                    stop_experiment_containers()
                    return
            print(f"✅ Completed: {command}")
        except Exception as e:
            print(f"❌ Error running: {command}, {e}")

    for i in range(0, len(docker_commands), parallel_limit):
        batch = docker_commands[i:i+parallel_limit]
        Parallel(n_jobs=len(batch))(delayed(run_docker)(cmd) for cmd in batch)

    time.sleep(20)

    # ✅ Fetch MLflow runs
    client = mlflow.tracking.MlflowClient()
    experiment = client.get_experiment_by_name("Default")
    if not experiment:
        print("❌ MLflow experiment 'Default' not found.")
        return None, None

    experiment_id = experiment.experiment_id
    runs = client.search_runs(experiment_id)

    run_data = []
    for run in runs:
        if run.info.status == "FINISHED" and "mlflow.parentRunId" not in run.data.tags:
            run_data.append({
                "Run ID": run.info.run_id,
                "Preprocessing": run.data.params.get("preprocessing_algorithm", "Unknown"),
                "Postprocessing": run.data.params.get("postprocessing_algorithm", "Unknown"),
                "Accuracy": run.data.metrics.get("accuracy", 0),
                "Conformal Coverage": run.data.metrics.get("conformal_coverage", 0),
                "FLAML Best Run": run.data.metrics.get("flaml.best_run", 0),
                # ✅ Postprocessed Fairness Metrics
                "SPD (Post)": run.data.metrics.get("postprocessed_statistical_parity_difference", 0),
                "DI (Post)": run.data.metrics.get("postprocessed_disparate_impact", 0),
                "EOD (Post)": run.data.metrics.get("postprocessed_equal_opportunity_difference", 0),
                "AOD (Post)": run.data.metrics.get("postprocessed_average_odds_difference", 0),
                "Theil (Post)": run.data.metrics.get("postprocessed_theil_index", 0),

                # ✅ Preprocessed Fairness Metrics
                "SPD (Pre)": run.data.metrics.get("preprocessed_statistical_parity_difference", 0),
                "DI (Pre)": run.data.metrics.get("preprocessed_disparate_impact", 0),
                "EOD (Pre)": run.data.metrics.get("preprocessed_equal_opportunity_difference", 0),
                "AOD (Pre)": run.data.metrics.get("preprocessed_average_odds_difference", 0),
                "Theil (Pre)": run.data.metrics.get("preprocessed_theil_index", 0)
            })

    df = pd.DataFrame(run_data)
    # df.to_csv(output_csv, index=False)
    # print(f"✅ Results saved to `{output_csv}`")
    # print(df)

    if not df.empty:
        best_flaml_run = df.loc[df["FLAML Best Run"].idxmax()]
        print(f"🏆 Best Run (FLAML):\n{best_flaml_run}")
    else:
        print("❌ No successful runs found.")
        best_flaml_run = None

    return df, best_flaml_run, runs



def run_automl_pipeline(
    search_algo: str = "tpe",
    mlflow_uri: str = "http://192.168.1.151:5000",
    output_csv: str = "mlflow_results.csv",
    parallel_limit: int = 3,
    selected_algorithms: list = None,
    selected_algorithm_parameters: dict = None
):
    """
    Run fairness-aware AutoML experiments with selected algorithms and parameters using Docker and aggregate MLflow results.

    Args:
        search_algo (str): 'tpe' or 'random'
        mlflow_uri (str): URI to the MLflow tracking server
        output_csv (str): Output file to save results
        parallel_limit (int): Number of parallel containers to run
        selected_algorithms (list): List of selected algorithm names, e.g. ["Reweighing", "EqOddsPostprocessing"]
        selected_algorithm_parameters (dict): Parameter dict per algorithm (not used here, but for future logic)
    """

    if selected_algorithms is None:
        selected_algorithms = []

    # ✅ Set MLflow URI
    mlflow.set_tracking_uri(mlflow_uri)

    # ✅ All known docker command templates
    base_commands = [
        ["Reweighing", "EqOddsPostprocessing"],
        ["Reweighing", "CalibratedEqOddsPostprocessing"],
        ["Reweighing", "RejectOptionClassification"]
    ]

    # ✅ Build only the docker commands that match selected_algorithms
    docker_commands = []
    for combo in base_commands:
        if all(any(alg_fragment in sel_alg for sel_alg in selected_algorithms) for alg_fragment in combo):
            joined = " ".join(combo)
            docker_commands.append(
                #f"docker run --rm --network=host mlflow-automl-experiment {joined} {search_algo}"
                f"docker run --rm --network=host workable-experiment {joined} {search_algo}"
            )
    if not docker_commands:
        print("❌ No matching Docker commands for selected algorithms.")
        return

    # ✅ Function to stop related Docker containers
    def stop_experiment_containers():
        try:
            result = subprocess.run(["docker", "ps", "--filter", "ancestor=workable-experiment", "-q"], stdout=subprocess.PIPE, text=True)
            container_ids = result.stdout.strip().split("\n")
            for container in container_ids:
                if container:
                    subprocess.run(["docker", "stop", container])
                    print(f"🛑 Stopped container: {container}")
        except Exception as e:
            print(f"❌ Error stopping containers: {e}")

    # ✅ Run each Docker command with timeout
    def run_docker(command):
        #print(f"🚀 Running: {command}")
        try:
            process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            start_time = time.time()
            while process.poll() is None:
                time.sleep(1)
                if time.time() - start_time > 50:
                    print(f"⏳ Timeout! Killing: {command}")
                    process.terminate()
                    stop_experiment_containers()
                    return
            #print(f"✅ Completed: {command}")
        except Exception as e:
            print(f"❌ Error running: {command}, {e}")

    # ✅ Run Docker commands in parallel batches
    for i in range(0, len(docker_commands), parallel_limit):
        batch = docker_commands[i:i+parallel_limit]
        Parallel(n_jobs=len(batch))(delayed(run_docker)(cmd) for cmd in batch)

    time.sleep(20)  # wait for MLflow logs to flush

    # ✅ Fetch completed MLflow runs
    client = mlflow.tracking.MlflowClient()
    experiment = client.get_experiment_by_name("Default")
    if not experiment:
        print("❌ MLflow experiment 'Default' not found.")
        return

    experiment_id = experiment.experiment_id
    runs = client.search_runs(experiment_id)

    # ✅ Extract relevant run info
    run_data = []
    for run in runs:
        if run.info.status == "FINISHED":
            run_data.append({
                "Run ID": run.info.run_id,
                "Preprocessing": run.data.params.get("preprocessing_algorithm", "Unknown"),
                "Postprocessing": run.data.params.get("postprocessing_algorithm", "Unknown"),
                "Accuracy": run.data.metrics.get("accuracy", 0),
                "Conformal Coverage": run.data.metrics.get("conformal_coverage", 0),
                "FLAML Best Run": run.data.metrics.get("flaml.best_run", 0)
            })

    df = pd.DataFrame(run_data)
    df.to_csv(output_csv, index=False)
    print(f"✅ Results saved to `{output_csv}`")
    print(df)

    if not df.empty:
        best_flaml_run = df.loc[df["FLAML Best Run"].idxmax()]
        print(f"🏆 Best Run (FLAML):\n{best_flaml_run}")
    else:
        print("❌ No successful runs found.")
    return df, best_flaml_run




#############################################
def run_bias_detection(b):
    """ Currently only MDSS supported for demo purposes. """
    global global_output, current_state
    current_state = "Run"
    update_button_color(b)
    clean_toolkit_content()
    with global_output:
        if df is None:
            warning_msg = """
            <div>    
                <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: red; font-size: 16px;">Warning:</span> 
                <span style="font-size: 14px">No dataset selected.</span>
            </div>"""
            display(HTML(warning_msg))
            return
        else:
            info_message_run = f"""
            <div style="background-color: #f0f0f0; border: 1px solid #ccc; color: #333333; padding: 15px; border-radius: 5px; width: auto; margin: 10px auto; text-align: left;">
                <span style="font-weight: bold; color: #333333; font-size: 16px;">
                    <i class="fa fa-play-circle" style="margin-right: 10px;"></i> Run Algorithm
                </span>
                <p style="font-size: 14px; margin-top: 10px;">
                    We have run the algorithm <b>{selected_algorithm}</b> with multiple parameters on the selected dataset <b>{selected_dataset}</b>.
                </p>                
            </div>
            """

        # Create a VBox to center the loading bar
        loading_box = widgets.VBox([loading_bar], layout=widgets.Layout(align_items='center', width='100%'))
        
        # Display the centered loading bar initially
        display(loading_box)
        
        # Create an Output widget to capture all the outputs
        output_widget = widgets.Output()

        # Define a function to run the actual bias detection
        def run_detection():
            with output_widget:
                if selected_algorithm == "Multi-Dimensional Subset Scan (MDSS)":
                    # Get selected parameters
                    config_path = "configs/mdss.yaml"  
                    cfg = load_hydra_config(config_path)
                    # selected_scoring_function = selected_algorithm_parameters['scoring_function']
                    # Init parameters
                    # scoring_function = None
                    # num_iters = selected_algorithm_parameters['num_iterations']
                    # penalty = selected_algorithm_parameters['penalty']
                    # if selected_scoring_function == "Bernoulli":
                    scoring_function = Bernoulli(direction='negative')
                    scanner = MDSS(scoring_function)
                    mlflow.sklearn.autolog()
                    if mlflow.active_run():
                        mlflow.end_run()
                    for n_iter in cfg.parameters.num_iters:
                        for penalty in cfg.parameters.penalty:
                            with mlflow.start_run(run_name=f"n_iter={n_iter}_penalty={penalty}"):
                                try:
                                    scanned_subset, _ = scanner.scan(df[features_4_scanning], 
                                        expectations = df['predicted_conversion'],
                                        outcomes = df['true_conversion'], 
                                        penalty = penalty, 
                                        num_iters = n_iter,
                                        verbose = False)
                                    print(f"Run complete: n_iter={n_iter}, penalty={penalty}") 
                                    mlflow.log_param("num_iter", n_iter)
                                    mlflow.log_metric("penalty", penalty)
                                    print_report(df, scanned_subset)
                                except Exception as e:
                                    print(f"Error for n_iter={n_iter}, penalty={penalty}: {e}")
                                    mlflow.log_param("error", str(e))    
                            #mlflow.log_metric("ot_val", bs1["ot_val"])
                            # mlflow.log_metric("least_penalty", bs1["ot_val"].mean())       
                elif selected_algorithm == "Bias Detection via Optimal Transport (Logistic Regression)":
                    # Load the version of Adult that is used in the ot_demo ipynb
                    
                    data_raw = load_preproc_data_adult()
                    data = data_raw.convert_to_dataframe()[0]
                    data.head()
                    
                    X = data.drop('Income Binary',axis=1)
                    y = data['Income Binary']
                    
                    config_path = "configs/config.yaml"  
                    cfg = load_hydra_config(config_path)
                    if mlflow.active_run():
                        mlflow.end_run()
                    mlflow.autolog()
                    for n_iter in cfg.parameters.n_iter:
                        for c in cfg.parameters.c:
                            for penalty in cfg.parameters.penalty:
                                with mlflow.start_run(run_name=f"n_iter={n_iter}_c={c}_penalty={penalty}"):
                                    try:
                                        # Configure penalty
                                        penalty = penalty.lower().replace(' ', '_')
                                        penalty_param = penalty if penalty != "none" else None

                                        # Train Logistic Regression model
                                        clf = LogisticRegression(
                                            solver='lbfgs',
                                            max_iter=n_iter,
                                            C=c,
                                        penalty=penalty_param
                                        )
                                        clf.fit(X, y)

                                        # Predict probabilities and compute OT distance
                                        preds = pd.Series(clf.predict_proba(X)[:, 0])
                                        protected_attribute = features_4_scanning[0]
                                        ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                                        # Create results DataFrame and display
                                        bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                                        display(bs1)
                                        # Log parameters and metrics to MLFlow
                                        mlflow.log_param("n_iter", n_iter)
                                        mlflow.log_param("C", c)
                                        mlflow.log_param("penalty", penalty)
                                        #mlflow.log_metric("ot_val", bs1["ot_val"])
                                        mlflow.log_metric("mean_ot_val", bs1["ot_val"].mean())
                                        print(f"Run complete: n_iter={n_iter}, C={c}, penalty={penalty}")
                                    except Exception as e:
                                        print(f"Error for n_iter={n_iter}, C={c}, penalty={penalty}: {e}")
                                        mlflow.log_param("error", str(e))
                    #identify the best configuration here
                    metric_name="mean_ot_val"
                    best_run = get_best_configuration(metric_name="mean_ot_val", maximize=False)
                    parameters = best_run["parameters"]
                    metric = best_run["metric"]
                    ot_best = widgets.VBox([widgets.HTML("<br><b><h2>Best Configuration Identified</h2></b>"),
                                  widgets.HTML(f"<b>Metric ({metric_name}):</b> {metric:.20f}"),
                                  widgets.HTML("<b>Best Parameters:</b>"),
                                  widgets.HTML("<ul>" + "".join([f"<li>{key}: {value}</li>" for key, value in parameters.items()]) + "</ul>")])
                    display(ot_best)
                    # n_iter = selected_algorithm_parameters['num_iterations']
                    # c = selected_algorithm_parameters['C']
                    # penalty = selected_algorithm_parameters['penalty'].lower().replace(' ', '_')
                    
                    # clf = LogisticRegression(solver='lbfgs', 
                    #                          max_iter=n_iter, 
                    #                          C=c, 
                    #                          penalty=penalty)
                    # clf.fit(X, y)
                    # preds = pd.Series(clf.predict_proba(X)[:,0])
                    # protected_attribute = features_4_scanning[0]
                    # ot_val1 = ot_distance(y_true=y, y_pred=preds, prot_attr=data[protected_attribute])
                    # bs1 = pd.DataFrame({protected_attribute: ot_val1.keys(), "ot_val": ot_val1.values()})
                    # display(bs1)
                elif selected_algorithm == "Fairness Aware Counterfactuals for Subgroups (FACTS)":
                    # Load the version of Adult that is used in the ot_demo ipynb
                    #### here, we incrementally build the example model. It consists of one preprocessing step,
                    #### which is to turn categorical features into the respective one-hot encodings, and
                    #### a simple scikit-learn logistic regressor.
                    data = clean_dataset(X_adult.assign(income=y_adult), "adult")

                    # split into train-test data
                    y = data['income']
                    X = data.drop('income', axis=1)
                    mlflow.sklearn.autolog()
                    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=random_seed, stratify=y)

                    categorical_features = X.select_dtypes(include=["object", "category"]).columns.to_list()
                    
                    categorical_features_onehot_transformer = ColumnTransformer(
                        transformers=[
                            ("one-hot-encoder", OneHotEncoder(), categorical_features)
                        ],
                        remainder="passthrough"
                    )
                    n_iters = selected_algorithm_parameters['num_iterations']
                    
                    model = Pipeline([
                        ("one-hot-encoder", categorical_features_onehot_transformer),
                        ("clf", LogisticRegression(max_iter=n_iters))
                    ])

                    protected_attribute = features_4_scanning[0]
                    with global_output:
                        #### train the model
                        model = model.fit(X_train, y_train)
                        # showcase model's accuracy
                        # y_pred = model.predict(X_test)
                        # print(f"Accuracy = {(y_test.values == y_pred).sum() / y_test.shape[0]:.2%}")
                        
                        # Retrieve selected algorithms
                        metric = selected_algorithm_parameters['metric']
                        top_count = selected_algorithm_parameters['top_count']
                        viewpoint = selected_algorithm_parameters['viewpoint']
                        phi = selected_algorithm_parameters['phi'] if 'phi' in selected_algorithm_parameters else 0.5
                        c = selected_algorithm_parameters['c'] if 'c' in selected_algorithm_parameters else 0.5
                        freq_itemset_min_supp = selected_algorithm_parameters['itemset_min_support']

                        detector = FACTS(
                            clf=model,
                            prot_attr=protected_attribute,
                            freq_itemset_min_supp=freq_itemset_min_supp,
                            feature_weights={f: 1 for f in X.columns},
                            feats_not_allowed_to_change=not_to_change_features,
                        )
                        detector.fit(X_test, verbose=False)

                        filter_seq = ["remove-fair-rules"]
                        if metric == "equal-effectiveness-within-budget":
                            filter_seq.append("remove-above-thr-cost")
                        if metric == "equal-choice-for-recourse" or (metric == "equal-cost-of-effectiveness" and viewpoint == "macro"):
                            filter_seq.append("remove-below-thr-corr")
                        if metric == "equal-cost-of-effectiveness" and viewpoint == "micro":
                            filter_seq.append("keep-rules-until-thr-corr-reached")
                        detector.bias_scan(
                            metric=metric,
                            viewpoint=viewpoint,
                            sort_strategy="max-cost-diff-decr",
                            top_count=top_count,
                            filter_sequence=filter_seq,
                            phi=phi,
                            c=c,
                        )
                        correctness_metric = metric in ["equal-effectiveness", "equal-effectiveness-within-budget"]
                        detector.print_recourse_report(
                            show_subgroup_costs=True,
                            show_action_costs=True,
                            correctness_metric=correctness_metric,
                        )
                        
        # Run the bias detection function
        run_detection()
        
        # Clear the loading bar and display all the captured outputs
        clean_toolkit_content()
        display(widgets.HTML(value=info_message_run))
        display(output_widget)
      
### for EXPLAINABILITY TOOLKIT ###
def select_explainability_dataset(b):
    global df
    
    with global_output:
        display_running_animation(duration=10)    
        df = helpers.load_adult_income_dataset()
        clean_toolkit_content()
        # update_status(loaded=True)
        display(output_status)

        warning_message = """
        <div style="background-color: #ffeb99; border: 1px solid #ffcc00; color: #cc6600; padding: 10px; border-radius: 5px;">
            <span style="font-weight: bold; color: #ff6600;">Warning:</span>
            Default dataset ("Adult") selected from UCI ML repository.
        </div>
        """

        dataset_select_warning = widgets.HTML(
            value=warning_message,
            layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
        )
        display(dataset_select_warning)
        
        display_dataframe_styled(df.head())
        display_dataframe_styled(df.describe())

        features_to_plot = ["age", "workclass", "education", "marital_status", 
                            "occupation", "race", "gender", "hours_per_week", "income"]
        plot_histogram_grid(features_to_plot)
              
def give_parameters_and_explainability_model(b):
    with global_output:
        warning_message = """
        <div style="background-color: #ffeb99; border: 1px solid #ffcc00; color: #cc6600; padding: 10px; border-radius: 5px;">
            <span style="font-weight: bold; color: #ff6600;">Warning:</span>
            Default input parameters and explainability model (DiCE) selected.
        </div>
        """
        explainability_model_warning = widgets.HTML(
            value=warning_message,
            layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
        )
        clean_toolkit_content()
        display(explainability_model_warning)

def select_explainability_action(b):
    with global_output:
        warning_message = """
        <div style="background-color: #ffeb99; border: 1px solid #ffcc00; color: #cc6600; padding: 10px; border-radius: 5px;">
            <span style="font-weight: bold; color: #ff6600;">Warning:</span>
            Default action selected: Local Explanations.
        </div>
        """
        explainability_action_warning = widgets.HTML(
            value=warning_message,
            layout=widgets.Layout(margin='0px', width='100%', padding='5px 0px 5px 0px')
        )
        clean_toolkit_content()
        display(explainability_action_warning)

def plot_feature_importance(feature_importance, title):
    # Convert feature importance to a sorted list of tuples
    sorted_importance = sorted(feature_importance.items(), key=lambda x: x[1], reverse=True)
    
    if not sorted_importance:
        display(HTML(f"<p>No feature importance data available for {title}.</p>"))
        return
    
    features, importances = zip(*sorted_importance)

    # Plot the feature importance
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(x=list(importances), y=list(features), palette="Set2", ax=ax)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Importance', fontsize=12)
    ax.set_ylabel('Features', fontsize=12)
    plt.tight_layout()
    plt.show()
    
def run_explainability(b):
    """ Currently only DiCE is supported for demo purposes. """
    with global_output:
        clean_toolkit_content()
        display_running_animation(duration=10)
        target = df["income"]
        train_dataset, test_dataset, y_train, y_test = train_test_split(df,
                                                                        target,
                                                                        test_size=0.2,
                                                                        random_state=0,
                                                                        stratify=target)
        x_train = train_dataset.drop('income', axis=1)
        x_test = test_dataset.drop('income', axis=1)
        # Step 1: dice_ml.Data
        d = dice_ml.Data(dataframe=train_dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
        numerical = ["age", "hours_per_week"]
        categorical = x_train.columns.difference(numerical)

        categorical_transformer = Pipeline(steps=[('onehot', OneHotEncoder(handle_unknown='ignore'))])

        transformations = ColumnTransformer(transformers=[('cat', categorical_transformer, categorical)])

        # Append classifier to preprocessing pipeline.
        # Now we have a full prediction pipeline.
        clf = Pipeline(steps=[('preprocessor', transformations),('classifier', RandomForestClassifier())])
        model = clf.fit(x_train, y_train)
        
        # Using sklearn backend
        m = dice_ml.Model(model=model, backend="sklearn")
        # Using method=random for generating CFs
        exp = dice_ml.Dice(d, m, method="random")
        
        e1 = exp.generate_counterfactuals(x_test[0:1], total_CFs=2, desired_class="opposite")
        display_message("Counterfactuals", font_weight="bold")
        e1.visualize_as_dataframe(show_only_changes=True)
        e1.visualize_as_dataframe(show_only_changes=False)
        
        display_message("Local Feature Importance", font_weight="bold")
        query_instance = x_test[0:1]
        imp = exp.local_feature_importance(query_instance, total_CFs=10)
        # print(imp.local_importance)
        plot_feature_importance(imp.local_importance[0], "Local Feature Importance")

        display_message("Global Feature Importance", font_weight="bold")
        query_instances = x_test[0:20]
        imp = exp.global_feature_importance(query_instances)
        # print(imp.summary_importance)
        plot_feature_importance(imp.summary_importance, "Global Feature Importance")

NameError: name 'widgets' is not defined

In [40]:
# Display
display_home_screen()