In [None]:
%pylab inline
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import pickle
import numpy as np
import os

import sys
sys.path.append('../../code/scripts')
import utils
import fit_scaling_law
import plotting as p

# 1. aggregate data

In [None]:
acc_key = 'acc'
acc_keys = ['auc_roc', 'acc']
group_id_dict_r = {0: 'vehicle', 1: 'animal'}
group_names_r = list(group_id_dict_r.values())
group_key = 'animal'

In [None]:
subset_sizes, accs_by_group, accs_total = [], [], []

subset_types = ['cifar4_additional_< experiment name >_ERM',
                'cifar4_additional_equal_group_sizes_< experiment name >_ERM', 
                'cifar4_subsetting_< experiment name >_ERM'
               ]

num_seeds_by_subset_type = [10,10,10]

for i,subset_type in enumerate(subset_types):
    print(num_seeds_by_subset_type[i])
    r = utils.read_in_results(group_key,
                              results_type = 'subset',
                              results_identifier = subset_type,
                              num_seeds=num_seeds_by_subset_type[i],
                              obj='ERM',
                              acc_keys = ['acc','auc_roc'],
                              sgd_params = {'lr': 0.001, 'weight_decay': 0.0001, 'momentum': 0.9},
                              num_epochs = 20,
                              add_reverse_accs = True)
    subset_sizes.append(r[1])
    accs_total.append(r[2])
    accs_by_group.append(r[3])

r_both = utils.combine_data_results(subset_sizes,
                                    accs_by_group,
                                    accs_total)

subset_sizes_r, accs_by_group_r, accs_total = r_both

In [None]:
acc_key = '1 - acc'
#plt.scatter(subset_sizes[0], subset_sizes[1], c = accs_total[acc_key].mean(axis=1))
#plt.colorbar();

# 1. fit scaling rules to data (flip first)

In [None]:
flipped_results = utils.flip_group_results(accs_by_group_r, 
                                           subset_sizes_r, 
                                           group_id_dict_r, 
                                           [0,0], 
                                           group_names_r)


accs_by_group, subset_sizes, group_id_dict, gammas, group_names = flipped_results

In [None]:
group_id_dict

In [None]:
group_names

In [None]:
min_pts_fit = 500
from importlib import reload
reload(fit_scaling_law)

acc_key = '1 - acc'
group_pair = group_names

popts, pcovs = fit_scaling_law.get_group_fits(group_pair = [0,1],
                                              accs_by_group = accs_by_group,
                                              subset_sizes = subset_sizes,
                                              acc_key = acc_key,
                                              delta_bounds = [0,1],
                                              min_pts = min_pts_fit, 
                                             need_to_tile_data = True)

p.print_table_rows('CIFAR-4', group_names, popts, pcovs, min_pts_fit)


In [None]:
for i,subset_type in enumerate(['cifar4_additional_equal_group_sizes_< experiment name >_ERM']):
    r = utils.read_in_results(group_key,
                      results_type = 'subset',
                      results_identifier = subset_type,
                      num_seeds= 10,
                      obj='ERM',
                      acc_keys = ['acc','auc_roc'],
                      sgd_params = {'lr': 0.001, 'weight_decay': 0.0001, 'momentum': 0.9},
                      num_epochs = 20,
                      add_reverse_accs=True)
    
    _, subset_sizes_plot_r, accs_total_plot_r, accs_by_group_plot_r = r
    
flipped_results_plot = utils.flip_group_results(accs_by_group_plot_r, 
                                                subset_sizes_plot_r, 
                                                group_id_dict_r, 
                                               [0,0], 
                                               group_names_r)

accs_by_group_plot, subset_sizes_plot, _, _, _ = flipped_results_plot    
        

In [None]:
fig, ax = p.setup_scaling_plot_ax()
reload(p)
p.plot_scaling_fits(subset_sizes_plot,
                                accs_by_group_plot, 
                                group_names = group_names,
                                n_thresh_for_scaling = min_pts_fit,
                                n_thresh_for_plotting = 0,
                                acc_key = acc_key, 
                                popts=popts,
                                loglog=True, 
                                show_data_not_fitted = True,
                                show_fitted_line = True,
                                max_one_group=False,
                                dot_legend = True,
                                full_legend = False,
                                #full_legend_loc_outside=False,
                                ax=ax)

ax.set_xlabel(r'\# training points from each group ($n_A = n_B$)');
ax.set_ylabel(r'0/1 loss')
ax.loglog()
ax.set_title('CIFAR-4')
ax.set_ylim(.008)

plt.savefig('../../figures/scaling_cifar4_flipped.pdf', bbox_inches='tight')

In [None]:
%pylab widget

show_fitted_line = True
need_to_tile_data = True

for g in [0,1]:
    plt.figure(figsize=(10,8))
    ax = plt.axes(projection='3d')
    
    if need_to_tile_data:
        ns, y = fit_scaling_law.tile_data(subset_sizes.sum(axis=0), 
                          accs_by_group[acc_key][g])

        njs, _ = fit_scaling_law.tile_data(subset_sizes[g], 
                          accs_by_group[acc_key][g])
        
        y_means = accs_by_group[acc_key][g].mean(axis=1)
        ax.scatter3D(subset_sizes.sum(axis=0), subset_sizes[g], 
                np.log(y_means), c=y_means, s=100,cmap='inferno')
        

    else:
        ns = subset_sizes.sum(axis=0)
        njs = subset_sizes[g]
        y = accs_by_group[acc_key][g]
        
        ax.scatter3D(subset_sizes.sum(axis=0), subset_sizes[g], 
                np.log(y), c=y, s=100,cmap='inferno')
            
    
    ax.set_xlabel('n')
    ax.set_ylabel('n_j')
    ax.set_zlabel('err')

    if show_fitted_line:
        x_fit, y_fit = np.meshgrid(np.linspace(ns.min(), ns.max(),100),
                                   np.linspace(njs.min(), njs.max(),50))
        
        def f_fit(x_0,y_0):
            
            fs_grid = np.ones(x_0.shape)*np.inf
            valid_idxs = np.where(x_0 > 0)
            fs =  fit_scaling_law.modified_ipl((x_0[valid_idxs],
                                                y_0[valid_idxs]), 
                                                *popts[g])
            fs_grid[valid_idxs] = fs
            return np.log(fs_grid)
        
        z_fit = f_fit(y_fit, x_fit)
        
        ax.plot_surface(x_fit, y_fit, z_fit, 
                        rstride=1, cstride=1,cmap='inferno', 
                        edgecolor='none',
                        alpha=0.5,
                        zorder=3)
        
        #ax.set_title(group_id_dict[g])

        #ax.set_zscale('log')