In [None]:
import pandas as pd
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from metric_utils import *

In [None]:
model_name = 'bert-base-uncased'
epsilon_list = sorted([0.5, 1.0, 3.0, 6.0, 9.0])[::-1] 

In [None]:
subgroup_map = {
    'jigsaw': ['male', 'female', 'transgender', 'white', 'black', 'asian'],
    'ucberkeley': ['target_gender_men', 'target_gender_women','target_gender_transgender', 'target_race_white', 'target_race_black', 'target_race_asian']
}

for dataset_name in ['jigsaw', 'ucberkeley']:
    print(dataset_name)
    dataset_directory = f'../results/{dataset_name}/'
    avg_result = {}
    protected_subgroups = subgroup_map[dataset_name]
    binarizing_columns = [target_column] + protected_subgroups

    for run in range(1, 4):
        run_folder = f'{dataset_directory}/run {run}'
        model_folder = os.path.join(run_folder, model_name)
        normal_folder = os.path.join(model_folder, 'normal')
        result_filepath = os.path.join(normal_folder, 'results.csv')

        result = pd.read_csv(result_filepath)
        result = result[result['split']=='test']
        # drop split column
        result.drop(columns=['split'], inplace=True)

        if dataset_name=='ucberkeley':
            test_csv_filepath = os.path.join(run_folder, 'test.csv')
        else:
            test_csv_filepath = os.path.join(dataset_directory, 'test.csv')

        test_df = pd.read_csv(test_csv_filepath)

        test_df.fillna(0, inplace=True)
        # result has id column which is the same as the text ids from raw dataset
        
        if dataset_name=='ucberkeley':
            test_df.rename({'comment_id': id_column}, axis=1, inplace=True)

        # if test df has any common columns except id, drop that during merge
        extra_columns = [col for col in test_df.columns if col in result.columns and col!=id_column]

        result = result.merge(test_df.drop(columns=extra_columns), on=id_column, how='inner').reset_index(drop=True)
        result[prediction_column] = result[probability_column] >=0.5
        result = binarize(result, binarizing_columns)

        if run ==1:
            avg_result['None_total'] = [result[prediction_column].value_counts(normalize=True)]
            for group in protected_subgroups:
                avg_result[f'None_{group}'] = [result[result[group]][prediction_column].value_counts(normalize=True)]
        else:
            avg_result['None_total'].append(result[prediction_column].value_counts(normalize=True))
            for group in protected_subgroups:
                    avg_result[f'None_{group}'].append(result[result[group]][prediction_column].value_counts(normalize=True))

        for epsilon in sorted(epsilon_list)[::-1]:
        
            dp_folder = os.path.join(model_folder, f'epsilon {epsilon}')
            dp_result_filepath = os.path.join(dp_folder, 'results.csv')
            dp_result = pd.read_csv(dp_result_filepath)

            # only calculate test result
            dp_result = dp_result[dp_result['split']=='test']
            dp_result = dp_result.merge(test_df.drop(columns=extra_columns), on=id_column, how='inner').reset_index(drop=True)
        
            dp_result[prediction_column] = dp_result[probability_column] >=0.5
            dp_result = binarize(dp_result, binarizing_columns)

            if run ==1:
                avg_result[f'{epsilon}_total'] = [dp_result[prediction_column].value_counts(normalize=True)]
                for group in protected_subgroups:
                    avg_result[f'{epsilon}_{group}'] = [dp_result[dp_result[group]][prediction_column].value_counts(normalize=True)]
            else:
                avg_result[f'{epsilon}_total'].append(dp_result[prediction_column].value_counts(normalize=True))
                for group in protected_subgroups:
                    avg_result[f'{epsilon}_{group}'].append(dp_result[dp_result[group]][prediction_column].value_counts(normalize=True))


    for key in avg_result.keys():
        avg_result[key] = pd.concat(avg_result[key]).reset_index()
        mean = avg_result[key].groupby('index').agg('mean').reset_index()
        print(key)
        print(mean)
        print()
    print('\n')
