In [1]:
import numpy as np
import math
import json
import glob
from functools import reduce

In [2]:
def load_results(dataset, model, pretrained, ckpt, metric, n_sample, seed):
    if pretrained:
        result_dir = f'../results/{dataset}/pretrained/{model}_{seed}-{ckpt}/{metric}_{n_sample}'
    else:
        result_dir = f'../results/{dataset}/scratch/{model}_{seed}-{ckpt}/{metric}_{n_sample}'
      
    result_fs = glob.glob(f'{result_dir}/*.json')
    results = []
    for file in result_fs:
        with open(file, 'r') as f:
            data = json.load(f)
        results.extend(data)
    
    print(f'{metric}: {len(results)}')
        
    return results

In [3]:
n_sample_dict = {
    'CIFAR-10': -1, 
    'CIFAR-100': -1, 
    'Living-17': -1,
    'Nonliving-26': -1,
    'Entity-13': -1,
    'Entity-30': -1,
    'ImageNet': -1,
    'RxRx1': -1,
    'FMoW': -1,
    'Amazon': -1,
    'CivilComments': -1
}

n_epoch_dict = {
    'CIFAR-10': 300, 
    'CIFAR-100': 300, 
    'Living-17': 450,
    'Nonliving-26': 450,
    'Entity-13': 300,
    'Entity-30': 300,
    'ImageNet': 10,
    'FMoW': 50,
    'RxRx1': 90,
    'Amazon': 3,
    'CivilComments': 5
}

pretrained_dict = {
    'CIFAR-10': False, 
    'CIFAR-100': False, 
    'Living-17': False,
    'Nonliving-26': False,
    'Entity-13': False,
    'Entity-30': False,
    'ImageNet': True,
    'FMoW': True,
    'RxRx1': True,
    'Amazon': True,
    'CivilComments': True
}

In [88]:
metrics = ['AC', 'DoC', 'IM', 'ATC', 'GDE', 'COT']
dataset = 'ImageNet'
arch = 'resnet50' # "distilbert-base-uncased" # 
n_sample = n_sample_dict[dataset]
seeds = [0, 1, 10]
model_ckpt = n_epoch_dict[dataset]
pretrained = pretrained_dict[dataset]
results = []
for seed in seeds: 
    print(f'seed: {seed}')
    results.append(
        sum([load_results(dataset, arch, pretrained, model_ckpt, metric, n_sample, seed) for metric in metrics], [])
    )
    print()

seed: 0
AC: 100
DoC: 100
IM: 100
ATC: 100
GDE: 100
COT: 20

seed: 1
AC: 100
DoC: 100
IM: 100
ATC: 100
GDE: 100
COT: 20

seed: 10
AC: 100
DoC: 100
IM: 100
ATC: 100
GDE: 100
COT: 20



In [89]:
import altair as alt
import pandas as pd

In [90]:
def get_corr_chart(data, subpop, scale_min=0, scale_max=1):
    corr = alt.Chart(alt.Data(values=data), title=subpop).mark_point(size=40, filled=True).encode(
        x=alt.X('metric:Q', scale=alt.Scale(domain=[scale_min, scale_max])),
        y=alt.X('error:Q', scale=alt.Scale(domain=[scale_min, scale_max]), title='Test Error'),
        color=alt.Color('ref:N'),
        shape=alt.Color('ref:N'),
    ).properties(
        width=200,
        height=200
    )
    return corr

In [91]:
seed_ind = 0
same_pop_results =  [i for i in results[seed_ind] if i['subpopulation'] == 'same']
natural_pop_results =  [i for i in results[seed_ind] if i['subpopulation'] == 'natural']
novel_pop_results = [i for i in results[seed_ind] if i['subpopulation'] == 'novel']

In [92]:
scale_min = max( min( min([i['metric'] for i in results[seed_ind]]), min([i['error'] for i in results[seed_ind]]))  - 0.1,  0)
scale_max = min( max( max([i['metric'] for i in results[seed_ind]]), max([i['error'] for i in results[seed_ind]]))  + 0.1,  1)

In [93]:
line = pd.DataFrame({'metric': [scale_min, scale_max], 'error': [scale_min, scale_max]})
line_plot = alt.Chart(line).mark_line(color='black', strokeDash=[5, 8]).encode(
    x='metric',
    y='error',
)

In [94]:
same_corr = get_corr_chart(same_pop_results, 'same', scale_min=scale_min, scale_max=scale_max)
same_plt = same_corr + line_plot

natural_corr = get_corr_chart(natural_pop_results, 'natural', scale_min=scale_min, scale_max=scale_max)
natural_plt = natural_corr + line_plot

novel_corr = get_corr_chart(novel_pop_results, 'novel', scale_min=scale_min, scale_max=scale_max)
novel_plt = novel_corr + line_plot

plt = same_plt | natural_plt | novel_plt

In [95]:
plt.configure_axis(
    labelFontSize=14,
    titleFontSize=16
).configure_axis(
    labelFontSize=14,
    titleFontSize=16,
).configure_legend(
    titleFontSize=14,
    labelFontSize=16
)


  for col_name, dtype in df.dtypes.iteritems():


In [96]:
get_corr_chart(results[seed_ind], dataset, scale_min=scale_min, scale_max=scale_max) + line_plot

In [97]:
def polyfit(x, y, degree=1):
    results = {}

    coeffs = np.polyfit(x, y, degree)

    results['polynomial'] = coeffs.tolist()

    p = np.poly1d(coeffs)

    yhat = p(x)                
    ybar = np.sum(y)/len(y)          
    ssreg = np.sum((yhat - ybar)**2)   
    sstot = np.sum((y - ybar)**2)    
    results['determination'] = ssreg / sstot

    return results

In [98]:
from scipy.stats import spearmanr
import math

In [99]:
def compute_corr_stats(all_results, metric):
    if len(all_results) == 0:
        return
    
    results = [i for i in all_results if i['ref'] == metric]
    d = [i['metric'] for i in results]
    e = [i['error'] for i in results]
    
    if len(results) > 1:
        r2 = polyfit(d, e)['determination']

        coef, p = spearmanr(d, e)
        slope = polyfit(d, e)['polynomial'][0]
        bias = polyfit(d, e)['polynomial'][1]

    yhat = np.array(d)
    y = np.array(e)
    mae = round(np.abs(yhat - y).mean() * 100, 2)
    
    return mae

In [100]:
def generate_summary(results, seeds, subpop):
    pop_sum = []
    for i, seed in enumerate(seeds):
        maes = []
        for metric in metrics:
            pop_result = [r for r in results[i] if r['subpopulation'] == subpop]
            if len(pop_result) == 0:
                return
            maes.append(
                compute_corr_stats(pop_result, metric)
            )
        pop_sum.append(maes)

    pop_sum = np.array(pop_sum).T
    for j, metric in enumerate(metrics):
        print(f'{metric}:', round(pop_sum[j].mean(), 2), u"\u00B1", round(pop_sum[j].std(), 2))
    

In [101]:
print("----- same pop results -----")
generate_summary(results, seeds, 'same')

----- same pop results -----
AC: 15.61 ± 0.26
DoC: 14.12 ± 0.24
IM: 17.01 ± 0.16
ATC: 1.55 ± 0.06
GDE: 14.9 ± 0.2
COT: 4.83 ± 0.13


In [102]:
print("----- natural pop results -----")
generate_summary(results, seeds, 'natural')

----- natural pop results -----
AC: 11.45 ± 0.14
DoC: 9.97 ± 0.17
IM: 11.16 ± 0.25
ATC: 3.49 ± 0.16
GDE: 13.65 ± 0.15
COT: 5.89 ± 0.21


In [103]:
print("----- novel pop results -----")
generate_summary(results, seeds, 'novel')

----- novel pop results -----
