In [38]:
import numpy as np
import os
import pandas as pd
import plotly.express as px
px.defaults.template = 'plotly_white'
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots


In [39]:
dir = '/scratch/fmager/fge'
model = ['VGG16', 'PreResNet164', 'WideResNet28x10'] # VGG16, PreResNet164, WideResNet28x10
dataset = 'CIFAR100' # CIFAR10, CIFAR100
curve = ['Bezier', 'PolyChain'] # Bezier, PolyChain
start_seed = [0, 0, 1]
end_seed = [1, 3, 3]
projection = 'cosine'
sampling_method = 'linear'
fig_path = os.path.join('./figures', f'{dataset}', f'{projection}', f'{sampling_method}')
os.makedirs(fig_path, exist_ok=True)

In [59]:
conc_results = pd.DataFrame()

for m in model:
    for s_init, s_end in zip(start_seed, end_seed):
        for c in curve:

            path = os.path.join(dir, dataset.lower(), m.lower(), f'seed_{s_init}_to_{s_end}', c.lower())

            if not os.path.exists(os.path.join(path, f'{sampling_method}_{projection}_rel_eval.npz')):
                continue

            results = dict(np.load(os.path.join(path, f'{sampling_method}_{projection}_rel_eval.npz'), allow_pickle=True))
            results = pd.DataFrame(results)
            results['idx'] = results.index


            results_acc = results.melt(id_vars=['ts', 'idx'], value_vars=['test_accuracy', 'ensemble_accuracy'], var_name='accuracy type', value_name='accuracy (%)')
            results_acc['accuracy type'] = results_acc['accuracy type'].str.split('_').str[0]

            results_fld = results.melt(id_vars=['ts', 'idx'], value_vars=['cumulative_relative_sample_fld', 'cumulative_relative_class_fld',
                                                                          'cumulative_absolute_sample_fld', 'cumulative_absolute_class_fld',
                                                                          'dt_relative_sample_fld', 'dt_relative_class_fld',
                                                                          'dt_absolute_sample_fld', 'dt_absolute_class_fld'], var_name='fld type', value_name='rho')

            results_fld['space'] = results_fld['fld type'].str.split('_').str[1]
            results_fld['fld type'] = results_fld['fld type'].str.split('_').str[0] + ' ' + results_fld['fld type'].str.split('_').str[-2]
            results_fld['fld type'] = results_fld['fld type'].str.replace('cumulative', 'cum.')
            

            results = pd.merge(results_acc, results_fld, on=['ts', 'idx'])

            results['model'] = m
            results['curve'] = c
            results['seed'] = f'seed {s_init} to {s_end}'

            conc_results = pd.concat([conc_results, results])

conc_results = conc_results.sort_values(by=['ts', 'model', 'curve', 'space'])
conc_results = conc_results.rename(columns={'ts': 't'})

print(conc_results)



       t  idx accuracy type  accuracy (%)     fld type       rho     space  \
2    0.0    0          test         78.14  cum. sample  1.000000  absolute   
3    0.0    0          test         78.14   cum. class  0.501298  absolute   
6    0.0    0          test         78.14    dt sample  0.996244  absolute   
7    0.0    0          test         78.14     dt class  0.502450  absolute   
170  0.0    0      ensemble         78.14  cum. sample  1.000000  absolute   
..   ...  ...           ...           ...          ...       ...       ...   
165  1.0   20          test         81.65     dt class  0.674194  relative   
328  1.0   20      ensemble         82.59  cum. sample  0.952905  relative   
329  1.0   20      ensemble         82.59   cum. class  0.691944  relative   
332  1.0   20      ensemble         82.59    dt sample  0.998754  relative   
333  1.0   20      ensemble         82.59     dt class  0.674194  relative   

               model      curve         seed  
2       PreResNe

In [58]:
# Create two subplots
def update_figure(fig, y_label, width=None, height=400):

    fig.update_layout(
        width =width, height=height,
        font_family="Serif", font_size=14, 
        margin_l=5, margin_b=10, margin_r=5,
        yaxis_title=y_label,
    )
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True, showgrid=True, gridwidth=.1, gridcolor='LightGray', zeroline=False)
    fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True, showgrid=True, gridwidth=.1, gridcolor='LightGray', zeroline=False)
    # set marker size
    fig.for_each_trace(lambda t: t.update(marker=dict(size=3)))
    fig.update_traces(opacity=.7)
    fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="left",
    x=0.0,
))
    return fig
    
    
fig1 = px.line(conc_results, x='t', y='accuracy (%)', color='model', line_dash='curve', markers=True, line_shape='spline',
            line_dash_map = {'Bezier': 'solid', 'PolyChain': 'dashdot'},
            range_x=[0, 1],
            range_y = [70, 82],
            line_group='seed',
            facet_col='accuracy type',
            render_mode='svg')

fig1 = update_figure(fig1, 'Accuracy')
fig1.show()
fig1.write_image(os.path.join(fig_path, 'accuracy.pdf'))


fig2 = px.line(conc_results, x='t', y='rho', color='space', line_dash='curve', markers=True, line_shape='spline',
            line_dash_map = {'Bezier': 'solid', 'PolyChain': 'dashdot'},
            range_x=[0, 1],
            line_group='seed',
            facet_row='fld type', facet_col='model',
            render_mode='svg')
fig2 = update_figure(fig2, height=600)
fig2 = fig2.update_yaxes(matches=None)
fig2.show()
fig2.write_image(os.path.join(fig_path, 'fld.pdf'))

