In [1]:
import sys
sys.path.append('..')

In [2]:
import tqdm
import torch
import pickle
import warnings
import numpy as np
import pandas as pd
import torch.optim as optim
from copy import deepcopy
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from src.data import *
from src.utils import *
from src.model import *
from src.recourse import *

warnings.filterwarnings('ignore')

  (f'theta_s-2\epsilon', (theta_p_minus_2eps[:-1], theta_p_minus_2eps[[-1]])),
  (f'theta_s-\epsilon', (theta_p_minus_eps[:-1], theta_p_minus_eps[[-1]])),
  (f'theta_s+\epsilon', (theta_p_plus_eps[:-1], theta_p_plus_eps[[-1]])),
  (f'theta_s+2\epsilon', (theta_p_plus_2eps[:-1], theta_p_plus_2eps[[-1]])),
  (f'theta_s-2\epsilon', (theta_p_minus_2eps[:-1], theta_p_minus_2eps[[-1]])),
  (f'theta_s-\epsilon', (theta_p_minus_eps[:-1], theta_p_minus_eps[[-1]])),
  (f'theta_s+\epsilon', (theta_p_plus_eps[:-1], theta_p_plus_eps[[-1]])),
  (f'theta_s+2\epsilon', (theta_p_plus_2eps[:-1], theta_p_plus_2eps[[-1]])),


In [3]:
def append_result(d, rob, loss, cost, m1_validity, wc_validity, m1_expectation, wc_expectation):
    d['Cost'].append(cost)
    d['M1 Validity'].append(m1_validity)
    d['WC Validity'].append(wc_validity)
    d['M1 Expectation'].append(m1_expectation)
    d['WC Expectation'].append(wc_expectation)
    d['J'].append(rob) 
    d['Loss'].append(loss)
    
def get_result(d, alg, seed, alpha, lamb):
    result = {
        'alg': alg, 
        'seed': seed,
        'alpha': alpha,
        'lambda': lamb,
        }
    
    for key in d.keys():
        result[key] = np.mean(d[key])
    return result

In [4]:
def get_model_adv_pga(X_0, X_r, cfr, alpha, lamb, pga_max_iter: int = 100):
    X_0 = torch.tensor(np.stack(X_0)).float()
    X_r = torch.tensor(np.stack(X_r)).float()
    
    loss_fn = torch.nn.BCELoss(reduction='mean')
    cfr_adv = deepcopy(cfr)
    optimizer = optim.Adam(cfr_adv.parameters(), maximize=True)
    weights_min = [cfr.fc1.weight.data-alpha, cfr.fc2.weight.data-alpha, cfr.fc3.weight.data-alpha, cfr.out.weight.data-alpha]
    weights_max = [cfr.fc1.weight.data+alpha, cfr.fc2.weight.data+alpha, cfr.fc3.weight.data+alpha, cfr.out.weight.data+alpha]
    bias_min = [cfr.fc1.bias.data-alpha, cfr.fc2.bias.data-alpha, cfr.fc3.bias.data-alpha, cfr.out.bias.data-alpha]
    bias_max = [cfr.fc1.bias.data+alpha, cfr.fc2.bias.data+alpha, cfr.fc3.bias.data+alpha, cfr.out.bias.data+alpha]
        
    loss = torch.tensor(1.)
    loss_diff = 1
    i = 0
    # while loss_diff > 1e-4:
    for epoch in range(pga_max_iter):
        prev_loss = loss.clone().detach()
        optimizer.zero_grad()
        
        f_x = cfr_adv(X_r)
        y_target = torch.ones(f_x.shape).float()
        bce_loss = loss_fn(f_x, y_target)
        cost = torch.dist(X_r, X_0, 1)
        loss = bce_loss + lamb*cost
        
        loss.backward()
        optimizer.step()
        
        loss_diff = torch.dist(prev_loss, loss, 1)
        i += 1
        
        # clamp model parameters to -alpha, alpha range
        cfr_adv.fc1.weight.data = cfr_adv.fc1.weight.data.clamp(weights_min[0], weights_max[0])
        cfr_adv.fc2.weight.data = cfr_adv.fc2.weight.data.clamp(weights_min[1], weights_max[1])
        cfr_adv.fc3.weight.data = cfr_adv.fc3.weight.data.clamp(weights_min[2], weights_max[2])
        cfr_adv.out.weight.data = cfr_adv.out.weight.data.clamp(weights_min[3], weights_max[3])
        
        cfr_adv.fc1.bias.data = cfr_adv.fc1.bias.data.clamp(bias_min[0], bias_max[0])
        cfr_adv.fc2.bias.data = cfr_adv.fc2.bias.data.clamp(bias_min[1], bias_max[1])
        cfr_adv.fc3.bias.data = cfr_adv.fc3.bias.data.clamp(bias_min[2], bias_max[2])
        cfr_adv.out.bias.data = cfr_adv.out.bias.data.clamp(bias_min[3], bias_max[3])
    
    wnorms = [
        torch.dist(cfr.fc1.weight.data, cfr_adv.fc1.weight.data, torch.inf),
        torch.dist(cfr.fc2.weight.data, cfr_adv.fc2.weight.data, torch.inf),
        torch.dist(cfr.fc3.weight.data, cfr_adv.fc3.weight.data, torch.inf),
        torch.dist(cfr.out.weight.data, cfr_adv.out.weight.data, torch.inf),
    ]
    
    bnorms = [
        torch.dist(cfr.fc1.bias.data, cfr_adv.fc1.bias.data, torch.inf),
        torch.dist(cfr.fc2.bias.data, cfr_adv.fc2.bias.data, torch.inf),
        torch.dist(cfr.fc3.bias.data, cfr_adv.fc3.bias.data, torch.inf),
        torch.dist(cfr.out.bias.data, cfr_adv.out.bias.data, torch.inf),
    ]
    
    # print(f'Final Loss: {loss}')
    # print(f'Num Iterations: {i}')
    # print(f'weights_alpha, bias_alpha: {max(wnorms), max(bnorms)}')
            
    return cfr_adv

In [5]:
def evaluate_performance(X_0, X_r, cfr, alpha, lamb, seed, alg, method='search'):
    results = {'Cost': [], 'M1 Validity': [], 'WC Validity': [], 'M1 Expectation': [], 'WC Expectation': [], 'J': [], 'Loss': []}

    cfr_adv = get_model_adv_pga(X_0, X_r, cfr, alpha, lamb, 99)    
    n = len(X_r)

    for i in tqdm.trange(n, desc=f'Eval alpha={alpha:.4f}; lambda={lamb}', colour='#0091ff'):
        x_0 = torch.from_numpy(X_0[i]).float()
        x_r = torch.from_numpy(X_r[i]).float()
        J = RecourseCost(x_0, lamb)
        
        bce_loss_opt, cost_opt, rob_opt = J.eval_nonlinear(x_r.reshape((1,len(x_r))), cfr_adv, True)
        m1_validity_opt = cfr.predict(x_r.reshape(1,-1))[0]
        m1_expectation_opt = cfr.predict_proba(x_r.reshape(1,-1))[0,1]
        
        wc_validity_opt = cfr_adv.predict(x_r.reshape(1,-1))[0]
        wc_expectation_opt = cfr_adv.predict_proba(x_r.reshape(1,-1))[0,1]
        
        append_result(results, rob_opt, bce_loss_opt, cost_opt, m1_validity_opt, wc_validity_opt, m1_expectation_opt, wc_expectation_opt)
        
    return get_result(results, alg, seed, alpha, lamb)

In [6]:
method_map = {
    'alg1_lamb0.05': 'Alg1 (λ=0.05)', 
    'alg1_lamb0.1': 'Alg1 (λ=0.1)',
    'alg1_lamb0.2': 'Alg1 (λ=0.2)', 
    'alg1_lamb0.3': 'Alg1 (λ=0.3)',
    'roar': 'ROAR ', 
    'rbr': 'RBR', 
    'wachter': 'WACHTER'
    }

data_map = {'synthetic': 'Synthetic', 'sba': 'Small Business Administration', 'german': 'German', 'income': 'ACS Income'}
model_map = {'lr': 'Logistic Regression', 'nn': 'Neural Network'}

In [7]:
torch.manual_seed(0)
params = {}
# 'synthetic', 'german', 'sba'
params['data'] = 'sba'
# 'lr', 'nn'
params['base_model'] = 'nn'
params['algs'] = ['alg1', 'roar']
params['alpha'] = 0.2
lambdas = [0.1, 0.2, 0.3]


results = {
    'Algorithm': [],
    'Cost': [],
    'Current Validity': [],
    'Worst Case Validity': []
}

# Synthesis
methods = []
for method in params['algs']:
    for li, lamb in enumerate(lambdas):
        if method == 'alg1':
            methods.append(f'{method}_lamb{lamb}')
        else:
            if li > 0:
                break
            methods.append(f'{method}')

for li, mname in enumerate(methods):
    with open(f"../results/cost_validity/{params['base_model']}_{params['data']}_{mname}.pickle", 'rb') as f:
        data = pickle.load(f)

    lamb = lambdas[li%len(lambdas)]
    
    for i in range(len(data["params"])):
        # alpha = 0.002
        if mname != 'rbr':
            alpha = data['params'][i]['delta_max'] * 0.001
        else:
            alpha = 0.02
        
        X_0 = data['x_0'][i]
        X_r = data['x_r'][i]
        theta_0 = data['theta_0']
        cfr = Net0(len(X_0[i]))
        cfr.load_state_dict(theta_0)
        
        res = evaluate_performance(X_0, X_r, deepcopy(cfr), alpha, lamb, i, mname)
        
        results['Algorithm'].append(method_map[mname])
        results['Cost'].append(res['Cost'])
        results['Current Validity'].append(res['M1 Validity'])
        results['Worst Case Validity'].append(res['WC Validity'])
        
df_results = pd.DataFrame(results)

Eval alpha=0.0000; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4384.05it/s]
Eval alpha=0.0000; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4337.44it/s]
Eval alpha=0.0000; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4671.39it/s]
Eval alpha=0.0001; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4523.48it/s]
Eval alpha=0.0001; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4582.94it/s]
Eval alpha=0.0001; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4432.04it/s]
Eval alpha=0.0001; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4492.57it/s]
Eval alpha=0.0001; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4495.79it/s]
Eval alpha=0.0002; lambda=0.1: 100%|[38;2;0;145;255m██████████[0m| 100/100 [00:00<00:00, 4347.83it/s]
Eval alpha=0.0002; lambda=0.1: 100%|[38;2;0;145;255m██████████

In [8]:
print(f'{data_map[params["data"]]}  |  {params["base_model"].upper()}')
df_results_avg = df_results.groupby('Algorithm').mean(True)
df_results_avg[['Cost', 'Current Validity', 'Worst Case Validity']] = df_results_avg[['Cost', 'Current Validity', 'Worst Case Validity']].round(2).astype(str) + '±' + df_results.groupby('Algorithm').std(numeric_only=True)[['Cost', 'Current Validity', 'Worst Case Validity']].round(2).astype(str)

df_results_avg

Small Business Administration  |  NN


Unnamed: 0_level_0,Cost,Current Validity,Worst Case Validity
Algorithm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Alg1 (λ=0.1),2.52±0.86,0.99±0.02,0.99±0.02
Alg1 (λ=0.2),1.35±0.67,0.92±0.12,0.92±0.12
Alg1 (λ=0.3),2.4±0.76,0.99±0.01,0.99±0.01
ROAR,0.32±0.11,0.56±0.0,0.56±0.0


In [9]:
# colors = ['#1f77b4', '#17becf', '#9467bd', '#e377c2', '#2ca02c'] # Synthesis
# colors = [ '#17becf', '#9467bd', '#e377c2', '#2ca02c'] # German
colors = ['#17becf', '#9467bd', '#e377c2', '#2ca02c'] # SBA

nc = len(colors)
font_family = 'Times New Roman'
font_color = 'black'
width, height = 720, 540

symbols = ['x' for _ in range(len(lambdas))] + ['circle', 'triangle-up']
size = [7 for _ in range(len(lambdas))] + [5,8]

show_errors = False
y_axis = 'Worst Case Validity'
# y_axis = 'Current Validity'

fig = go.Figure()
for i, alg in enumerate(df_results['Algorithm'].unique()):
    temp = df_results[(df_results['Algorithm'] == alg)].sort_values(['Cost'], ascending=True)
    x,y = find_pareto(temp['Cost'], temp[y_axis])
    # x,y = temp['Cost'], temp[y_axis]
    temp = pd.DataFrame({'Algorithm': [alg for _ in range(len(x))], 'Cost': x, y_axis: y})

    fig.add_trace(go.Scatter(
        x = temp['Cost'],
        y = temp[y_axis],
        marker = dict(color=colors[i], symbol=symbols[i], size=size[i]),
        mode = 'lines+markers' if alg != 'WACHTER' else 'markers',
        name = alg,
        hovertemplate='Cost: %{x}<br>Validity: %{y}',
        showlegend=True,
    ))

fig.update_xaxes(
    title=dict(
        text='Cost',
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        )
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey', 
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )


fig.update_yaxes(
    title=dict(
        text=y_axis,
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        ), 
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey',
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )


fig.update_layout(
    legend=dict(
        x=0.975, 
        y=0.025, 
        orientation='v',
        xanchor='right',
        font=dict(
            family=font_family,
            color=font_color,
            size=15
            ), 
        bgcolor='rgba(255, 255, 255, 0.7)',
        bordercolor='lightgrey',
        borderwidth=1,
        entrywidth=100.5,
        ),
    width=width,
    height=height,
    plot_bgcolor='white',
    paper_bgcolor='white',
    xaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20,
        ),
    ),
    yaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        ),
        range=[-0.1,1.1],
    )
    )

fig.show()

In [10]:
# fig.write_image(f'../figs/cost_validity_tradeoff_{params["base_model"]}_{params["data"]}.pdf')