In [None]:
%pylab inline
from mpl_toolkits import mplot3d
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import os
import pandas as pd
import numpy as np

from importlib import reload
import sys
sys.path.append('../../code/scripts')
import utils
import plotting as p

import time
import fit_scaling_law


In [None]:
# 0. setup
group_key = 'genre'
groups = [0,1]
genre_pair = ['history', 'fantasy']

#plotting
scale = 5

genre_id_dict = {}
for k in range(2):
    genre_id_dict[k] = genre_pair[k]
    
group_id_dict = genre_id_dict

## 1. read the results of u-plot and additional data collection

In [None]:
n_train_per_group = 50000
obj_str = 'ERM'
pred_fxn_name = 'logistic_regression'
acc_key = 'mae'
param_dict = {'penalty': ['l2'], 'C': [1.0], 'solver':['lbfgs']}
num_seeds_eval = 10
num_seeds_additional = 2

    
results_general_path = '../../results/subset_results'
results_descriptor = 'goodreads_2k_history_fantasy_{0}_'.format('subsetting')
pred_fxn_base_name = 'subset_{0}'.format(group_key, pred_fxn_name)

this_results_path = os.path.join(results_general_path, results_descriptor + obj_str)
results_path_this_pred_fxn = os.path.join(this_results_path,pred_fxn_base_name, pred_fxn_name)
 
these_keys = tuple([x[0] for x in param_dict.values()])

In [None]:
subset_sizes, accs_by_group, accs_total = [], [], []

reload(utils)
# add in all the data sources
subset_types = ['subsetting',
                'additional', 
                'additional_equal_group_sizes',
               ]
num_seeds_by_subset_type = [10,10,10]

need_to_tile_data = (len(np.unique(num_seeds_by_subset_type)) == 1)
r = 0
for i,subset_type in enumerate(subset_types):
    results_path_this = results_path_this_pred_fxn.replace('subsetting',subset_type)
    r = utils.read_subset_results_nonimage(results_path_this,
                                           param_dict,
                                           by_seed=True,
                                           seed_start = 0,
                                           num_seeds = num_seeds_by_subset_type[i],
                                           acc_keys = ['mae','mse'])[these_keys] 

    subset_sizes.append(r['subset_sizes'])
    accs_by_group.append(r['accs_by_group'])
    accs_total.append(r['accs_total'])
    
r_both = utils.combine_data_results(subset_sizes,
                                    accs_by_group,
                                    accs_total)
subset_sizes_both, accs_by_group_both, accs_total_both = r_both

In [None]:
# check that the scaling pattern looks right
fig, ax = plt.subplots()
plt.scatter(subset_sizes_both[0], subset_sizes_both[1], 
           c = accs_total_both['mae'].mean(axis=1))
ax.set_aspect('equal')
ax.set_xlabel('# training samples group A')
ax.set_ylabel('# training samples total')
plt.colorbar(ax=ax)
plt.title(acc_key)

## 2a. fit scaling rules to data

In [None]:
min_pts_fit = 2500

upper_bound_delta = accs_by_group_both[acc_key].min(axis=1).max()
popts, pcovs = fit_scaling_law.get_group_fits(group_pair = genre_pair,
                                              accs_by_group = accs_by_group_both,
                                              subset_sizes = subset_sizes_both,
                                              acc_key = acc_key,
                                              delta_bounds = [0,upper_bound_delta],
                                              min_pts = min_pts_fit,
                                              # already tiled
                                              need_to_tile_data=need_to_tile_data)

# print line for putting in the overleaf

p.print_table_rows('goodreads', genre_pair, popts, pcovs, min_pts_fit)


## 2b. plot the scaling law fits

In [None]:
num_seeds_eval = 10
for i,subset_type in enumerate(['additional_equal_group_sizes']):
    
    results_path_this = results_path_this_pred_fxn.replace('subsetting',subset_type)
    r = utils.read_subset_results_nonimage(results_path_this,
                                           param_dict,
                                           by_seed=True,
                                           seed_start = 0,
                                           num_seeds = num_seeds_eval,
                                           acc_keys = ['mae','mse'])[these_keys] 

    subset_sizes_plot = r['subset_sizes']
    accs_by_group_plot = r['accs_by_group']
    accs_total_plot = r['accs_total']
        

In [None]:
subset_sizes_plot

In [None]:
fig, ax = p.setup_scaling_plot_ax()
reload(p)
p.plot_scaling_fits(subset_sizes_plot,
                                accs_by_group_plot, 
                                group_names = genre_pair,
                                n_thresh_for_scaling = min_pts_fit,
                                n_thresh_for_plotting = 0,
                                acc_key = acc_key, 
                                popts=popts,
                                loglog=True, 
                                show_data_not_fitted = True,
                                show_fitted_line = True,
                                max_one_group=False,
                                dot_legend = True,
                                full_legend = False,
                                ax=ax)

#ax.set_xlim(500)
ax.set_xlabel(r'\# training points from each group ($n_A = n_B$)');
ax.set_ylabel(r'$\ell_1$ loss');
ax.set_title('Goodreads')

plt.savefig('../../figures/scaling_goodreads.pdf', bbox_inches='tight')

# extra: plot 3d

In [None]:


%pylab widget


show_fitted_line = True

for g in [0,1]:
    plt.figure(figsize=(10,8))
    ax = plt.axes(projection='3d')
    
    if need_to_tile_data:
        ns, y = fit_scaling_law.tile_data(subset_sizes_both.sum(axis=0), 
                          accs_by_group_both[acc_key][g])

        njs, _ = fit_scaling_law.tile_data(subset_sizes_both[g], 
                          accs_by_group_both[acc_key][g])
        
        y_means = accs_by_group_both[acc_key][g].mean(axis=1)
        ax.scatter3D(subset_sizes_both.sum(axis=0), subset_sizes_both[g], 
                y_means, c=y_means, s=100,cmap='inferno')
        

    else:
        ns = subset_sizes_both.sum(axis=0)
        njs = subset_sizes_both[g]
        y = accs_by_group_both[acc_key][g]
        
        ax.scatter3D(subset_sizes_both.sum(axis=0), subset_sizes_both[g], 
                y, c=y, s=100,cmap='inferno')
            
    
    ax.set_xlabel('n')
    ax.set_ylabel('n_j')
    ax.set_zlabel('err')

    if show_fitted_line:
        x_fit, y_fit = np.meshgrid(np.linspace(ns.min(), ns.max(),100),
                                   np.linspace(njs.min(), njs.max(),50))
        
        def f_fit(x_0,y_0):
            return fit_scaling_law.modified_ipl((x_0, y_0), *popts[g])
        
        z_fit = f_fit(y_fit, x_fit)
        
        ax.plot_surface(x_fit, y_fit, z_fit, 
                        rstride=1, cstride=1,cmap='inferno', 
                        edgecolor='none',
                        alpha=0.5,
                        zorder=3)
        
        ax.set_title(genre_id_dict[g])
