In [1]:
import pickle 
import numpy as np
import os

from fairsoft_utils import formal_fairness_name, formal_metric_name, formal_model_name

In [2]:
def shorten_se(se):
    if se == 0:
        return '0'
    else:
        return f'\\nicefrac{{ {int(se * 1000)} }}{{10^3}}'

In [3]:
def show_valid_replication(dataset, reg_norm, target_label_idx=0, masked=False, fair_coeff=1):
    
    fair_metrics = []
    fair_results = {}
    prefix = f'fair_through_postprocess/model/{dataset}/evaluation-{target_label_idx}'
    if masked: 
        prefix += '_masked'
    for i in range(1, 11):
        target_file = f'{prefix}/fair_eval_{reg_norm}_reg_lambda={fair_coeff:.2f}_{i:04d}.pkl'
        if os.path.exists(target_file):
            fairs = pickle.load(open(target_file, 'rb'))
    
            if not fair_metrics:
                fair_metrics = list(fairs.keys())
            for met in fairs:
                if met not in fair_results:
                    fair_results[met] = {}
                for mod in fairs[met]:
                    if mod not in fair_results[met]:
                        fair_results[met][mod] = []
                    fair_results[met][mod].append(fairs[met][mod])
    
    perform_metrics = []
    perform_results = {}
    for i in range(1, 11):
        target_file = f'{prefix}/perform_eval_{reg_norm}_reg_lambda={fair_coeff:.2f}_{i:04d}.pkl'
        if os.path.exists(target_file):
            performs = pickle.load(open(target_file, 'rb'))
            
            perform_models = list(performs.keys())
            
            if not perform_metrics:
                perform_metrics = list(performs[perform_models[0]].keys())

            for met in perform_metrics:
                if met not in perform_results:
                    perform_results[met] =  {}
                for mod in performs:
                    if mod not in perform_results[met]:
                        perform_results[met][mod] = []
                    
                    perform_results[met][mod].append(performs[mod][met])
            
    fair_metrics = list(fair_results.keys())
    fair_metrics_nested = {}
    fair_metrics_sorted = []
    should_add_eo = False
    for met_hparam in fair_metrics:
        met = met_hparam.split('_')[0]
        if met not in fair_metrics_nested:
            fair_metrics_nested[met] = []
        fair_metrics_nested[met].append(met_hparam)
    
    for met in ['constant', 'jaccard', 'indication']:
        if met in fair_metrics_nested:
            if len(fair_metrics_nested[met]) > 1:
                met_sorted = sorted(
                    fair_metrics_nested[met], key=lambda met: float(met.split('_')[-1]))
            else:
                met_sorted = fair_metrics_nested[met]
            fair_metrics_sorted += met_sorted
    fair_metrics = fair_metrics_sorted
    fair_models = [formal_model_name(fair_metric).replace('\\', '') for fair_metric in fair_metrics]
    
    colnames = ' & ' + ' & '.join(fair_models)
    print(colnames + '\\\\')
    print('\\midrule')
    
    skip_head_sep = True
    
    for met in fair_metrics:
        result = []
        for mod in fair_metrics:
            results = fair_results[met][mod]
            mean = np.mean(results, 0)[0]
            se = np.std(results, 0)[0] / np.sqrt(len(results))
            result.append(f"{mean:.3f}")

        if skip_head_sep:
            resultrow = formal_fairness_name(met).replace('\\', '') + ' & ' + ' & '.join(result)
            skip_head_sep = False
        else:
            resultrow = formal_fairness_name(met).replace('\\', '').replace('SimFair', 'SF') + ' & ' + ' & '.join(result)
        resultrow = resultrow.replace('.0 ', ' ').replace('0.', '.')
        print(resultrow + '\\\\')
    
#     for perform_metric in list(perform_results.keys()):
        
#         result = []
#         for mod in fair_metrics:
#             results = perform_results[perform_metric][mod]
#             mean = np.mean(results, 0)[0]
#             se = np.std(results, 0)[0] / np.sqrt(len(results))
#             result.append(f"${mean:.3f}$")
#             resultrow = formal_metric_name(perform_metric).replace('\\', '').replace('SimFair', 'SF') + ' & ' + ' & '.join(result)
#         print(resultrow + '\\\\')
    print('\\bottomrule')

In [8]:
show_valid_replication('adult', 'l1', '0', fair_coeff=1)
print('\n' * 5)
# show_valid_replication('adult', 0, True)

 & w/ DP reg & w/ $ s_{ 0.01 } $-SF reg & w/ $ s_{ 1.0 } $-SF reg & w/ $ s_{ 5.0 } $-SF reg & w/ $ s_{ 10.0 } $-SF reg & w/ EOp reg\\
\midrule
DP & .111 & .111 & .118 & .155 & .175 & .186\\
$ s_{ .01 } $-SF & .111 & .111 & .117 & .155 & .175 & .186\\
$ s_{ 1 } $-SF & .111 & .111 & .118 & .155 & .175 & .186\\
$ s_{ 5 } $-SF & .120 & .119 & .127 & .163 & .182 & .187\\
$ s_{ 10 } $-SF & .133 & .132 & .139 & .173 & .194 & .196\\
EOp & .134 & .135 & .141 & .178 & .194 & .196\\
\bottomrule








In [5]:
show_valid_replication('adult', 'l2', '0', fair_coeff=0.1)
print('\n' * 5)
# show_valid_replication('adult', 0, True)

 & w/ DP reg & w/ $ s_{ 0.01 } $-SF reg & w/ $ s_{ 1.0 } $-SF reg & w/ $ s_{ 5.0 } $-SF reg & w/ $ s_{ 10.0 } $-SF reg & w/ EOp reg\\
\midrule
DP & .120 & .121 & .119 & .123 & .123 & .167\\
$ s_{ .01 } $-SF & .120 & .115 & .123 & .120 & .125 & .164\\
$ s_{ 1 } $-SF & .119 & .116 & .118 & .121 & .120 & .164\\
$ s_{ 5 } $-SF & .129 & .123 & .145 & .146 & .125 & .177\\
$ s_{ 10 } $-SF & .158 & .147 & .161 & .161 & .156 & .215\\
EOp & .143 & .135 & .155 & .163 & .197 & .212\\
\bottomrule






