In [6]:
import altair as alt

In [7]:
cifar10_fot_resnet18 = [0.971, 0.993, 0.969, 0.993, 0.970, 0.992, 0.972, 0.991, 0.970, 0.988, 0.939, 0.968]
cifar10_fot_resnet50 = [0.978, 0.993, 0.975, 0.992, 0.971, 0.991, 0.971, 0.988, 0.971, 0.984, 0.919, 0.931]
cifar10_fot_vgg11    = [0.986, 0.993, 0.985, 0.992, 0.985, 0.992, 0.984, 0.990, 0.982, 0.990, 0.971, 0.978]

In [8]:
cifar10_pn_resnet18 = [0.971, 0.984, 0.957, 0.982, 0.971, 0.975, 0.969, 0.982, 0.965, 0.980, 0.940, 0.963]
cifar10_pn_resnet50 = [0.958, 0.981, 0.965, 0.989, 0.966, 0.985, 0.944, 0.984, 0.938, 0.978, 0.905, 0.942]
cifar10_pn_vgg11    = [0.881, 0.945, 0.741, 0.950, 0.964, 0.963, 0.954, 0.966, 0.950, 0.965, 0.860, 0.956]

In [9]:
cifar100_fot_resnet18 = [0.978, 0.993, 0.971, 0.991, 0.953, 0.987, 0.939, 0.981, 0.921, 0.980, 0.819, 0.926]
cifar100_fot_resnet50 = [0.954, 0.990, 0.934, 0.988, 0.919, 0.986, 0.909, 0.986, 0.901, 0.984, 0.842, 0.932]
cifar100_fot_vgg11    = [0.927, 0.991, 0.881, 0.985, 0.832, 0.976, 0.811, 0.965, 0.783, 0.955, 0.644, 0.807]

In [10]:
cifar100_pn_resnet18 = [0.977, 0.983, 0.972, 0.979, 0.974, 0.983, 0.974, 0.989, 0.956, 0.982, 0.882, 0.950]
cifar100_pn_resnet50 = [0.916, 0.945, 0.919, 0.954, 0.965, 0.987, 0.953, 0.988, 0.947, 0.985, 0.867, 0.939]
cifar100_pn_vgg11    = [0.863, 0.909, 0.809, 0.861, 0.962, 0.976, 0.956, 0.980, 0.914, 0.973, 0.772, 0.953]

In [11]:
def format_results(arr, ds, method, model):
    results = []
    n_samples = [10000, 5000, 2000, 1000, 500, 100]
    for i in range(0, len(arr), 2):
        results.append({
            'ds': ds,
            'method': method,
            'model': model,
            'r_sqr': arr[i],
            'rho': arr[i+1],
            'n_sample': n_samples[i//2]
        })
    return results

### CIFAR-10 Results

In [12]:
cifar10_fot_resnet18_results = format_results(cifar10_fot_resnet18, 'CIFAR-10', 'FOT', 'ResNet18')
cifar10_fot_resnet50_results = format_results(cifar10_fot_resnet50, 'CIFAR-10', 'FOT', 'ResNet50')
cifar10_fot_vgg11_results = format_results(cifar10_fot_vgg11, 'CIFAR-10', 'FOT', 'VGG11')

cifar10_pn_resnet18_results = format_results(cifar10_pn_resnet18, 'CIFAR-10', 'ProjNorm', 'ResNet18')
cifar10_pn_resnet50_results = format_results(cifar10_pn_resnet50, 'CIFAR-10', 'ProjNorm', 'ResNet50')
cifar10_pn_vgg11_results    = format_results(cifar10_pn_vgg11, 'CIFAR-10', 'ProjNorm', 'VGG11')

cifar10_fot_results = cifar10_fot_resnet18_results + cifar10_fot_resnet50_results + cifar10_fot_vgg11_results
cifar10_pn_results = cifar10_pn_resnet18_results + cifar10_pn_resnet50_results + cifar10_pn_vgg11_results

cifar10_results = cifar10_fot_results + cifar10_pn_results

In [13]:
alt.Chart(alt.Data(values=cifar10_results)).mark_line(
    point=alt.OverlayMarkDef(size=50, filled=False, fill='white')
).encode(
    x=alt.X('n_sample:Q', title='number of samples', scale=alt.Scale(type="log", reverse=True), axis=alt.Axis(values=[100, 500, 1000, 2000, 5000, 10000])), 
    y=alt.Y('r_sqr:Q', scale=alt.Scale(domain=[0.7, 1]), title='R\u00b2'),
    color=alt.Color('method:N')
).properties(
    width=150,
    height=150
).facet(
    column='model:N'
)

### CIFAR-100C Results

In [14]:
cifar100_fot_resnet18_results = format_results(cifar100_fot_resnet18, 'CIFAR-100', 'FOT', 'ResNet18')
cifar100_fot_resnet50_results = format_results(cifar100_fot_resnet50, 'CIFAR-100', 'FOT', 'ResNet50')
cifar100_fot_vgg11_results = format_results(cifar100_fot_vgg11, 'CIFAR-100', 'FOT', 'VGG11')

cifar100_pn_resnet18_results = format_results(cifar100_pn_resnet18, 'CIFAR-100', 'ProjNorm', 'ResNet18')
cifar100_pn_resnet50_results = format_results(cifar100_pn_resnet50, 'CIFAR-100', 'ProjNorm', 'ResNet50')
cifar100_pn_vgg11_results    = format_results(cifar100_pn_vgg11, 'CIFAR-100', 'ProjNorm', 'VGG11')

cifar100_fot_results = cifar100_fot_resnet18_results + cifar100_fot_resnet50_results + cifar100_fot_vgg11_results
cifar100_pn_results = cifar100_pn_resnet18_results + cifar100_pn_resnet50_results + cifar100_pn_vgg11_results

cifar100_results = cifar100_fot_results + cifar100_pn_results

In [17]:
alt.Chart(alt.Data(values=cifar100_results)).mark_line(
    point=alt.OverlayMarkDef(size=50, filled=False, fill='white')
).encode(
    x=alt.X('n_sample:Q', title='number of samples', scale=alt.Scale(type="log", reverse=True), axis=alt.Axis(values=[100, 500, 1000, 2000, 5000, 10000])), 
    y=alt.Y('r_sqr:Q', scale=alt.Scale(domain=[0.6, 1]), title='R\u00b2'),
    color=alt.Color('method:N')
).properties(
    width=150,
    height=150
).facet(
    column='model:N'
)