In [7]:
import sys
sys.path.append('..')

In [8]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from src.utils import pareto_frontier

In [32]:
dataset = 'synthetic'
base_model = 'nn'
algs = ['roar']
seeds = range(5)

df_results = pd.DataFrame()
for alg in algs:
    for seed in seeds:
        df_results = pd.concat((df_results, pd.read_pickle(f'../results/rob_con_tradeoff/output/{base_model}_{dataset}_{alg}_{seed}.pkl')))

In [33]:
df_agg = df_results.groupby(['alg', 'p', 'beta'], as_index=False).mean()
df_im = df_agg.copy()

In [35]:
colors = px.colors.qualitative.Plotly
nc = len(colors)
font_family = 'Times New Roman'
font_color = 'black'
width, height = 720, 540

params = {
    'ROAR': {
        'name': 'ROAR',
        'symbol': 'star',
        'size': 5
    },
    'OPT': {
        'name': 'Alg1',
        'symbol': 'circle',
        'size': 3
    }
}

x_val = 'consistency'
y_val = 'robustness'

fig = go.Figure()
for alg in df_im['alg'].unique():
    for p in df_im['p'].unique():
        df_alg = df_im[(df_im['p'] == p) & (df_im['alg'] == alg)]
        mask = pareto_frontier(df_alg[y_val], df_alg[x_val])
        # df_alg = df_alg.iloc[mask]
        
        t = '\\theta_{p}'.format(p=f'{int(p)}')
        a = '{alg}'.format(alg=params[alg]['name'])
        name = r'$\hat{theta} ({a})$'.format(theta=t, a=params[alg]['name'])

        fig.add_trace(go.Scatter(
            x = df_alg[x_val],
            y = df_alg[y_val],
            marker = dict(color=colors[p], symbol=params[alg]['symbol'], size=params[alg]['size']),
            mode = 'markers+lines',
            name = name,
            customdata=df_alg['beta'],
            hovertemplate='consistency: %{x}<br>robustness: %{y}<br>beta: %{customdata}'
        ))

fig.update_xaxes(
    title=dict(
        text=x_val.capitalize(),
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        )
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey', 
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )

fig.update_yaxes(
    title=dict(
        text=y_val.capitalize(),
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        ), 
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey',
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    )

fig.update_layout(
    legend=dict(
        # y=0.975, 
        # y=0.0255, 
        x=0.975, 
        y=0.975, 
        xanchor='right',
        font=dict(
            family=font_family,
            color=font_color,
            size=15
            ), 
        bgcolor='rgba(255, 255, 255, 0.7)',
        bordercolor='lightgrey',
        borderwidth=1,
        entrywidth=0.1,
        entrywidthmode='pixels',
        ),
    margin=dict(t=5, b=0, l=1, r=5),
    width=width,
    height=height,
    plot_bgcolor='white',
    paper_bgcolor='white',
    xaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20,
        ),
    ),
    yaxis=dict(
        tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        )
    )
    )

print(base_model.upper(), dataset.upper())
fig.show()

NN SYNTHETIC


In [28]:
fig.write_image(f'../figures/rob_con_tradeoff/rob_con_{base_model}_{dataset}.pdf')