In [None]:
import io
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set_style('darkgrid')

FIGS_DIR = 'figures_ccai'
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
COLOR_MAP = {
    'ETO-SLL (box)': 'tab:blue',
    'ETO-SLL (ellipse)': 'tab:orange',
    'ETO-JC (ellipse)': 'tab:green',
    'E2E (picnn)': 'tab:red',
}

In [None]:
def plot_task_loss(df: pd.DataFrame, filename: str, optimal: dict[str, float]):
    fig, axs = plt.subplots(1, 2, figsize=(8, 2.5), layout='constrained')
    alphas = ['0.01', '0.05', '0.1', '0.2']
    for i, ax in enumerate(axs):
        if i == 0:
            setting = 'no distribution shift'
            ax.set(ylabel='task loss')
        else:
            setting = 'with distribution shift'

        for model, color in COLOR_MAP.items():
            try:
                means = df.loc[(setting, model, 'mean'), alphas].values
                stds = df.loc[(setting, model, 'std'), alphas].values
                ax.plot(range(4), means, label=model, color=color)
                ax.fill_between(range(4), means-stds, means+stds, alpha=0.3, color=color)
            except:
                pass

        optimal_val = optimal[setting]
        ax.plot((0, 3), (optimal_val, optimal_val), color='black', ls=':', label='optimal')
        ax.set(xticks=range(4), xticklabels=alphas, xlabel='$\\alpha$', title=setting)

    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(handles, labels, loc='outside right center')
    for text in legend.get_texts():
        if text.get_text().startswith('E2E'):
            text.set_fontweight('bold')
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.pdf'))
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.png'))

## Storage task loss

In [None]:
buf = io.StringIO(
"""
setting,model,val,0.01,0.05,0.1,0.2
with distribution shift,ETO-SLL (box),mean,-10.867236,-15.960863,-17.455506,-20.217888
with distribution shift,ETO-SLL (box),std,3.304253,1.158833,0.917278,1.211853
with distribution shift,ETO-SLL (ellipse),mean,-12.770146,-17.391244,-19.935627,-20.871551
with distribution shift,ETO-SLL (ellipse),std,2.172402,1.752889,1.807191,2.315126
with distribution shift,ETO-JC (ellipse),mean,-0.004849,-8.52487,-14.341357,-16.935207
with distribution shift,ETO-JC (ellipse),std,0.015332,4.948228,3.723592,3.447717
with distribution shift,E2E (picnn),mean,-27.377843,-29.244593,-30.546455,-30.815375
with distribution shift,E2E (picnn),std,2.142963,2.064457,0.465653,0.206176
no distribution shift,ETO-SLL (box),mean,-12.880297,-24.462366,-28.001287,-31.287236
no distribution shift,ETO-SLL (box),std,7.369936,2.411725,2.636298,1.679142
no distribution shift,ETO-SLL (ellipse),mean,-17.889201,-24.415428,-27.500599,-31.131884
no distribution shift,ETO-SLL (ellipse),std,4.362387,4.208463,4.455954,2.612563
no distribution shift,ETO-JC (ellipse),mean,-2.638006,-24.262258,-29.659861,-33.462275
no distribution shift,ETO-JC (ellipse),std,1.310785,5.448612,4.68547,3.822394
no distribution shift,E2E (picnn),mean,-39.785395,-41.294381,-42.965304,-43.171177
no distribution shift,E2E (picnn),std,1.842494,2.592749,0.101937,0.100933
""")
df = pd.read_csv(buf).set_index(['setting', 'model', 'val'])
df

In [None]:
plot_task_loss(
    df, filename='taskloss',
    optimal={
        'with distribution shift': -32.50498588,
        'no distribution shift': -45.4994702,
    })