In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings

warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

In [3]:
cur_folder_name = os.getcwd().split('/')[-1]
if cur_folder_name != "virny-flow-experiments":
    os.chdir("../../..")

print('Current location: ', os.getcwd())

Current location:  /Users/denys_herasymuk/Research/NYU/VirnyFlow_Project/Code/virny-flow-experiments


# Case Studies Visualizations

In [4]:
import pandas as pd
from duckdb import query as sqldf
from virny_flow.core.custom_classes.core_db_client import CoreDBClient
from virny_flow.configs.constants import EXP_CONFIG_HISTORY_TABLE
from source.visualizations.use_case_queries import get_best_lps_per_exp_config
from source.visualizations.scalability_viz import create_performance_plot_v3

## Prepare data for visualizations

In [5]:
# DATASET_NAME = 'heart'
# DISPARITY_METRIC = 'Equalized_Odds_TNR'
# GROUP = "gender"

DATASET_NAME = 'folk_pubcov'
DISPARITY_METRIC1 = 'Selection_Rate_Difference'
GROUP1 = "SEX"
DISPARITY_METRIC2 = 'Selection_Rate_Difference'
GROUP2 = "RAC1P"

In [6]:
SECRETS_PATH = os.path.join(os.getcwd(), "scripts", "configs", "secrets.env")
VIRNY_FLOW = 'virny_flow'
ALPINE = 'alpine_meadow'
AUTOSKLEARN = 'autosklearn'
EXP_CONFIG_NAMES = {
    VIRNY_FLOW: {
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_1': '[1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_2': '[0.25,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_3': '[0.5,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_4': '[0.75,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_5': '[0.25,0.5,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_6': '[0.5,0.75,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_7': '[0.1,0.25,0.5,1.0]',
        f'sensitivity_exp2_{DATASET_NAME}_w32_vf_halting_8': '[0.1,0.5,0.75,1.0]',
    },
}

db_client = CoreDBClient(SECRETS_PATH)
db_client.connect()

In [7]:
def get_virny_flow_metrics(db_client):
    exp_config_names = list(EXP_CONFIG_NAMES['virny_flow'].keys())
    best_lp_metrics_per_exp_config_df = get_best_lps_per_exp_config(secrets_path=SECRETS_PATH,
                                                                    exp_config_names=exp_config_names)
    best_lp_metrics_per_exp_config_df['halting'] = best_lp_metrics_per_exp_config_df['exp_config_name'].map(EXP_CONFIG_NAMES['virny_flow'])

    virny_flow_all_runtime_df = pd.DataFrame()
    for exp_config_name in exp_config_names:
        virny_flow_runtime_df = db_client.read_metric_df_from_db(collection_name=EXP_CONFIG_HISTORY_TABLE,
                                                                 query={'exp_config_name': exp_config_name,
                                                                        'deletion_flag': False})
        virny_flow_all_runtime_df = pd.concat([virny_flow_all_runtime_df, virny_flow_runtime_df])
    
    new_column_names = []
    for col in virny_flow_all_runtime_df.columns:
        new_col_name = '_'.join([c.lower() for c in col.split('_')])
        new_column_names.append(new_col_name)
    virny_flow_all_runtime_df.columns = new_column_names

    virny_flow_metrics_df = sqldf("""
        SELECT DISTINCT t1.*, t2.exp_config_execution_time
        FROM best_lp_metrics_per_exp_config_df AS t1
        JOIN virny_flow_all_runtime_df AS t2
          ON t1.exp_config_name = t2.exp_config_name
         AND t1.run_num = t2.run_num
    """).to_df()
    
    return virny_flow_metrics_df

In [8]:
virny_flow_metrics_df = get_virny_flow_metrics(db_client)

Extracting metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_1...
best_pps_per_lp_and_run_num_df.shape: (252, 19)
best_lp_per_run_all.shape: (90, 19)
Extracted metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_1

Extracting metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_2...
best_pps_per_lp_and_run_num_df.shape: (144, 19)
best_lp_per_run_all.shape: (81, 19)
Extracted metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_2

Extracting metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_3...
best_pps_per_lp_and_run_num_df.shape: (198, 19)
best_lp_per_run_all.shape: (108, 19)
Extracted metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_3

Extracting metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_4...
best_pps_per_lp_and_run_num_df.shape: (207, 19)
best_lp_per_run_all.shape: (99, 19)
Extracted metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_4

Extracting metrics for sensitivity_exp2_folk_pubcov_w32_vf_halting_5...
best_pps_per_lp_and_run_num

## Display Results

In [9]:
from virny_flow.visualizations.use_case_queries import get_models_disparity_metric_df


def display_table_with_results(system_metrics_df, system_name: str, disparity_metric_name: str, group_name: str):
    if system_name == VIRNY_FLOW:
        system_metrics_df['system_name'] = system_name
        common_cols = ['system_name', 'dataset_name', 'halting', 'run_num', 'exp_config_execution_time']
    else:
        common_cols = ['system_name', 'dataset_name', 'halting', 'run_num', 'optimization_time']

    f1_metrics_df = system_metrics_df[system_metrics_df['metric'] == 'F1']
    f1_metrics_df['F1'] = f1_metrics_df['overall']
    f1_metrics_df = f1_metrics_df[common_cols + ['F1']]

    disparity_metric_df = get_models_disparity_metric_df(system_metrics_df, disparity_metric_name, group_name)
    disparity_metric_df[disparity_metric_name] = disparity_metric_df['disparity_metric_value']
    disparity_metric_df = disparity_metric_df[common_cols + [disparity_metric_name]]

    final_metrics_df = sqldf(f"""
        SELECT t1.*, t2.{disparity_metric_name}
        FROM f1_metrics_df AS t1
        JOIN disparity_metric_df AS t2
          ON t1.run_num = t2.run_num
         AND t1.halting = t2.halting
    """).to_df()
    final_metrics_df["score"] = final_metrics_df["F1"] * 0.5 + (1 - abs(final_metrics_df[disparity_metric_name])) * 0.5

    if system_name == VIRNY_FLOW:
        final_metrics_df = final_metrics_df[~final_metrics_df['exp_config_execution_time'].isna()]
        final_metrics_df = final_metrics_df.rename(columns={'exp_config_execution_time': 'optimization_time'})

    return final_metrics_df


def display_table_with_results_for_folk_pubcov(system_metrics_df, system_name: str,
                                               disparity_metric_name1: str, group_name1: str,
                                               disparity_metric_name2: str, group_name2: str):
    if system_name == VIRNY_FLOW:
        system_metrics_df['system_name'] = system_name
        common_cols = ['system_name', 'dataset_name', 'halting', 'run_num', 'exp_config_execution_time']
    else:
        common_cols = ['system_name', 'dataset_name', 'halting', 'run_num', 'optimization_time']

    f1_metrics_df = system_metrics_df[system_metrics_df['metric'] == 'F1']
    f1_metrics_df['F1'] = f1_metrics_df['overall']
    f1_metrics_df = f1_metrics_df[common_cols + ['F1']]

    disparity_metric_df1 = get_models_disparity_metric_df(system_metrics_df, disparity_metric_name1, group_name1)
    disparity_metric_df1[disparity_metric_name1] = disparity_metric_df1['disparity_metric_value']
    disparity_metric_df1 = disparity_metric_df1[common_cols + [disparity_metric_name1]]

    disparity_metric_df2 = get_models_disparity_metric_df(system_metrics_df, disparity_metric_name2, group_name2)
    disparity_metric_df2[disparity_metric_name2] = disparity_metric_df2['disparity_metric_value']
    disparity_metric_df2 = disparity_metric_df2[common_cols + [disparity_metric_name2]]

    final_metrics_df = sqldf(f"""
            SELECT t1.*, t2.{disparity_metric_name1} AS {disparity_metric_name1}_{group_name1}, t3.{disparity_metric_name2} AS {disparity_metric_name2}_{group_name2}
            FROM f1_metrics_df AS t1
            JOIN disparity_metric_df1 AS t2
              ON t1.run_num = t2.run_num
             AND t1.dataset_name = t2.dataset_name
             AND t1.halting = t2.halting
            JOIN disparity_metric_df2 AS t3
              ON t1.run_num = t3.run_num
             AND t1.dataset_name = t3.dataset_name
             AND t1.halting = t3.halting
        """).to_df()

    final_metrics_df["score"] = (final_metrics_df["F1"] * 0.33 +
                                 (1 - abs(final_metrics_df[f"{disparity_metric_name1}_{group_name1}"])) * 0.33 +
                                 (1 - abs(final_metrics_df[f"{disparity_metric_name2}_{group_name2}"])) * 0.33)

    if system_name == VIRNY_FLOW:
        final_metrics_df = final_metrics_df[~final_metrics_df['exp_config_execution_time'].isna()]
        final_metrics_df = final_metrics_df.rename(columns={'exp_config_execution_time': 'optimization_time'})

    return final_metrics_df


def create_latex_table(df):
    # Compute mean and std for each system
    summary = df.groupby('halting').agg(['mean', 'std']).round(4)

    # Combine mean and std into "mean ± std" format
    def format_metric(mean, std):
        if pd.isna(std):
            return f"{mean:.4f} ± n/a"
        return f"{mean:.4f} \scriptsize{{$\pm${std:.4f}}}"

    def format_runtime(mean, std):
        if pd.isna(std):
            return f"{mean:.0f} ± n/a"
        return f"{mean:.0f} \scriptsize{{$\pm${std:.0f}}}"

    # Create formatted DataFrame
    latex_df = pd.DataFrame({
        'Halting': summary.index,
        'Score': [format_metric(m, s) for m, s in zip(summary['score']['mean'], summary['score']['std'])],
        'Runtime': [format_runtime(m, s) for m, s in zip(summary['optimization_time']['mean'], summary['optimization_time']['std'])],
    })

    # Reorder rows: virny_flow first
    latex_df = latex_df.set_index('Halting').loc[['[1.0]','[0.25,1.0]','[0.5,1.0]','[0.75,1.0]','[0.25,0.5,1.0]','[0.5,0.75,1.0]','[0.1,0.25,0.5,1.0]','[0.1,0.5,0.75,1.0]']].reset_index()
    
    # Generate LaTeX table
    latex_table = latex_df.to_latex(index=False,
                                    caption='Sensitivity to the Dataset Fraction for Halting',
                                    label='tab:sensitivity_halting',
                                    column_format='lcc',
                                    escape=False)

    print(latex_table)

In [10]:
if DATASET_NAME == 'folk_pubcov':
    virny_flow_final_metrics_df = display_table_with_results_for_folk_pubcov(virny_flow_metrics_df, 'virny_flow',
                                                                             DISPARITY_METRIC1, GROUP1,
                                                                             DISPARITY_METRIC2, GROUP2)
else:
    virny_flow_final_metrics_df = display_table_with_results(virny_flow_metrics_df, 'virny_flow', DISPARITY_METRIC, GROUP)

In [11]:
virny_flow_final_metrics_df[virny_flow_final_metrics_df['halting'] == '[0.5,0.75,1.0]']

Unnamed: 0,system_name,dataset_name,halting,run_num,optimization_time,F1,Selection_Rate_Difference_SEX,Selection_Rate_Difference_RAC1P,score
9,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",10,753.698984,0.639837,-0.013441,0.1135,0.829256
19,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",6,931.2766,0.61416,-0.015979,0.102539,0.823562
20,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",8,646.357094,0.632572,-0.028131,0.107819,0.823885
26,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",5,867.018093,0.626223,-0.018969,0.104553,0.825891
27,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",11,1273.977963,0.627462,-0.023862,0.096978,0.827185
28,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",12,903.285316,0.630746,-0.028188,0.079344,0.832661
47,virny_flow,folk_pubcov,"[0.5,0.75,1.0]",9,724.449237,0.635802,-0.022513,0.100435,0.829242


In [12]:
create_latex_table(virny_flow_final_metrics_df)

\begin{table}
\centering
\caption{Sensitivity to the Dataset Fraction for Halting}
\label{tab:sensitivity_halting}
\begin{tabular}{lcc}
\toprule
           Halting &                           Score &                    Runtime \\
\midrule
             [1.0] & 0.8277 \scriptsize{$\pm$0.0029} &  782 \scriptsize{$\pm$492} \\
        [0.25,1.0] & 0.8264 \scriptsize{$\pm$0.0031} &  765 \scriptsize{$\pm$146} \\
         [0.5,1.0] & 0.8267 \scriptsize{$\pm$0.0027} &   682 \scriptsize{$\pm$70} \\
        [0.75,1.0] & 0.8278 \scriptsize{$\pm$0.0044} &  789 \scriptsize{$\pm$120} \\
    [0.25,0.5,1.0] & 0.8291 \scriptsize{$\pm$0.0043} &  970 \scriptsize{$\pm$205} \\
    [0.5,0.75,1.0] & 0.8274 \scriptsize{$\pm$0.0033} &  871 \scriptsize{$\pm$205} \\
[0.1,0.25,0.5,1.0] & 0.8284 \scriptsize{$\pm$0.0044} & 1104 \scriptsize{$\pm$166} \\
[0.1,0.5,0.75,1.0] & 0.8249 \scriptsize{$\pm$0.0043} &  1007 \scriptsize{$\pm$81} \\
\bottomrule
\end{tabular}
\end{table}

