In [1]:
import sys
sys.path.append('..')

In [2]:
import tqdm
import warnings
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from copy import deepcopy

from src.data import *
from src.model import *
from src.recourse import *
from src.utils import *

warnings.filterwarnings('ignore')

In [3]:
def append_result(d, alg_name, seed, alpha, lamb, i, x_0, theta_0, beta, x_r, theta_r, p, theta_p, J_r, J_c, robustness, consistency):
    d['alg'].append(alg_name)
    d['seed'].append(seed)
    d['alpha'].append(alpha)
    d['lambda'].append(lamb)
    d['i'].append(i)
    d['x_0'].append(x_0.round(4))
    d['theta_0'].append(theta_0.round(4))
    d['beta'].append(beta)
    d['x_r'].append(x_r.round(4))
    d['theta_r'].append(theta_r.round(4))
    d['p'].append(p)
    d['theta_p'].append(theta_p.round(4))
    d['J_r'].append(J_r)
    d['J_c'].append(J_c)
    d['robustness'].append(robustness)
    d['consistency'].append(consistency)

In [4]:
def recourse_runner(seed: int, X_train: np.ndarray, X: np.ndarray, lar_recourse: LARRecourse, roar_recourse: ROAR, base_model: NN, params: dict, dataset: Dataset):
    alpha = params['alpha']
    lamb = params['lamb']
    params['algs'] = [alg.lower() for alg in params['algs']]
    betas = np.arange(0., 1.01, 0.01).round(2)
    
    results_opt = {'alg': [], 'seed': [], 'alpha': [], 'lambda': [], 'i': [], 'x_0': [], 'theta_0': [], 'beta': [], 'x_r': [], 'theta_r': [], 'p': [], 'theta_p': [], 'J_r': [], 'J_c': [], 'robustness': [], 'consistency': []}
    results_roar = deepcopy(results_opt)
    
    n = len(X)
    for i in tqdm.trange(n, desc=f'Evaluating recourse | alpha={alpha}; lambda={lamb}', colour='#0091ff'):
        x_0 = X[i]
        J = RecourseCost(x_0, lamb)
        
        # LIME approximation of original NN
        np.random.seed(i)
        weights_0, bias_0 = lime_explanation(base_model.predict, X_train, x_0)
        weights_0, bias_0 = np.round(weights_0, 4), np.round(bias_0, 4)
        theta_0 = np.hstack((weights_0, bias_0))
        
        # Initalize recourse methods with theta_0
        lar_recourse.weights = weights_0
        lar_recourse.bias = bias_0
        roar_recourse.set_weights(weights_0)
        roar_recourse.set_bias(bias_0)
        
        
        # Robust Recourse
        x_r = lar_recourse.get_recourse(x_0, beta=1.)
        weights_r, bias_r = lar_recourse.calc_theta_adv(x_r)
        theta_r = np.hstack((weights_r, bias_r))
        J_r_opt = J.eval(x_r, weights_r, bias_r)
        
        # Predictions
        predictions = generate_nn_predictions(dataset, theta_0, theta_r, alpha)
        
        for p, prediction in enumerate(predictions):
            weights_p, bias_p = prediction[:-1], prediction[[-1]]
            theta_p = (weights_p, bias_p)
            
            # Consistent Recourse
            x_c = lar_recourse.get_recourse(x_0, beta=0., theta_p=theta_p)
            J_c_opt = J.eval(x_c, *theta_p)
            
            # Learning Augmented Recourse
            for beta in betas:
                # Alg 1
                if 'alg1' in params['algs']:
                    x = lar_recourse.get_recourse(x_0, beta=beta, theta_p=theta_p)
                    weights_r, bias_r = lar_recourse.calc_theta_adv(x)
                    theta_r = np.hstack((weights_r, bias_r))
                    
                    J_r = J.eval(x, weights_r, bias_r)
                    J_c = J.eval(x, weights_p, bias_p)
                    robustness = J_r - J_r_opt
                    consistency = J_c - J_c_opt
                    
                    append_result(results_opt, 'OPT', seed, alpha, lamb, i, x_0, theta_0, beta, x, theta_r, p, prediction, J_r[0], J_c[0], robustness[0], consistency[0])
                
                # ROAR
                if 'roar' in params['algs']:
                    x, _ = roar_recourse.get_recourse(x_0, theta_p, beta)
                    weights_r, bias_r = lar_recourse.calc_theta_adv(x)
                    theta_r = np.hstack((weights_r, bias_r))
                    
                    J_r = J.eval(x, weights_r, bias_r)
                    J_c = J.eval(x, weights_p, bias_p)
                    robustness = J_r - J_r_opt
                    consistency = J_c - J_c_opt
                    
                    append_result(results_roar, 'ROAR', seed, alpha, lamb, i, x_0, theta_0, beta, x, theta_r, p, prediction, J_r[0], J_c[0], robustness[0], consistency[0])
                
    # Save history
    df_results = pd.DataFrame()
    if 'alg1' in params['algs']:
        df_opt = pd.DataFrame(results_opt)
        if params['save_history']:
            print(f'[Alg1] Saving history for {dataset.name} run {seed}')
            df_opt.to_pickle(f'../results/rob_con_tradeoff/history/nn_{dataset.name}_alg1_{seed}.pkl')
        df_opt_agg = df_opt.groupby(['alg', 'p', 'beta'], as_index=False).mean(True)
        if params['save_results']:
            print(f'[Alg1] Saving results for {dataset.name} run {seed}')
            df_opt_agg.to_pickle(f'../results/rob_con_tradeoff/output/nn_{dataset.name}_alg1_{seed}.pkl')
        df_results = pd.concat((df_results, df_opt_agg))
    
    if 'roar' in params['algs']:
        df_roar = pd.DataFrame(results_roar)
        if params['save_history']:
            print(f'[ROAR] Saving history for {dataset.name} run {seed}')
            df_roar.to_pickle(f'../results/rob_con_tradeoff/history/nn_{dataset.name}_roar_{seed}.pkl')
        df_roar_agg = df_roar.groupby(['alg', 'p', 'beta'], as_index=False).mean(True)
        if params['save_results']:
            print(f'[ROAR] Saving results for {dataset.name} run {seed}')
            df_roar_agg.to_pickle(f'../results/rob_con_tradeoff/output/nn_{dataset.name}_roar_{seed}.pkl')
        df_results = pd.concat((df_results, df_roar_agg))
    
    return df_results
        

In [5]:
def run_experiment(dataset: Dataset, params: dict, results: List):
    alpha = params['alpha']
    
    for seed in params['seeds']:
        (train_data, test_data) = dataset.get_data(seed)
        X_train, y_train = train_data
        X_test, y_test = test_data
        
        base_model = NN(X_train.shape[1])
        base_model.train(X_train.values, y_train.values)
        
        recourse_needed_X_train = recourse_needed(base_model.predict, X_train.values)
        recourse_needed_X_test = recourse_needed(base_model.predict, X_test.values)
        
        weights, bias = None, None
        lar_recourse = LARRecourse(weights=weights, bias=bias, alpha=alpha)
        roar_recourse = ROAR(weights=weights, bias=bias, alpha=alpha)
        
        params['lamb'] = lar_recourse.choose_lambda(recourse_needed_X_train, base_model.predict, X_train.values)
        lar_recourse.lamb = params['lamb']
        roar_recourse.lamb = params['lamb']
        
        df_results = recourse_runner(seed, X_train.values, recourse_needed_X_test, lar_recourse, roar_recourse, base_model, params, dataset)
        results.append(df_results)

In [6]:
torch.manual_seed(0)

d_results = {}
params = {}
params['alpha'] = 0.5
params['lamb'] = None
params['seeds'] = range(5)
params['algs'] = ['roar'] # 'alg1', 'roar
params['save_results'] = True
params['save_history'] = True
params['save_final_results'] = False


datasets = [SyntheticDataset(), GermanDataset(), SBADataset()]
for dataset in datasets:
    results = []
    
    print(f'Running {dataset.name} data...')
    run_experiment(dataset, params, results)
    
    d_results[dataset.name] = pd.concat(results)
    if params['save_final_results']:
        d_results[dataset.name].to_pickle(f'../results/rob_con_tradeoff/output/nn_{dataset.name}')
    
    
    print(f'Finished {dataset.name}\n')

Running synthetic data...
Choosing lambda


lambda=0.1: 100%|██████████| 404/404 [00:01<00:00, 206.32it/s]
lambda=0.2: 100%|██████████| 404/404 [00:01<00:00, 217.19it/s]
lambda=0.3: 100%|██████████| 404/404 [00:01<00:00, 210.31it/s]
lambda=0.4: 100%|██████████| 404/404 [00:01<00:00, 237.02it/s]
lambda=0.5: 100%|██████████| 404/404 [00:01<00:00, 231.03it/s]
lambda=0.6: 100%|██████████| 404/404 [00:01<00:00, 227.97it/s]
lambda=0.7: 100%|██████████| 404/404 [00:01<00:00, 224.59it/s]
lambda=0.8: 100%|██████████| 404/404 [00:01<00:00, 221.61it/s]
lambda=0.9: 100%|██████████| 404/404 [00:01<00:00, 230.53it/s]
lambda=1.0: 100%|██████████| 404/404 [00:01<00:00, 213.33it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 96/96 [3:22:57<00:00, 126.85s/it]  


[ROAR] Saving history for synthetic run 0
[ROAR] Saving results for synthetic run 0
Choosing lambda


lambda=0.1: 100%|██████████| 405/405 [00:01<00:00, 228.48it/s]
lambda=0.2: 100%|██████████| 405/405 [00:01<00:00, 227.39it/s]
lambda=0.3: 100%|██████████| 405/405 [00:01<00:00, 234.83it/s]
lambda=0.4: 100%|██████████| 405/405 [00:01<00:00, 236.16it/s]
lambda=0.5: 100%|██████████| 405/405 [00:01<00:00, 237.18it/s]
lambda=0.6: 100%|██████████| 405/405 [00:01<00:00, 235.27it/s]
lambda=0.7: 100%|██████████| 405/405 [00:01<00:00, 233.60it/s]
lambda=0.8: 100%|██████████| 405/405 [00:01<00:00, 235.83it/s]
lambda=0.9: 100%|██████████| 405/405 [00:01<00:00, 236.63it/s]
lambda=1.0: 100%|██████████| 405/405 [00:01<00:00, 231.36it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 95/95 [3:05:42<00:00, 117.29s/it]  


[ROAR] Saving history for synthetic run 1
[ROAR] Saving results for synthetic run 1
Choosing lambda


lambda=0.1: 100%|██████████| 397/397 [00:01<00:00, 226.85it/s]
lambda=0.2: 100%|██████████| 397/397 [00:01<00:00, 234.89it/s]
lambda=0.3: 100%|██████████| 397/397 [00:01<00:00, 232.18it/s]
lambda=0.4: 100%|██████████| 397/397 [00:01<00:00, 222.80it/s]
lambda=0.5: 100%|██████████| 397/397 [00:01<00:00, 233.68it/s]
lambda=0.6: 100%|██████████| 397/397 [00:01<00:00, 233.24it/s]
lambda=0.7: 100%|██████████| 397/397 [00:01<00:00, 229.57it/s]
lambda=0.8: 100%|██████████| 397/397 [00:01<00:00, 236.53it/s]
lambda=0.9: 100%|██████████| 397/397 [00:01<00:00, 231.98it/s]
lambda=1.0: 100%|██████████| 397/397 [00:01<00:00, 234.80it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 103/103 [3:30:37<00:00, 122.69s/it] 


[ROAR] Saving history for synthetic run 2
[ROAR] Saving results for synthetic run 2
Choosing lambda


lambda=0.1: 100%|██████████| 399/399 [00:01<00:00, 217.69it/s]
lambda=0.2: 100%|██████████| 399/399 [00:01<00:00, 223.85it/s]
lambda=0.3: 100%|██████████| 399/399 [00:01<00:00, 222.30it/s]
lambda=0.4: 100%|██████████| 399/399 [00:01<00:00, 219.32it/s]
lambda=0.5: 100%|██████████| 399/399 [00:01<00:00, 212.17it/s]
lambda=0.6: 100%|██████████| 399/399 [00:01<00:00, 218.09it/s]
lambda=0.7: 100%|██████████| 399/399 [00:01<00:00, 218.52it/s]
lambda=0.8: 100%|██████████| 399/399 [00:01<00:00, 221.89it/s]
lambda=0.9: 100%|██████████| 399/399 [00:01<00:00, 218.37it/s]
lambda=1.0: 100%|██████████| 399/399 [00:01<00:00, 211.92it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 101/101 [3:30:31<00:00, 125.06s/it] 


[ROAR] Saving history for synthetic run 3
[ROAR] Saving results for synthetic run 3
Choosing lambda


lambda=0.1: 100%|██████████| 395/395 [00:01<00:00, 219.99it/s]
lambda=0.2: 100%|██████████| 395/395 [00:01<00:00, 225.31it/s]
lambda=0.3: 100%|██████████| 395/395 [00:01<00:00, 217.14it/s]
lambda=0.4: 100%|██████████| 395/395 [00:01<00:00, 225.24it/s]
lambda=0.5: 100%|██████████| 395/395 [00:01<00:00, 223.30it/s]
lambda=0.6: 100%|██████████| 395/395 [00:01<00:00, 222.41it/s]
lambda=0.7: 100%|██████████| 395/395 [00:01<00:00, 219.60it/s]
lambda=0.8: 100%|██████████| 395/395 [00:01<00:00, 217.66it/s]
lambda=0.9: 100%|██████████| 395/395 [00:01<00:00, 216.55it/s]
lambda=1.0: 100%|██████████| 395/395 [00:01<00:00, 214.35it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 105/105 [4:23:59<00:00, 150.85s/it]  


[ROAR] Saving history for synthetic run 4
[ROAR] Saving results for synthetic run 4
Finished synthetic

Running german data...
Choosing lambda


lambda=0.1: 100%|██████████| 119/119 [00:00<00:00, 206.45it/s]
lambda=0.2: 100%|██████████| 119/119 [00:00<00:00, 211.70it/s]
Evaluating recourse | alpha=0.5; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 37/37 [1:31:44<00:00, 148.76s/it]


[ROAR] Saving history for german run 0
[ROAR] Saving results for german run 0
Choosing lambda


lambda=0.1: 100%|██████████| 105/105 [00:00<00:00, 208.76it/s]
lambda=0.2: 100%|██████████| 105/105 [00:00<00:00, 211.33it/s]
lambda=0.3: 100%|██████████| 105/105 [00:00<00:00, 206.42it/s]
Evaluating recourse | alpha=0.5; lambda=0.2: 100%|[38;2;0;145;255m██████████[0m| 20/20 [52:47<00:00, 158.39s/it] 


[ROAR] Saving history for german run 1
[ROAR] Saving results for german run 1
Choosing lambda


lambda=0.1: 100%|██████████| 120/120 [00:00<00:00, 212.12it/s]
lambda=0.2: 100%|██████████| 120/120 [00:00<00:00, 210.95it/s]
lambda=0.3: 100%|██████████| 120/120 [00:00<00:00, 213.47it/s]
lambda=0.4: 100%|██████████| 120/120 [00:00<00:00, 209.64it/s]
Evaluating recourse | alpha=0.5; lambda=0.3: 100%|[38;2;0;145;255m██████████[0m| 34/34 [56:54<00:00, 100.42s/it] 


[ROAR] Saving history for german run 2
[ROAR] Saving results for german run 2
Choosing lambda


lambda=0.1: 100%|██████████| 127/127 [00:00<00:00, 206.33it/s]
lambda=0.2: 100%|██████████| 127/127 [00:00<00:00, 207.67it/s]
lambda=0.3: 100%|██████████| 127/127 [00:00<00:00, 206.76it/s]
Evaluating recourse | alpha=0.5; lambda=0.2: 100%|[38;2;0;145;255m██████████[0m| 24/24 [49:40<00:00, 124.17s/it]


[ROAR] Saving history for german run 3
[ROAR] Saving results for german run 3
Choosing lambda


lambda=0.1: 100%|██████████| 122/122 [00:00<00:00, 214.47it/s]
lambda=0.2: 100%|██████████| 122/122 [00:00<00:00, 215.21it/s]
lambda=0.3: 100%|██████████| 122/122 [00:00<00:00, 216.65it/s]
Evaluating recourse | alpha=0.5; lambda=0.2: 100%|[38;2;0;145;255m██████████[0m| 29/29 [1:10:17<00:00, 145.42s/it]


[ROAR] Saving history for german run 4
[ROAR] Saving results for german run 4
Finished german

Running sba data...
Choosing lambda


lambda=0.1: 100%|██████████| 150/150 [00:01<00:00, 129.88it/s]
lambda=0.2: 100%|██████████| 150/150 [00:01<00:00, 130.39it/s]
lambda=0.3: 100%|██████████| 150/150 [00:01<00:00, 129.59it/s]
lambda=0.4: 100%|██████████| 150/150 [00:01<00:00, 120.45it/s]
lambda=0.5: 100%|██████████| 150/150 [00:01<00:00, 130.18it/s]
lambda=0.6: 100%|██████████| 150/150 [00:01<00:00, 130.08it/s]
lambda=0.7: 100%|██████████| 150/150 [00:01<00:00, 130.51it/s]
lambda=0.8: 100%|██████████| 150/150 [00:01<00:00, 129.25it/s]
lambda=0.9: 100%|██████████| 150/150 [00:01<00:00, 128.60it/s]
lambda=1.0: 100%|██████████| 150/150 [00:01<00:00, 128.23it/s]
Evaluating recourse | alpha=0.5; lambda=1.0: 100%|[38;2;0;145;255m██████████[0m| 39/39 [2:55:02<00:00, 269.29s/it]  


[ROAR] Saving history for sba run 0
[ROAR] Saving results for sba run 0
Choosing lambda


lambda=0.1: 100%|██████████| 153/153 [00:01<00:00, 123.24it/s]
lambda=0.2: 100%|██████████| 153/153 [00:01<00:00, 128.78it/s]
lambda=0.3: 100%|██████████| 153/153 [00:01<00:00, 130.16it/s]
lambda=0.4: 100%|██████████| 153/153 [00:01<00:00, 127.57it/s]
lambda=0.5: 100%|██████████| 153/153 [00:01<00:00, 129.96it/s]
lambda=0.6: 100%|██████████| 153/153 [00:01<00:00, 129.50it/s]
Evaluating recourse | alpha=0.5; lambda=0.5: 100%|[38;2;0;145;255m██████████[0m| 36/36 [4:16:20<00:00, 427.22s/it]  


[ROAR] Saving history for sba run 1
[ROAR] Saving results for sba run 1
Choosing lambda


lambda=0.1: 100%|██████████| 149/149 [00:01<00:00, 128.57it/s]
lambda=0.2: 100%|██████████| 149/149 [00:01<00:00, 127.86it/s]
lambda=0.3: 100%|██████████| 149/149 [00:01<00:00, 129.88it/s]
lambda=0.4: 100%|██████████| 149/149 [00:01<00:00, 125.74it/s]
lambda=0.5: 100%|██████████| 149/149 [00:01<00:00, 128.57it/s]
lambda=0.6: 100%|██████████| 149/149 [00:01<00:00, 128.13it/s]
Evaluating recourse | alpha=0.5; lambda=0.5: 100%|[38;2;0;145;255m██████████[0m| 39/39 [3:56:36<00:00, 364.01s/it]  


[ROAR] Saving history for sba run 2
[ROAR] Saving results for sba run 2
Choosing lambda


lambda=0.1: 100%|██████████| 153/153 [00:01<00:00, 131.46it/s]
lambda=0.2: 100%|██████████| 153/153 [00:01<00:00, 135.04it/s]
lambda=0.3: 100%|██████████| 153/153 [00:01<00:00, 133.65it/s]
lambda=0.4: 100%|██████████| 153/153 [00:01<00:00, 134.61it/s]
lambda=0.5: 100%|██████████| 153/153 [00:01<00:00, 132.53it/s]
lambda=0.6: 100%|██████████| 153/153 [00:01<00:00, 134.26it/s]
Evaluating recourse | alpha=0.5; lambda=0.5: 100%|[38;2;0;145;255m██████████[0m| 36/36 [6:01:09<00:00, 601.93s/it]    


[ROAR] Saving history for sba run 3
[ROAR] Saving results for sba run 3
Choosing lambda


lambda=0.1: 100%|██████████| 151/151 [00:01<00:00, 129.13it/s]
lambda=0.2: 100%|██████████| 151/151 [00:01<00:00, 134.15it/s]
lambda=0.3: 100%|██████████| 151/151 [00:01<00:00, 132.91it/s]
lambda=0.4: 100%|██████████| 151/151 [00:01<00:00, 133.38it/s]
lambda=0.5: 100%|██████████| 151/151 [00:01<00:00, 128.11it/s]
lambda=0.6: 100%|██████████| 151/151 [00:01<00:00, 132.72it/s]
Evaluating recourse | alpha=0.5; lambda=0.5: 100%|[38;2;0;145;255m██████████[0m| 38/38 [8:51:15<00:00, 838.83s/it]    


[ROAR] Saving history for sba run 4
[ROAR] Saving results for sba run 4
Finished sba

