In [1]:
import sys
sys.path.append('..')

In [2]:
import warnings
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

warnings.filterwarnings('ignore')

In [24]:
dataset = 'synthetic'
base_model = 'nn'
algs = ['alg1', 'roar']
seeds = range(5)

df_results = pd.DataFrame()
for alg in algs:
    for seed in seeds:
        df_results = pd.concat((df_results, pd.read_pickle(f'../results/smoothness/output/{base_model}_{dataset}_{alg}_{seed}.pkl')))

In [25]:
df_agg = df_results.groupby(['alg', 'beta', 'p'], as_index=False).mean(True)
df_im = df_agg.copy()

In [26]:
colors = ['#FF97FF','#19D3F3','#EF553B','#636EFA','#FECB52','#AB63FA','#FFA15A','#636EFA']
pnames = ['theta_s-2\epsilon', 'theta_s-\epsilon', 'theta_s', 'theta_s+\epsilon', 'theta_s+2\epsilon']

font_family = 'Times New Roman'
font_color = 'black'
width, height = 1200, 540

symbols = ['circle', 'x', 'x', 'circle', 'triangle-up', 'star']
size = [7, 7, 7, 5, 8, 10]
show_errors = False

fig = go.Figure()
fig = make_subplots(cols=2, rows=1, subplot_titles=['Algorithm 1', 'ROAR'])

for i, p in enumerate(pnames):
    for alg in df_im['alg'].unique():
        df_alg = df_im[(df_im['alg'] == alg) & (df_im['p'] == p)].sort_values(['beta'], ascending=True)
        t = '\\theta_s{p}'.format(p=f'{p.split("_s")[-1]}')

        fig.add_trace(go.Scatter(
            x = df_alg['beta'],
            y = df_alg['J'],
            marker = dict(color=colors[i], symbol='circle', size=5),
            mode = 'lines+markers',
            name = r'$\hat{theta}$       '.format(theta=t),
            hovertemplate='Cost: %{x}<br>Validity: %{y}',
            showlegend=alg=='Alg1',
        ),col=1 if alg=='Alg1' else 2, row=1)

fig.update_xaxes(
    title=dict(
        text=r'$\beta$',
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        )
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey', 
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        )
    )


fig.update_yaxes(
    title=dict(
        text='Smoothness',
        font=dict(
            family=font_family,
            color=font_color,
            size=25
        ), 
        ), 
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey',
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        ),
    rangemode='tozero',
    col=1
    )

fig.update_yaxes(
    showline=True, 
    mirror=True,
    linecolor='black', 
    gridcolor='lightgrey',
    zerolinewidth=1,
    zerolinecolor='lightgrey',
    tickfont=dict(
            family=font_family,
            color=font_color,
            size=20
        ),
    rangemode='tozero',
    col=2
    )

fig.update_layout(
    legend=dict(
        x=0.440, 
        y=0.025, 
        # y=0.975, 
        # x=0.025, 
        orientation='v',
        xanchor='right',
        font=dict(
            family=font_family,
            color=font_color,
            size=15
            ), 
        bgcolor='rgba(255, 255, 255, 0.7)',
        bordercolor='lightgrey',
        borderwidth=1,
        ),
    width=width,
    height=height,
    plot_bgcolor='white',
    paper_bgcolor='white',
    margin=dict(r=10, l=5, t=35, b=1)
    )

fig.update_annotations(font_size=25, font_family=font_family)

print(base_model.upper(), dataset.upper())
fig.show()

NN SYNTHETIC


In [27]:
fig.write_image(f'../figures/smoothness/smoothness_{base_model}_{dataset}.pdf')