In [None]:
%pylab inline

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager

import seaborn as sns

import sys
sys.path.append('../../code/scripts')
import utils
import plotting as p
from dataset_params import dataset_params

from importlib import reload

#reload(dataset_params)

In [None]:
group_id_dict = {0:r' age $< 55$', 1:r' age $\geq 55$'}
group_names = list(group_id_dict.values())
gamma0 = dataset_params['isic']['gamma']
gammas = [gamma0, 1-gamma0]
gammas

# 1. read in results

In [None]:
reload(utils)
r_ERM = utils.read_in_results(group_key='age_over_50_id',
                               results_type='subset',
                               results_identifier='isic_subsetting_< experiment name >_ERM',
                               obj='ERM',
                               sgd_params = {'lr': 0.001, 'weight_decay': 0.0001, 'momentum': 0.9},
                               num_seeds = 10)

r_IS = utils.read_in_results(group_key='age_over_50_id',
                             results_type='subset',
                             results_identifier='isic_subsetting_< experiment name >_IS',
                             sgd_params = {'lr': 0.001, 'weight_decay': 0.001, 'momentum': 0.9},
                             obj='IS',
                             num_seeds = 10)

r_GDRO = utils.read_in_results(group_key='age_over_50_id',
                             results_type='subset',
                             results_identifier='isic_subsetting_< experiment name >_GDRO',
                             sgd_params = {'lr': 0.001, 'weight_decay': 0.0001, 'momentum': 0.9},
                             obj='GDRO_group_adj_1.0_gdro_stepsize_0.1',
                             num_epochs = '20',
                             num_seeds = 10)


# 2. plot results

In [None]:
# preload font sizes, etc.
reload(p)
fig, ax = p.setup_uplot_ax()


ls = ['-','--', ':']
obj_names = ['ERM','IS', 'GDRO', 'ERM' ] 
acc_key_plot = '1 - auc_roc'

results_to_plot = [r_ERM, r_IS, r_GDRO]

for i, results in enumerate(results_to_plot):
    groups, subset_sizes, accs_total, accs_by_group = results

    subset_fracs = subset_sizes / subset_sizes.sum(axis=0)
    accs_total_avgs = {}

    p.plot_by_group(accs_by_group,
                    accs_total,
                    subset_fracs,
                    acc_key_plot, 
                    group_id_dict,
                    gammas = gammas,
                    pop_weights = 'by_eval_set',
                    ls = ls[i],
                    label_append = obj_names[i],
                    range_type='stddev',
                    title='ISIC benign/mal. prediction',
                    ylim=None,
                    #plot_alpha_star=False,
                    plot_gamma=True,
                    plot_alpha_star=True, 
                    lw = 2,
                    gamma_annot_offset = (0.01,0.1),
                    ax = ax)
    
ax.set_ylabel('1 - AUROC')

plt.savefig('../../figures/uplot_isic_aucroc.pdf', bbox_inches='tight')


In [None]:
fig, ax = p.setup_uplot_ax()

ls = ['-','--', ':']
obj_names = ['ERM','IS', 'GDRO', 'ERM' ] 
acc_key_plot = '1 - acc'

results_to_plot = [r_ERM, r_IS, r_GDRO]

for i, results in enumerate(results_to_plot):
    groups, subset_sizes, accs_total, accs_by_group = results

    subset_fracs = subset_sizes / subset_sizes.sum(axis=0)
    accs_total_avgs = {}

    p.plot_by_group(accs_by_group,
                    accs_total,
                    subset_fracs,
                    acc_key_plot, 
                    group_id_dict,
                    gammas = gammas,
                    pop_weights = 'by_eval_set',
                    ls = ls[i],
                    label_append = obj_names[i],
                    range_type='stddev',
                    title='ISIC benign/mal. prediction',
                    ylim=None,
                    plot_gamma=True,
                    plot_alpha_star=True, 
                    gamma_annot_offset = (-0.26,0.055),
                    ax = ax)
    
ax.set_ylabel('0/1 loss')


plt.savefig('../../figures/uplot_isic_acc.pdf', bbox_inches='tight')


# 2. flip results and plot

In [None]:
# preload font sizes, etc.
reload(p)
fig, ax = p.setup_uplot_ax()


ls = ['-','--', ':']
obj_names = ['ERM','IS', 'GDRO', 'ERM' ] 
acc_key_plot = '1 - auc_roc'
#acc_key_replace = '1-aucroc'

results_to_plot = [r_ERM, r_IS, r_GDRO]

for i, results in enumerate(results_to_plot):
    groups, subset_sizes, accs_total, accs_by_group = results

    res_flipped = utils.flip_group_results(accs_by_group, subset_sizes, group_id_dict, gammas, group_names)
    accs_by_group_flipped, subset_size_flipped, group_id_dict_flipped, gammas_flipped, group_names = res_flipped


    subset_fracs_flipped = subset_size_flipped / subset_size_flipped.sum(axis=0)


    p.plot_by_group(accs_by_group_flipped,
                    accs_total,
                    subset_fracs_flipped,
                    acc_key_plot, 
                    group_id_dict_flipped,
                    gammas = gammas_flipped,
                    pop_weights = 'by_eval_set',
                    ls = ls[i],
                    label_append = obj_names[i],
                    range_type='stddev',
                    title='ISIC benign/mal. prediction',
                    ylim=None,
                    #plot_alpha_star=False,
                    plot_gamma=True,
                    plot_alpha_star=True, 
                    lw = 2,
                    gamma_annot_offset = (-0.26,0.1),
                    ax = ax)
    
ax.set_ylabel('1 - AUROC')

plt.savefig('../../figures/uplot_isic_aucroc_flipped.pdf', bbox_inches='tight')


In [None]:
# preload font sizes, etc.
reload(p)
fig, ax = p.setup_uplot_ax()


ls = ['-','--', ':']
obj_names = ['ERM','IS', 'GDRO', 'ERM' ] 
acc_key_plot = '1 - acc'
#acc_key_replace = '1-aucroc'

results_to_plot = [r_ERM, r_IS, r_GDRO]

for i, results in enumerate(results_to_plot):
    groups, subset_sizes, accs_total, accs_by_group = results

    res_flipped = utils.flip_group_results(accs_by_group, subset_sizes, group_id_dict, gammas, group_names)
    accs_by_group_flipped, subset_size_flipped, group_id_dict_flipped, gammas_flipped, group_names = res_flipped


    subset_fracs_flipped = subset_size_flipped / subset_size_flipped.sum(axis=0)


    p.plot_by_group(accs_by_group_flipped,
                    accs_total,
                    subset_fracs_flipped,
                    acc_key_plot, 
                    group_id_dict_flipped,
                    gammas = gammas_flipped,
                    pop_weights = 'by_eval_set',
                    ls = ls[i],
                    label_append = obj_names[i],
                    range_type='stddev',
                    title='ISIC benign/mal. prediction',
                    ylim=None,
                    #plot_alpha_star=False,
                    plot_gamma=True,
                    plot_alpha_star=True, 
                    lw = 2,
                    gamma_annot_offset = (0.01,0.055),
                    ax = ax)
    
ax.set_ylabel('0/1 loss')

plt.savefig('../../figures/uplot_isic_acc_flipped.pdf', bbox_inches='tight')


In [None]:
# preload font sizes, etc.
reload(p)
fig, ax = p.setup_uplot_ax_smaller()


ls = ['-','--', ':']
obj_names = ['ERM','IS', 'GDRO', 'ERM' ] 
acc_key_plot = '1 - auc_roc'
#acc_key_replace = '1-aucroc'

results_to_plot = [r_ERM, r_IS, r_GDRO]

for i, results in enumerate(results_to_plot):
    groups, subset_sizes, accs_total, accs_by_group = results

    res_flipped = utils.flip_group_results(accs_by_group, subset_sizes, group_id_dict, gammas, group_names)
    accs_by_group_flipped, subset_size_flipped, group_id_dict_flipped, gammas_flipped, group_names = res_flipped


    subset_fracs_flipped = subset_size_flipped / subset_size_flipped.sum(axis=0)


    p.plot_by_group(accs_by_group_flipped,
                    accs_total,
                    subset_fracs_flipped,
                    acc_key_plot, 
                    group_id_dict_flipped,
                    gammas = gammas_flipped,
                    pop_weights = 'by_eval_set',
                    ls = ls[i],
                    label_append = obj_names[i],
                    range_type='stddev',
                    title='ISIC benign/mal. prediction',
                    ylim=None,
                    #plot_alpha_star=False,
                    plot_gamma=True,
                    plot_alpha_star=True, 
                    lw = 2,
                    gamma_annot_offset = (-0.26,0.1),
                    group_labels_only=True,
                    ax = ax)
    
ax.set_ylabel('1 - AUROC')

plt.savefig('../../figures/uplot_isic_aucroc_flipped_wide.pdf', bbox_inches='tight')
