In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import io
import seaborn as sns
sns.set_style('darkgrid')

FIGS_DIR = 'figures_ccai'

import os
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
COLOR_MAP = {
    'ETO-SLL (box)': 'tab:blue',
    'ETO-SLL (ellipse)': 'tab:orange',
    'ETO-JC (ellipse)': 'tab:green',
    'E2E (picnn)': 'tab:red',
}

In [None]:
def plot_task_loss(df: pd.DataFrame, filename: str, optimal: dict[str, float]):
    fig, axs = plt.subplots(1, 2, figsize=(8, 2.5), layout='constrained')
    alphas = ['0.01', '0.05', '0.1', '0.2']
    for i, ax in enumerate(axs):
        if i == 0:
            setting = 'no distribution shift'
            ax.set(ylabel='task loss')
        else:
            setting = 'with distribution shift'

        for model, color in COLOR_MAP.items():
            try:
                means = df.loc[(setting, model, 'mean'), alphas].values
                stds = df.loc[(setting, model, 'std'), alphas].values
                ax.plot(range(4), means, label=model, color=color)
                ax.fill_between(range(4), means-stds, means+stds, alpha=0.3, color=color)
            except:
                pass

        optimal_val = optimal[setting]
        ax.plot((0, 3), (optimal_val, optimal_val), color='black', ls=':', label='optimal')
        ax.set(xticks=range(4), xticklabels=alphas, xlabel='$\\alpha$', title=setting)

    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(handles, labels, loc='outside right center')
    for text in legend.get_texts():
        if text.get_text().startswith('E2E'):
            text.set_fontweight('bold')
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.pdf'))
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.png'))

## Storage task loss

In [None]:
buf = io.StringIO(
"""
setting,model,val,0.01,0.05,0.1,0.2
with distribution shift,ETO-SLL (box),mean,-10.867236,-15.960863,-17.455506,-20.217888
with distribution shift,ETO-SLL (box),std,3.304253,1.158833,0.917278,1.211853
with distribution shift,ETO-SLL (ellipse),mean,-18.244101,-21.993499,-23.686633,-24.497868
with distribution shift,ETO-SLL (ellipse),std,2.049533,1.599452,1.661991,1.170572
with distribution shift,ETO-JC (ellipse),mean,-4.772606,-16.551971,-20.367132,-22.658602
with distribution shift,ETO-JC (ellipse),std,5.942747,2.157322,1.57894,1.204185
with distribution shift,E2E (picnn),mean,-27.377843,-29.244593,-30.546455,-30.815375
with distribution shift,E2E (picnn),std,2.142963,2.064457,0.465653,0.206176
no distribution shift,ETO-SLL (box),mean,-12.880297,-24.462366,-28.001287,-31.287236
no distribution shift,ETO-SLL (box),std,7.369936,2.411725,2.636298,1.679142
no distribution shift,ETO-SLL (ellipse),mean,-24.20926,-28.872667,-30.773642,-32.595965
no distribution shift,ETO-SLL (ellipse),std,3.175815,2.860649,2.523238,1.593281
no distribution shift,ETO-JC (ellipse),mean,-9.041665,-27.972709,-33.052322,-35.686457
no distribution shift,ETO-JC (ellipse),std,7.61495,1.90364,2.425706,2.212662
no distribution shift,E2E (picnn),mean,-39.785395,-40.889996,-42.965304,-43.171177
no distribution shift,E2E (picnn),std,1.842494,3.473853,0.101937,0.100933
""")
df = pd.read_csv(buf).set_index(['setting', 'model', 'val'])
df

In [None]:
plot_task_loss(
    df, filename='taskloss',
    optimal={
        'with distribution shift': -32.50498588,
        'no distribution shift': -45.4994702,
    })