In [3]:
import os
from os.path import join as oj
import sys, time
sys.path.insert(1, oj(sys.path[0], '..'))  # insert parent path
import seaborn as sns
from sklearn.model_selection import train_test_split
from regression_dsets_large_names import regression_dsets_large_names
import numpy as np
import matplotlib.pyplot as plt
import pmlb
from tqdm import tqdm
from copy import deepcopy
import pickle as pkl
import pandas as pd
import data
import fit

%matplotlib inline
%load_ext autoreload
%autoreload 2

# load results from a directory

In [4]:
# depending on how much is saved, this may take a while
# pmlb, gaussian
out_dir = '/scratch/users/vision/yu_dl/raaz.rsk/double_descent/pmlb'
save_dir = 'results'
fnames = sorted([fname for fname in os.listdir(out_dir)])
results_list = [pd.Series(pkl.load(open(oj(out_dir, fname), "rb"))) for fname in tqdm(fnames) 
                if not fname.startswith('weights') and not fname.startswith('idx')]
results = pd.concat(results_list, axis=1).T.infer_objects()

100%|██████████| 434/434 [00:00<00:00, 1082.07it/s]


In [7]:
r = results[['dset', 'noise_mult', 'dset_num',
             'n_test', 
             'n_train', 'n_train_over_num_features', 'num_features', 
             'test_mse', 'seed', 'preds_test', 'wnorm', 'H_trace']]

r2 = r.groupby(['dset', 'noise_mult', 'dset_num'])
# plt.plot(r.num_features / r.n_train, np.log(r.test_mse), 'o')

# loop over multiple seeds / curves
R, C = 2, 3

for name, gr in r2:
    curve = gr.groupby(['n_train'])
    dset = gr.dset.values[0]
    dset_num = gr.dset_num.values[0]
    num_features = gr.num_features.values[0] # assume this was held constant (for each dset)
    
    
    # loop over seeds and average / calculate bias stuff
    plt.figure(figsize=(C * 4, R * 4))
    for curve_name, gr2 in tqdm(curve):
        ratio = gr2.num_features.values[0] / gr2.n_train.values[0]        
        preds = gr2.preds_test.values
        preds = np.stack(preds) # num_seeds x n_test
        
        preds_mean = preds.mean(axis=0).reshape(1, -1) # 1 x n_test
        
        if dset == 'gaussian':
            dset_name = ''
            _, y_true = data.get_data(gr2.n_test.values[0], num_features, # this assumes that num_features was held constant
                                      noise_mult=0, seed=703858704)
            y_true = y_true.reshape(1, -1) # 1 x n_test
        elif dset == 'pmlb':
            dset_name = regression_dsets_large_names[dset_num]
            X, y = pmlb.fetch_data(dset_name, return_X_y=True)
            fit.seed(703858704)
            _, _, _, y_true = train_test_split(X, y) # get test set
        
        
        bias = np.mean(preds_mean - y_true)
        
        var = np.mean(np.square(preds - preds_mean))
#         plt.plot(1 / gr2.n_train_over_num_features.mean(), 
#                  gr2.test_mse.mean(), 'o')
        
        plt.subplot(R, C, 1)
        plt.plot(ratio, bias**2 + var, 'o')
        plt.xlabel('p / n')
        plt.ylabel('test mse')    
        
        plt.subplot(R, C, 2)
        plt.plot(ratio, bias, 'o')
        plt.xlabel('p / n')
        plt.ylabel('bias')
        
        plt.subplot(R, C, 3)
        plt.plot(ratio, var, 'o')
        plt.xlabel('p / n')
        plt.ylabel('var')
        
        
        plt.subplot(R, C, 4)
        plt.plot(gr2.wnorm.mean(), bias**2 + var, 'o')
        plt.xlabel('$||\hat{w}||_2$')
        plt.ylabel('test mse')
        
        plt.subplot(R, C, 5)
        plt.plot(gr2.H_trace.mean(), bias**2 + var, 'o')
        plt.xlabel('$tr(H)$')
        plt.ylabel('test mse')        
    
    for i in range(1, 6):
        plt.subplot(R, C, i)
        plt.xscale('log')
        plt.yscale('log')
    
    s = f'dset={dset}_{dset_name}+p={num_features}'
    plt.suptitle(s)
    plt.tight_layout()
    plt.savefig(oj(save_dir, s + '.pdf'))
    plt.close()

100%|██████████| 4/4 [00:02<00:00,  1.68it/s]
100%|██████████| 7/7 [00:02<00:00,  2.86it/s]
100%|██████████| 10/10 [00:04<00:00,  2.97it/s]
100%|██████████| 4/4 [00:05<00:00,  1.31s/it]
100%|██████████| 4/4 [00:01<00:00,  2.34it/s]
100%|██████████| 5/5 [00:02<00:00,  2.41it/s]
100%|██████████| 10/10 [00:05<00:00,  2.07it/s]
100%|██████████| 9/9 [00:26<00:00,  2.61s/it]
100%|██████████| 9/9 [00:02<00:00,  3.87it/s]
  "Data has no positive values, and therefore cannot be "
