In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from src.utils import pareto_frontier

In [3]:
dataset = 'synthetic'
base_model = 'lr'
algs = ['alg1']
seeds = range(5)

df_results = pd.DataFrame()
for alg in algs:
    for seed in seeds:
        df_results = pd.concat((df_results, pd.read_pickle(f'../results/rob_con_tradeoff/history/{base_model}_{dataset}_{alg}_{seed}.pkl')))

In [4]:
# L1 norm between x_0 (original) and x_r (recourse)
df_results['L1x'] = np.linalg.norm(np.stack((df_results['x_0'] - df_results['x_r']).values), 1, 1)

In [5]:
df_im = df_results[(df_results['seed']==0) & (df_results['i']==50) & (df_results['p']==4)]
df_im

Unnamed: 0,alg,seed,alpha,lambda,i,x_0,theta_0,beta,x_r,theta_r,p,theta_p,J_r,J_c,robustness,consistency,L1x
25654,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.00,"[-1.8037, 1.9094]","[2.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",5.574971,3.888719,1.367814,0.000000,3.3699
25655,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.01,"[-1.8037, 1.9]","[2.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",5.577876,3.888784,1.370719,0.000065,3.3605
25656,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.02,"[-1.8037, 1.9]","[2.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",5.577876,3.888784,1.370719,0.000065,3.3605
25657,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.03,"[0.0, 0.1]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.219932,3.928169,0.012775,0.039449,3.3642
25658,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.04,"[0.0, 0.1]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.219932,3.928169,0.012775,0.039449,3.3642
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25750,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.96,"[0.0, 0.0]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.207157,3.942199,0.000000,0.053480,3.2642
25751,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.97,"[0.0, 0.0]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.207157,3.942199,0.000000,0.053480,3.2642
25752,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.98,"[0.0, 0.0]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.207157,3.942199,0.000000,0.053480,3.2642
25753,OPT,0,0.5,1.0,50,"[-1.8037, -1.4605]","[1.9661, 1.9713, 0.0506]",0.99,"[0.0, 0.0]","[1.4661, 1.4713, -0.4494]",4,"[2.4184, 2.4705, 0.0307]",4.207157,3.942199,0.000000,0.053480,3.2642


In [6]:
X_0 = []
X_r = []
betas = [0, 0.5, 1]
for beta in betas:
    temp = df_im[df_im['beta']==beta]
    X_0.append(temp['x_0'].to_numpy()[0])
    X_r.append(temp['x_r'].to_numpy()[0])

X_0 = np.array(X_0)
X_r = np.array(X_r)

In [8]:
X_0

array([[-1.8037, -1.4605],
       [-1.8037, -1.4605],
       [-1.8037, -1.4605]])

In [9]:
X_r

array([[-1.8037,  1.9094],
       [ 0.    ,  0.1   ],
       [ 0.    ,  0.    ]])

In [80]:
font_family = 'Times New Roman'
font_color = 'black'
width, height = 720, 540

x_val = 'Feature 1'
y_val = 'Feature 2'


fig = go.Figure()

fig.add_trace(go.Scatter(
    x = X_0[:,0],
    y = X_0[:,1],
    name = 'x_0',
    mode = 'markers',
    marker=dict(size=5)
))



colors = ['#009AFF', '#8700FF', '#D20D8C', '#FE2201']
names = ['x_c', 'x\'', 'x_r']

for i, x_r in enumerate(X_r):
    x_0 = X_0[i]
    fig.add_trace(go.Scatter(
        x = [x_r[0]],
        y = [x_r[1]],
        mode='markers',
        marker=dict(color=colors[i]),
        line=dict(color='#009AFF', width=0.8, dash='dot'),
        name=names[i],
        text=names[i]
    ))
    fig.add_annotation(go.layout.Annotation(
        x=x_r[0], y=x_r[1], ax=x_0[0], ay=x_0[1], xref='x', yref='y', axref='x', ayref='y', arrowhead=3, arrowsize=2, arrowwidth=1, arrowcolor='green', standoff=5
    ))
    
fig.update_xaxes(
    title=dict(
        text=x_val.capitalize(),
        font=dict(
            family=font_family,
            color=font_color,
            size=20
        )
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey', 
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )

fig.update_yaxes(
    title=dict(
        text=y_val.capitalize(),
        font=dict(
            family=font_family,
            color=font_color,
            size=20
        ), 
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey',
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )

fig.update_layout(
    legend=dict(
        # y=0.975, 
        # y=0.0255, 
        x=0.975, 
        y=0.975, 
        xanchor='right',
        font=dict(
            family=font_family,
            color=font_color,
            size=15
            ), 
        bgcolor='rgba(255, 255, 255, 0.7)',
        bordercolor='lightgrey',
        borderwidth=1,
        entrywidth=0.1,
        entrywidthmode='pixels',
        ),
    margin=dict(t=5, b=0, l=1, r=5),
    width=width,
    height=height,
    plot_bgcolor='white',
    paper_bgcolor='white',
    xaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20,
        ),
    ),
    yaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        )
    )
    )


fig.show()