In [9]:
import torch
import pickle
import numpy as np
import pandas as pd
import plotly.express as px
from copy import deepcopy


from model import LR
from data import SyntheticDataset, FairnessDataset
from ei_model_dev import FairBatch
from ei_effort import Optimal_Effort, PGD_Effort
from ei_utils import *

In [2]:
dataset = SyntheticDataset(seed=0)

In [3]:
def append_res(d, acc, ei):
    d['accuracy'].append(acc)
    d['ei_disparity'].append(ei)
        
def get_res(d, id, hp):
    res = {'id': [], 'alpha': [], 'lambda': [], 'delta': [], 'accuracy_mean': [], 'accuracy_std': [], 'ei_disparity_mean': [], 'ei_disparity_std': []}
    res['id'].append(id)
    res['alpha'].append(hp['alpha'])
    res['lambda'].append(hp['lambda'])
    res['delta'].append(hp['delta'])
    res['accuracy_mean'].append(np.mean(d['accuracy']))
    res['accuracy_std'].append(np.std(d['accuracy']))
    res['ei_disparity_mean'].append(np.mean(d['ei_disparity']))
    res['ei_disparity_std'].append(np.std(d['ei_disparity']))
    
    return res

def get_model(models):
    weights = []
    bias = []
    for model in models:
        for module in model.layers:
            if hasattr(module, 'weight'):
                weights.append(module.weight.data)
            if hasattr(module, 'bias'):
                bias.append(module.bias.data)
            
    weights = torch.cat(weights).mean(dim=0)
    bias = torch.cat(bias).mean(dim=0)
    return weights, bias

In [4]:
def lr_fb_model_runner(dataset, hp, seeds):
    tau = 0.5
    train_metrics = {'alpha': [], 'accuracy': [], 'ei_disparity': []}
    val_metrics = deepcopy(train_metrics)
    test_metrics = deepcopy(train_metrics)
    ei_models = []
    
    if hp['optimal_effort']:
        effort = Optimal_Effort(hp['delta'])
    else:
        effort = PGD_Effort(hp['delta'])
    
    for seed in seeds:
        train_tensors, val_tensors, test_tensors = dataset.tensor(fold=seed, z_blind=hp['z_blind'])
        train_dataset = FairnessDataset(*train_tensors, dataset.imp_feats)
        val_dataset = FairnessDataset(*val_tensors, dataset.imp_feats)
        test_dataset = FairnessDataset(*test_tensors, dataset.imp_feats)
        
        model = LR(num_features=train_dataset.X.shape[1])
        ei_model = FairBatch(model, effort, tau)
        
        ei_model.train(
            train_dataset, 
            sensitive_attrs=dataset.sensitive_attrs,
            lamb=hp['lambda'],
            lr=hp['learning_rate'],
            alpha=hp['alpha']
            )
 
        Y_hat, Y_hat_max = ei_model.predict(train_dataset, hp['alpha'], dataset.sensitive_attrs)
        train_acc, train_ei = model_performance(train_dataset.Y.detach().numpy(), train_dataset.Z.detach().numpy(), Y_hat, Y_hat_max, tau)
        append_res(train_metrics, train_acc, train_ei)
        
        Y_hat, Y_hat_max = ei_model.predict(val_dataset, hp['alpha'], dataset.sensitive_attrs)
        val_acc, val_ei = model_performance(val_dataset.Y.detach().numpy(), val_dataset.Z.detach().numpy(), Y_hat, Y_hat_max, tau)
        append_res(val_metrics, val_acc, val_ei)
        
        Y_hat, Y_hat_max = ei_model.predict(test_dataset, hp['alpha'], dataset.sensitive_attrs)
        test_acc, test_ei = model_performance(test_dataset.Y.detach().numpy(), test_dataset.Z.detach().numpy(), Y_hat, Y_hat_max, tau)
        append_res(test_metrics, test_acc, test_ei)
    
        ei_models.append(ei_model.model)
    
    res_train = get_res(train_metrics, 'train', hp)
    res_val = get_res(val_metrics, 'val', hp)
    res_test = get_res(test_metrics, 'test', hp)
    
    
    return res_train, res_val, res_test, ei_models

In [5]:
def fb_tradeoff(dataset, hyper_params, seeds):
    hp = hyper_params.copy()
    result = pd.DataFrame()
    ei_models = []
    
    for alpha in hyper_params['alpha']:
        for lamb in hyper_params['lambda']:
            for delta in hyper_params['delta']:
                hp['alpha'] = alpha
                hp['lambda'] = lamb
                hp['delta'] = delta
                
                train, val, test, models = lr_fb_model_runner(dataset, hp, seeds)
                result = pd.concat((result, pd.DataFrame(train), pd.DataFrame(val), pd.DataFrame(test)))
                ei_models.extend(models)
                print()
    
    return result, ei_models

In [6]:
hyper_params = {}
hyper_params['learning_rate'] = 0.01
hyper_params['delta'] = [0.5]
hyper_params['alpha'] = [0., 0.5, 1.4]
hyper_params['lambda'] = np.linspace(0.,.25, 10).round(3)
hyper_params['z_blind'] = False
hyper_params['optimal_effort'] = False
seeds = list(range(5))

results, ei_models = fb_tradeoff(dataset, hyper_params, seeds)
results['loss_mean'] = 1 - results['accuracy_mean']
results['alpha'] = results['alpha'].astype(str)

Training [alpha=0.00; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.04s/epochs]
Training [alpha=0.00; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.04s/epochs]
Training [alpha=0.00; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]





Training [alpha=0.00; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.04s/epochs]
Training [alpha=0.00; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:45<00:00,  1.05s/epochs]
Training [alpha=0.00; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:45<00:00,  1.05s/epochs]
Training [alpha=0.00; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:46<00:00,  1.07s/epochs]





Training [alpha=0.00; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:48<00:00,  1.09s/epochs]
Training [alpha=0.00; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.03s/epochs]





Training [alpha=0.00; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=0.00; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]
Training [alpha=0.00; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=0.00; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]





Training [alpha=0.00; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.00; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]
Training [alpha=0.00; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.01s/epochs]
Training [alpha=0.00; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.00; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.05s/epochs]





Training [alpha=0.00; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]
Training [alpha=0.00; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.00; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]





Training [alpha=0.00; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.00; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.04s/epochs]
Training [alpha=0.00; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]





Training [alpha=0.00; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]
Training [alpha=0.00; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.00; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.01s/epochs]
Training [alpha=0.00; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.00; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]





Training [alpha=0.00; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.00; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=0.00; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=0.00; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=0.00; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]





Training [alpha=0.00; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]
Training [alpha=0.00; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.00; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:35<00:00,  1.05epochs/s]
Training [alpha=0.00; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:35<00:00,  1.04epochs/s]
Training [alpha=0.00; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.04epochs/s]





Training [alpha=0.50; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=0.50; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.04epochs/s]
Training [alpha=0.50; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]





Training [alpha=0.50; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:35<00:00,  1.04epochs/s]
Training [alpha=0.50; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:31<00:00,  1.09epochs/s]
Training [alpha=0.50; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]
Training [alpha=0.50; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:31<00:00,  1.09epochs/s]
Training [alpha=0.50; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:35<00:00,  1.05epochs/s]





Training [alpha=0.50; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=0.50; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.03epochs/s]





Training [alpha=0.50; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]
Training [alpha=0.50; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=0.50; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.00epochs/s]
Training [alpha=0.50; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]





Training [alpha=0.50; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]
Training [alpha=0.50; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:49<00:00,  1.09s/epochs]
Training [alpha=0.50; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.04s/epochs]
Training [alpha=0.50; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:47<00:00,  1.07s/epochs]





Training [alpha=0.50; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]
Training [alpha=0.50; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.50; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.00epochs/s]





Training [alpha=0.50; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.50; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=0.50; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.02s/epochs]
Training [alpha=0.50; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]





Training [alpha=0.50; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.04s/epochs]
Training [alpha=0.50; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=0.50; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:44<00:00,  1.04s/epochs]





Training [alpha=0.50; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.00epochs/s]
Training [alpha=0.50; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]
Training [alpha=0.50; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=0.50; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.01s/epochs]
Training [alpha=0.50; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]





Training [alpha=0.50; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=0.50; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:34<00:00,  1.05epochs/s]
Training [alpha=0.50; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:34<00:00,  1.06epochs/s]
Training [alpha=0.50; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:34<00:00,  1.06epochs/s]





Training [alpha=1.40; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:34<00:00,  1.06epochs/s]
Training [alpha=1.40; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:34<00:00,  1.06epochs/s]
Training [alpha=1.40; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:31<00:00,  1.09epochs/s]
Training [alpha=1.40; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]
Training [alpha=1.40; lambda=0.00]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:30<00:00,  1.10epochs/s]





Training [alpha=1.40; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.04epochs/s]
Training [alpha=1.40; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:40<00:00,  1.00s/epochs]
Training [alpha=1.40; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]
Training [alpha=1.40; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.03s/epochs]
Training [alpha=1.40; lambda=0.03]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.04s/epochs]





Training [alpha=1.40; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=1.40; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.04epochs/s]
Training [alpha=1.40; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]
Training [alpha=1.40; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]
Training [alpha=1.40; lambda=0.06]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]





Training [alpha=1.40; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=1.40; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.03epochs/s]
Training [alpha=1.40; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=1.40; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.01epochs/s]
Training [alpha=1.40; lambda=0.08]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]





Training [alpha=1.40; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.02epochs/s]
Training [alpha=1.40; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]
Training [alpha=1.40; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:39<00:00,  1.00epochs/s]
Training [alpha=1.40; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=1.40; lambda=0.11]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]





Training [alpha=1.40; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]
Training [alpha=1.40; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]
Training [alpha=1.40; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]
Training [alpha=1.40; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:33<00:00,  1.07epochs/s]
Training [alpha=1.40; lambda=0.14]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:32<00:00,  1.08epochs/s]





Training [alpha=1.40; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.02epochs/s]
Training [alpha=1.40; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:42<00:00,  1.02s/epochs]
Training [alpha=1.40; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]
Training [alpha=1.40; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:35<00:00,  1.05epochs/s]
Training [alpha=1.40; lambda=0.17]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:43<00:00,  1.03s/epochs]





Training [alpha=1.40; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:36<00:00,  1.04epochs/s]
Training [alpha=1.40; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [02:14<00:00,  1.35s/epochs]
Training [alpha=1.40; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [02:19<00:00,  1.39s/epochs]
Training [alpha=1.40; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [06:31<00:00,  3.91s/epochs] 
Training [alpha=1.40; lambda=0.19]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [11:35<00:00,  6.96s/epochs] 





Training [alpha=1.40; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [38:53<00:00, 23.33s/epochs]  
Training [alpha=1.40; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [36:41<00:00, 22.02s/epochs]  
Training [alpha=1.40; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [57:43<00:00, 34.64s/epochs]  
Training [alpha=1.40; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [39:21<00:00, 23.61s/epochs]  
Training [alpha=1.40; lambda=0.22]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [1:11:02<00:00, 42.63s/epochs]  





Training [alpha=1.40; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [34:08<00:00, 20.49s/epochs]  
Training [alpha=1.40; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [24:41<00:00, 14.82s/epochs]  
Training [alpha=1.40; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:37<00:00,  1.03epochs/s]
Training [alpha=1.40; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:41<00:00,  1.01s/epochs]
Training [alpha=1.40; lambda=0.25]: 100%|[38;2;0;145;255m██████████[0m| 100/100 [01:38<00:00,  1.01epochs/s]





In [11]:
with open(f'tradeoff_robust_synthetic_5crossval.pkl', 'wb') as f:
    pickle.dump(results, f)
results.head()

Unnamed: 0,id,alpha,lambda,delta,accuracy_mean,accuracy_std,ei_disparity_mean,ei_disparity_std,loss_mean
0,train,0.0,0.0,0.5,0.784625,0.001904,0.131757,0.004442,0.215375
0,val,0.0,0.0,0.5,0.790188,0.004884,0.130056,0.010956,0.209812
0,test,0.0,0.0,0.5,0.7865,0.002525,0.131944,0.013618,0.2135
0,train,0.0,0.028,0.5,0.784703,0.001952,0.132062,0.004666,0.215297
0,val,0.0,0.028,0.5,0.790188,0.004884,0.129862,0.010854,0.209812


In [12]:
train_results = results[results['id'] == 'train']
val_results = results[results['id'] == 'val']
test_results = results[results['id'] == 'test']

In [13]:
test_results = test_results.sort_values(['alpha', 'lambda'])
test_results

Unnamed: 0,id,alpha,lambda,delta,accuracy_mean,accuracy_std,ei_disparity_mean,ei_disparity_std,loss_mean
0,test,0.0,0.0,0.5,0.7865,0.002525,0.131944,0.013618,0.2135
0,test,0.0,0.028,0.5,0.7865,0.002424,0.131754,0.013316,0.2135
0,test,0.0,0.056,0.5,0.7867,0.002467,0.131827,0.013261,0.2133
0,test,0.0,0.083,0.5,0.7866,0.002606,0.131785,0.013058,0.2134
0,test,0.0,0.111,0.5,0.78645,0.002517,0.131775,0.013392,0.21355
0,test,0.0,0.139,0.5,0.7866,0.002473,0.131972,0.013121,0.2134
0,test,0.0,0.167,0.5,0.7866,0.002591,0.131728,0.013074,0.2134
0,test,0.0,0.194,0.5,0.78645,0.002547,0.131703,0.013254,0.21355
0,test,0.0,0.222,0.5,0.7865,0.002424,0.131617,0.013106,0.2135
0,test,0.0,0.25,0.5,0.7864,0.002483,0.13158,0.013375,0.2136


In [14]:
test_results_pareto = pd.DataFrame()
for alpha in test_results['alpha'].unique():
    test_results_alpha = test_results[test_results['alpha'] == alpha]
    mask = pareto_frontier(test_results_alpha['loss_mean'], test_results_alpha['ei_disparity_mean'])
    results_alpha_pareto = test_results_alpha.iloc[mask]
    test_results_pareto = pd.concat((test_results_pareto, results_alpha_pareto.sort_values('ei_disparity_mean')))

In [19]:
test_results_pareto

Unnamed: 0,id,alpha,lambda,delta,accuracy_mean,accuracy_std,ei_disparity_mean,ei_disparity_std,loss_mean
0,test,0.0,0.25,0.5,0.7864,0.002483,0.13158,0.013375,0.2136
0,test,0.0,0.222,0.5,0.7865,0.002424,0.131617,0.013106,0.2135
0,test,0.0,0.167,0.5,0.7866,0.002591,0.131728,0.013074,0.2134
0,test,0.0,0.056,0.5,0.7867,0.002467,0.131827,0.013261,0.2133
0,test,0.5,0.25,0.5,0.7866,0.002591,0.125958,0.011767,0.2134
0,test,0.5,0.028,0.5,0.78665,0.002422,0.126212,0.011932,0.21335
0,test,1.4,0.167,0.5,0.7866,0.002611,0.126128,0.011763,0.2134
0,test,1.4,0.139,0.5,0.78665,0.002653,0.126203,0.011878,0.21335
0,test,1.4,0.083,0.5,0.7867,0.002638,0.126591,0.011549,0.2133


In [18]:
fig = px.line(test_results_pareto, x='ei_disparity_mean', y='loss_mean', color='alpha', hover_data='lambda', markers=True, error_y='ei_disparity_std')
fig.add_annotation(dict(font=dict(color='black',size=10),
                                        x=0.9,
                                        y=0.99,
                                        showarrow=False,
                                        text='dataset=synthetic',
                                        textangle=0,
                                        xanchor='left',
                                        xref="paper",
                                        yref="paper"))
fig.update_layout(title=dict(text='Fairness vs Loss Tradeoff', x=0.5))
# fig.update_traces(marker=dict(size=3))
fig.show()