In [3]:
import pickle 
import numpy as np
import os

from fairsoft_utils import formal_fairness_name, formal_metric_name, formal_model_name

In [4]:
def shorten_se(se):
    if se == 0:
        return '0'
    else:
        return f'\\nicefrac{{ {int(se * 1000)} }}{{10^3}}'

In [16]:
def show_valid_replication(dataset, reg_norm, target_label_idx=0, masked=False, fair_coeff=1):
    
    fair_metrics = []
    fair_results = {}
    prefix = f'fair_through_postprocess/model/{dataset}/evaluation-{target_label_idx}'
    if masked: 
        prefix += '_masked'
    for i in range(1, 11):
        target_file = f'{prefix}/finetune_fair_eval_{reg_norm}_reg_lambda={fair_coeff:.2f}_{i:04d}.pkl'
        if os.path.exists(target_file):
            fairs = pickle.load(open(target_file, 'rb'))
    
            if not fair_metrics:
                fair_metrics = list(fairs.keys())
            for met in fairs:
                if met not in fair_results:
                    fair_results[met] = {}
                for mod in fairs[met]:
                    if mod not in fair_results[met]:
                        fair_results[met][mod] = []
                    fair_results[met][mod].append(fairs[met][mod])
    
    perform_metrics = []
    perform_results = {}
    for i in range(1, 11):
        target_file = f'{prefix}/finetune_perform_eval_{reg_norm}_reg_lambda={fair_coeff:.2f}_{i:04d}.pkl'
        if os.path.exists(target_file):
            performs = pickle.load(open(target_file, 'rb'))
            
            perform_models = list(performs.keys())
            
            if not perform_metrics:
                perform_metrics = list(performs[perform_models[0]].keys())

            for met in perform_metrics:
                if met not in perform_results:
                    perform_results[met] =  {}
                for mod in performs:
                    if mod not in perform_results[met]:
                        perform_results[met][mod] = []
                    
                    perform_results[met][mod].append(performs[mod][met])
            
    fair_metrics = list(fair_results.keys())
    fair_metrics_nested = {}
    fair_metrics_sorted = []
    should_add_eo = False
    for met_hparam in fair_metrics:
        met = met_hparam.split('_')[0]
        if met not in fair_metrics_nested:
            fair_metrics_nested[met] = []
        fair_metrics_nested[met].append(met_hparam)
    
    for met in ['constant', 'jaccard', 'indication']:
        if met in fair_metrics_nested:
            if len(fair_metrics_nested[met]) > 1:
                met_sorted = sorted(
                    fair_metrics_nested[met], key=lambda met: float(met.split('_')[-1]))
            else:
                met_sorted = fair_metrics_nested[met]
            fair_metrics_sorted += met_sorted
    fair_metrics = fair_metrics_sorted
    fair_models = [formal_model_name(fair_metric).replace('\\', '') for fair_metric in fair_metrics]
    
    colnames = ' & ' + ' & '.join(fair_models)
    print(colnames + '\\\\')
    print('\\midrule')
    
    skip_head_sep = True
    
    for met in fair_metrics:
        result = []
        for mod in fair_metrics:
            results = fair_results[met][mod]
            mean = np.mean(results, 0)[0]
            se = np.std(results, 0)[0] / np.sqrt(len(results))
            result.append(f"{mean:.3f}")

        if skip_head_sep:
            resultrow = formal_fairness_name(met).replace('\\', '') + ' & ' + ' & '.join(result)
            skip_head_sep = False
        else:
            resultrow = formal_fairness_name(met).replace('\\', '').replace('SimFair', 'SF') + ' & ' + ' & '.join(result)
        resultrow = resultrow.replace('.0 ', ' ').replace('0.', '.')
        print(resultrow + '\\\\')
    
    for perform_metric in list(perform_results.keys()):
        
        result = []
        for mod in fair_metrics:
            results = perform_results[perform_metric][mod]
            mean = np.mean(results, 0)[0]
            se = np.std(results, 0)[0] / np.sqrt(len(results))
            result.append(f"${mean:.3f}$")
            resultrow = formal_metric_name(perform_metric).replace('\\', '').replace('SimFair', 'SF') + ' & ' + ' & '.join(result)
        print(resultrow + '\\\\')
    print('\\bottomrule')

In [17]:
show_valid_replication('adult', 'l2', '0', fair_coeff=1)
print('\n' * 5)
# show_valid_replication('adult', 0, True)

 & \\
\midrule
\bottomrule








In [18]:
show_valid_replication('adult', 'l2', '0', fair_coeff=0.1)
print('\n' * 5)
# show_valid_replication('adult', 0, True)

 & w/ DP reg & w/ $ s_{ 0.01 } $-SF reg & w/ $ s_{ 1.0 } $-SF reg & w/ $ s_{ 5.0 } $-SF reg & w/ $ s_{ 10.0 } $-SF reg & w/ EOp reg\\
\midrule
DP & .116 & .116 & .117 & .118 & .127 & .165\\
$ s_{ .01 } $-SF & .116 & .120 & .123 & .125 & .121 & .165\\
$ s_{ 1 } $-SF & .119 & .119 & .120 & .127 & .122 & .166\\
$ s_{ 5 } $-SF & .124 & .129 & .122 & .132 & .131 & .177\\
$ s_{ 10 } $-SF & .147 & .148 & .158 & .158 & .162 & .216\\
EOp & .161 & .174 & .172 & .192 & .160 & .206\\
instance-F1 & $0.555$ & $0.554$ & $0.551$ & $0.553$ & $0.551$ & $0.572$\\
micro-F1 & $0.502$ & $0.498$ & $0.498$ & $0.499$ & $0.497$ & $0.538$\\
macro-F1 & $0.214$ & $0.212$ & $0.209$ & $0.213$ & $0.209$ & $0.219$\\
\bottomrule








In [19]:
show_valid_replication('credit', 'l2', '0', fair_coeff=1)
print('\n' * 5)
# show_valid_replication('adult', 0, True)

 & w/ DP reg & w/ $ s_{ 0.01 } $-SF reg & w/ $ s_{ 1.0 } $-SF reg & w/ $ s_{ 5.0 } $-SF reg & w/ $ s_{ 10.0 } $-SF reg & w/ EOp reg\\
\midrule
DP & .406 & .413 & .407 & .415 & .422 & .473\\
$ s_{ .01 } $-SF & .416 & .415 & .409 & .417 & .422 & .481\\
$ s_{ 1 } $-SF & .419 & .402 & .399 & .419 & .414 & .474\\
$ s_{ 5 } $-SF & .481 & .475 & .478 & .498 & .503 & .597\\
$ s_{ 10 } $-SF & .578 & .573 & .565 & .578 & .568 & .691\\
EOp & .539 & .552 & .539 & .556 & .572 & .651\\
instance-F1 & $0.531$ & $0.520$ & $0.522$ & $0.519$ & $0.525$ & $0.517$\\
micro-F1 & $0.549$ & $0.541$ & $0.540$ & $0.537$ & $0.544$ & $0.536$\\
macro-F1 & $0.366$ & $0.357$ & $0.357$ & $0.354$ & $0.363$ & $0.356$\\
\bottomrule






