In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import io
import seaborn as sns
sns.set_style('darkgrid')

FIGS_DIR = 'figures_arxiv'

import os
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
COLOR_MAP = {
    'ETO': 'tab:purple',
    'ETO-JC': 'tab:green',
    'ETO-SLL': 'tab:blue',
    'E2E': 'tab:red',
}

In [None]:
def plot_task_loss(df: pd.DataFrame, filename: str, optimal: float | tuple[float, float]):
    fig, axs = plt.subplots(1, 3, figsize=(12, 3), sharey=True, layout='constrained')
    alphas = ['0.01', '0.05', '0.1', '0.2']
    for shape, ax in zip(('box', 'ellipse', 'picnn'), axs):
        for model, color in COLOR_MAP.items():
            try:
                means = df.loc[(model, shape, 'mean'), alphas].values
                stds = df.loc[(model, shape, 'std'), alphas].values
                ax.plot(range(4), means, label=model, color=color)
                ax.fill_between(range(4), means-stds, means+stds, alpha=0.3, color=color)
            except:
                pass

        if isinstance(optimal, (float, int)):
            ax.plot((0, 3), (optimal, optimal), color='black', ls=':', label='optimal')
        else:
            mean, std = optimal
            ax.plot((0, 3), (mean, mean), color='black', ls=':', label='optimal')
            ax.fill_between((0, 3), mean - std, mean + std, color='black', alpha=0.3)
        ax.set(xticks=range(4), xticklabels=alphas, xlabel='uncertainty level $\\alpha$', title=shape)
        if shape == 'box':
            ax.set(ylabel='task loss')

    ax = axs[1]
    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(handles, labels, loc='outside right center')
    for text in legend.get_texts():
        if text.get_text().startswith('E2E'):
            text.set_fontweight('bold')
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.pdf'))
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.png'))


def plot_task_loss_regret(df: pd.DataFrame, filename: str, optimal: float):
    fig, axs = plt.subplots(1, 3, figsize=(12, 3), sharey=True, tight_layout=True)
    alphas = ['0.01', '0.05', '0.1', '0.2']
    for shape, ax in zip(('box', 'ellipse', 'picnn'), axs):
        for model, color in COLOR_MAP.items():
            try:
                means = df.loc[(model, shape, 'mean'), alphas].values
                mean_regret = means - optimal
                stds = df.loc[(model, shape, 'std'), alphas].values
                ax.plot(range(4), mean_regret, label=model, color=color)
                ax.fill_between(range(4), mean_regret-stds, mean_regret+stds, alpha=0.3, color=color)
            except:
                pass

        ax.set(xticks=range(4), xticklabels=alphas, xlabel='$\\alpha$', title=shape)
        if shape == 'box':
            ax.set(ylabel='regret')
        ax.legend()
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.pdf'))
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.png'))


def plot_coverage(df: pd.DataFrame, filename: str):
    fig, axs = plt.subplots(1, 3, figsize=(12, 3), sharey=True, layout='constrained')
    alphas = ['0.01', '0.05', '0.1', '0.2']
    for shape, ax in zip(('box', 'ellipse', 'picnn'), axs):
        for model, color in COLOR_MAP.items():
            try:
                means = df.loc[(model, shape, 'mean'), alphas].values
                stds = df.loc[(model, shape, 'std'), alphas].values
                ax.plot(range(4), means, label=model, color=color)
                ax.fill_between(range(4), means-stds, means+stds, alpha=0.3, color=color)
            except:
                pass

        ax.plot(range(4), [0.99, 0.95, 0.9, 0.8], ls=':', color='black', label='optimal')
        ax.set(xticks=range(4), xticklabels=alphas, xlabel='uncertainty level $\\alpha$', title=shape)
        if shape == 'box':
            ax.set(ylabel='coverage')

    ax = axs[1]
    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(handles, labels, loc='outside right center')
    for text in legend.get_texts():
        if text.get_text().startswith('E2E'):
            text.set_fontweight('bold')
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.pdf'))
    fig.savefig(os.path.join(FIGS_DIR, f'{filename}.png'))

## Storage dist-shift task loss

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,-17.450167,-21.299612,-22.91361,-24.726097
ETO,box,std,2.158718,0.372012,0.57891,0.766728
ETO,ellipse,mean,-8.637755,-13.995689,-14.642697,-15.240083
ETO,ellipse,std,2.587513,1.264177,1.305358,1.368691
ETO,picnn,mean,-6.766127,-11.066663,-12.151778,-12.950701
ETO,picnn,std,3.075998,4.241612,4.582858,4.27471
ETO-SLL,box,mean,-10.867236,-15.960863,-17.455506,-20.217888
ETO-SLL,box,std,3.304253,1.158833,0.917278,1.211853
ETO-SLL,ellipse,mean,-18.244101,-21.993499,-23.686633,-24.497868
ETO-SLL,ellipse,std,2.049533,1.599452,1.661991,1.170572
ETO-JC,ellipse,mean,-4.772606,-16.551971,-20.367132,-22.658602
ETO-JC,ellipse,std,5.942747,2.157322,1.57894,1.204185
E2E,box,mean,-20.4442,-23.5843,-24.5894,-25.7206
E2E,box,std,3.9686,0.4579,0.5078,0.899
E2E,ellipse,mean,-28.6046,-29.7657,-30.4725,-31.0868
E2E,ellipse,std,1.2628,1.7897,1.4968,1.2415
E2E,picnn,mean,-27.377843,-29.244593,-30.546455,-30.815375
E2E,picnn,std,2.142963,2.064457,0.465653,0.206176
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_task_loss(df, filename='storage_distshift_taskloss', optimal=-32.50498588)
# plot_task_loss_regret(df, filename='storage_distshift_taskloss_regret', optimal=-32.50498588)

## Storage shuffle task loss

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,-26.370909,-32.739807,-34.50698,-36.813602
ETO,box,std,2.063016,1.03021,0.765397,0.765509
ETO,ellipse,mean,-8.617506,-18.147022,-18.982848,-19.766314
ETO,ellipse,std,5.701845,2.445104,2.417572,2.765179
ETO,picnn,mean,-10.69,-16.25,-18.33,-19.77
ETO,picnn,std,3.349,3.806,3.699,3.432
ETO-SLL,box,mean,-12.880297,-24.462366,-28.001287,-31.287236
ETO-SLL,box,std,7.369936,2.411725,2.636298,1.679142
ETO-SLL,ellipse,mean,-24.20926,-28.872667,-30.773642,-32.595965
ETO-SLL,ellipse,std,3.175815,2.860649,2.523238,1.593281
ETO-JC,ellipse,mean,-9.041665,-27.972709,-33.052322,-35.686457
ETO-JC,ellipse,std,7.61495,1.90364,2.425706,2.212662
E2E,box,mean,-30.8399,-34.5764,-35.9106,-37.8004
E2E,box,std,1.5141,0.5529,0.7764,0.6791
E2E,ellipse,mean,-37.7843,-39.454,-39.7673,-40.9565
E2E,ellipse,std,4.5091,4.3277,3.049,3.8543
E2E,picnn,mean,-39.785395,-40.889996,-42.965304,-43.171177
E2E,picnn,std,1.842494,3.473853,0.101937,0.100933
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_task_loss(df, filename='storage_shuffle_taskloss', optimal=-45.4994702)
# plot_task_loss_regret(df, filename='storage_shuffle_taskloss_regret', optimal=-45.4994702)

## Storage dist-shift coverage

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,0.996347,0.959589,0.905023,0.775571
ETO,box,std,0.003082,0.019164,0.035487,0.050356
ETO,ellipse,mean,0.98379,0.874658,0.750457,0.561187
ETO,ellipse,std,0.010404,0.063967,0.079775,0.081383
ETO,picnn,mean,0.98516,0.865297,0.763699,0.633333
ETO,picnn,std,0.015936,0.122482,0.17486,0.195958
ETO-SLL,box,mean,0.995434,0.931507,0.787443,0.637671
ETO-SLL,box,std,0.009864,0.03209,0.075204,0.094761
ETO-SLL,ellipse,mean,0.933105,0.808447,0.640183,0.510959
ETO-SLL,ellipse,std,0.066694,0.100597,0.095656,0.064718
ETO-JC,ellipse,mean,1,0.995662,0.939726,0.814384
ETO-JC,ellipse,std,0,0.004623,0.043053,0.064505
E2E,box,mean,0.992,0.9347,0.8683,0.7441
E2E,box,std,0.0068,0.03,0.0372,0.0511
E2E,ellipse,mean,0.9918,0.8842,0.7813,0.6532
E2E,ellipse,std,0.0081,0.0734,0.0767,0.1022
E2E,picnn,mean,0.978082,0.893379,0.850228,0.689954
E2E,picnn,std,0.027654,0.076635,0.061094,0.104074
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_coverage(df, filename='storage_distshift_coverage')

## Storage shuffle coverage

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,0.984932,0.931963,0.875571,0.778539
ETO,box,std,0.007996,0.011975,0.018368,0.017815
ETO,ellipse,mean,0.989498,0.943151,0.886986,0.781507
ETO,ellipse,std,0.006107,0.012618,0.019794,0.027101
ETO,picnn,mean,0.9879,0.9454,0.9018,0.8055
ETO,picnn,std,0.007765,0.01598,0.01685,0.01913
ETO-SLL,box,mean,0.992009,0.951142,0.9,0.778082
ETO-SLL,box,std,0.007159,0.024334,0.032889,0.038185
ETO-SLL,ellipse,mean,0.957078,0.871461,0.799772,0.705479
ETO-SLL,ellipse,std,0.017906,0.038113,0.040457,0.035549
ETO-JC,ellipse,mean,0.991781,0.938356,0.875342,0.768721
ETO-JC,ellipse,std,0.002454,0.009566,0.018617,0.01611
E2E,box,mean,0.9797,0.9199,0.8582,0.7658
E2E,box,std,0.0091,0.0167,0.0235,0.0168
E2E,ellipse,mean,0.9925,0.9516,0.91,0.7995
E2E,ellipse,std,0.0052,0.0131,0.0138,0.0227
E2E,picnn,mean,0.9879,0.949543,0.874201,0.761872
E2E,picnn,std,0.004698,0.010679,0.016409,0.023704
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_coverage(df, filename='storage_shuffle_coverage')

## Portfolio synthetic task loss

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,-1.161178,-1.372686,-1.391445,-1.411191
ETO,box,std,0.417279,0.118968,0.134566,0.118071
ETO,ellipse,mean,-0.907793,-1.092547,-1.189434,-1.279274
ETO,ellipse,std,0.081758,0.102121,0.098324,0.093977
ETO,picnn,mean,-0.884407,-1.059457,-1.154048,-1.259902
ETO,picnn,std,0.213987,0.168555,0.159744,0.11653
ETO-SLL,box,mean,-1.4054,-1.416325,-1.420571,-1.436291
ETO-SLL,box,std,0.132714,0.119232,0.116984,0.110968
ETO-SLL,ellipse,mean,-0.840694,-1.015641,-1.2442,-1.386948
ETO-SLL,ellipse,std,0.146894,0.173704,0.14569,0.127257
ETO-JC,ellipse,mean,-0.86239,-1.10404,-1.315613,-1.424665
ETO-JC,ellipse,std,0.094252,0.118496,0.126147,0.106042
E2E,box,mean,-1.188,-1.387,-1.413,-1.425
E2E,box,std,0.4253,0.1119,0.1121,0.09528
E2E,ellipse,mean,-1.4558,-1.4653,-1.4646,-1.4703
E2E,ellipse,std,0.1217,0.1147,0.1139,0.1017
E2E,picnn,mean,-1.468,-1.466,-1.467,-1.472
E2E,picnn,std,0.09566,0.09204,0.114,0.1017
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_task_loss(df, filename='portfolio_syn_taskloss', optimal=(-1.933641, 0.084673))

In [None]:
buf = io.StringIO(
"""
model,set,val,0.01,0.05,0.1,0.2
ETO,box,mean,0.9841,0.9472,0.902,0.7862
ETO,box,std,0.00743,0.017358,0.016806,0.020308
ETO,ellipse,mean,0.9883,0.9438,0.8941,0.794
ETO,ellipse,std,0.003974,0.019977,0.021605,0.026733
ETO,picnn,mean,0.9911,0.953,0.9047,0.8071
ETO,picnn,std,0.00599,0.013482,0.018968,0.017195
ETO-SLL,box,mean,0.9848,0.9452,0.8846,0.7959
ETO-SLL,box,std,0.012363,0.020682,0.02971,0.029248
ETO-SLL,ellipse,mean,0.9911,0.9464,0.8868,0.7955
ETO-SLL,ellipse,std,0.009073,0.028005,0.040052,0.03601
ETO-JC,ellipse,mean,0.9914,0.953,0.902,0.7958
ETO-JC,ellipse,std,0.006433,0.016951,0.02563,0.026102
E2E,box,mean,0.9851,0.9472,0.9065,0.7873
E2E,box,std,0.007295,0.01807,0.0123,0.01729
E2E,ellipse,mean,0.9904,0.9577,0.8987,0.7926
E2E,ellipse,std,0.0071,0.0144,0.0256,0.0351
E2E,picnn,mean,0.9906,0.9541,0.8957,0.805
E2E,picnn,std,0.006415,0.01051,0.01347,0.02662
""")
df = pd.read_csv(buf).set_index(['model', 'set', 'val'])
df

In [None]:
plot_coverage(df, filename='portfolio_syn_coverage')