In [1]:
%matplotlib inline

In [2]:
from matplotlib.pylab import *

In [3]:
repo = "https://raw.githubusercontent.com/nicoguaro/matplotlib_styles/master"
style.use("results/style_sheet.mplstyle")

rc('figure', figsize=(8, 4))
rc('savefig', bbox='tight')
plt.rcParams.update({'font.size': 16})

import numpy as np
import torch
import torch.nn as nn
from models import resnet18_narrow as resnet18
from utils import get_loader
from utils.train_utils import AverageMeter, accuracy
import argparse
from sklearn.model_selection import ParameterGrid
import pickle
from tqdm import tqdm 
import copy
import glob
import numpy as np
import scipy
import copy

In [4]:
# param_grid = {'mo': [0.0, 0.5, 0.9],  # momentum
#               'wd': [0.0, 1e-4, 5e-4],  # weight decay
#               'lr': [7e-3, 0.0085, 1e-2],  # learning rate
#               'bs': [32, 128, 512],  # batch size
#               }
param_grid = {'mo': [0.0, 0.5, 0.9],  # momentum
              'width': [4, 6, 8],  # network width
              'wd': [0.0, 1e-4, 5e-4],  # weight decay
              'lr': [0.01, 0.0075, 0.005],  # learning rate
              'bs': [32, 128, 512],  # batch size
              'skip': [False, True], # skip
              'batchnorm': [False, True]  # batchnorm
              }

In [8]:
# same results as fantastic papers
for x in param_grid.keys():
    print(x, end=', ')
print(' ')
labels = ["$\epsilon$ sharpness", "Pac Bayes", "$||H||_{F}$", "Fisher norm", "Trace", "Local entropy", "Low pass filter"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
idx = 0
# pick a measure
for meas in ["eps_flat", "pac_bayes", "fro_norm", "fim", "local_entropy", "low_pass"]:
    print(f"{meas} & ", end='')
#     fig, ax = plt.subplots()
    plotting_needs = [[], []]
    # pick a hyper-parameter
    for key, value in param_grid.items():
        grid = copy.deepcopy(param_grid)
        del grid[key]
        
        grid = list(ParameterGrid(grid))
        corr = []
        # loop over all other set of hyper-parameters
        for params in grid:
            flat_measure = []
            gen_gap = []
            # and just vary a single hyper-parameter that we picked
            for v in value:
                params[key] = v
                #mnist
#                 name = f"checkpoints/mnist/lenet/" \
#                        f"*_0_{params['mo']}_{params['wd']}" \
#                        f"_{params['lr']}_{params['bs']}_{False}"
                # cifar
                name = f"checkpoints/cifar10/resnet/" \
                       f"*_0_{params['mo']}_{params['width']}_{params['wd']}_" \
                       f"{params['lr']}_{params['bs']}_{params['skip']}_{params['batchnorm']}"
                
                fol = glob.glob(name)[0]
                    
                with open(f"{fol}/run_ms_0/measures.pkl", 'rb') as f:
                    measures = pickle.load(f)
                # discard model with less cross-entropy               
                if measures["train_loss"] > 0.01:
                    continue

                # record flatness and gen_gap for it
                flat_measure.append(measures[meas])
                gen_gap.append((100 - measures["val_acc"]) - (100 - measures["train_acc"]))
            # compute tau and append (this is inner tau in equation 4 of fantastic)
            # just that our tau is not kendall but pearson
            if len(gen_gap) > 1:
                plotting_needs[0] += flat_measure
                plotting_needs[1] += gen_gap
                c = scipy.stats.kendalltau(flat_measure, gen_gap)[0]
                if not math.isnan(c):
                    corr.append(c)
        # this is mean over a picked hyper-parameter
        print(f"{np.mean(corr):0.4f} & ", end='')
    
#     ax.scatter([p/np.max(plotting_needs[0]) for p in plotting_needs[0]], plotting_needs[1], label=key)
#     ax.set_xlabel(f"{labels[idx]}")
    idx+=1
#     ax.set_ylabel("Generalization gap")
#     ax.set_title(f"Generalization vs Sharpness Measure")
#     fig.savefig(f"results/deep_learning/resnet_cifar/figure_{labels[idx-1]}.png")
    print(' ')

mo, width, wd, lr, bs, skip, batchnorm,  
eps_flat & 0.7922 & -0.0253 & 0.2414 & 0.5652 & 0.9342 & -0.2146 & -0.0789 &  
pac_bayes & 0.9486 & -0.6245 & 0.2706 & 0.7640 & 0.9918 & -0.1545 & -0.0704 &  
fro_norm & 0.8621 & 0.0527 & 0.2351 & 0.6273 & 0.9733 & 0.0687 & -0.0704 &  
fim & 0.7449 & 0.1835 & 0.1787 & 0.3747 & 0.7819 & -0.0601 & -0.0789 &  
local_entropy & -0.9625 & 0.6456 & -0.2428 & -0.7975 & -0.9795 & 0.2575 & 0.0789 &  
low_pass & 0.9877 & -0.5211 & 0.3459 & 0.8302 & 0.9938 & 0.5880 & -0.0618 &  


In [9]:
print("Empirical order & ", end=' ')
for key, value in param_grid.items():
    print(f"{key} &", end=' ')
print(' ')

for key, value in param_grid.items():

    grid = copy.deepcopy(param_grid)
    del grid[key]

    grid = list(ParameterGrid(grid))
    corr = []
    # loop over all other set of hyper-parameters
    for params in grid:
        gen_gap = []
        hyp = []
        # and just vary a single hyper-parameter that we picked
        for v in value:
            params[f"{key}"] = v
            #mnist
#             name = f"checkpoints/mnist/lenet/" \
#                    f"*_0_{params['mo']}_{params['wd']}" \
#                    f"_{params['lr']}_{params['bs']}_{False}"
            # cifar
            name = f"checkpoints/cifar10/resnet/" \
                   f"*_0_{params['mo']}_{params['width']}_{params['wd']}_" \
                   f"{params['lr']}_{params['bs']}_{params['skip']}_{params['batchnorm']}"

            fol = glob.glob(name)[0]

            try:
                with open(f"{fol}/run_ms_0/measures.pkl", 'rb') as f:
                    measures = pickle.load(f)
            except:
                print(fol)
                continue

            if measures["train_loss"] > 0.01:
                continue
            else:
                # record flatness and hyper-parameter for it
                gen_gap.append((100 - measures["val_acc"]))
                if v is True:
                    hyp.append(1)
                elif v is False:
                    hyp.append(0)
                else:
                    hyp.append(v)

        # compute tau and append (this is inner tau in equation 4 of fantastic)
        # just that our tau is not kendall but pearson
        if len(gen_gap) > 1:
            c = scipy.stats.kendalltau(hyp, gen_gap)[0]
            if not math.isnan(c):
                corr.append(c)
    # this is mean over a picked hyper-parameter
    print(f"{np.mean(corr):0.4f} & ", end='')

Empirical order &  mo & width & wd & lr & bs & skip & batchnorm &  
-0.9856 & -0.6492 & -0.2989 & -0.8264 & 0.9938 & -0.2532 & -0.0789 & 

In [None]:
print(f"measure, momentum, weight decay, learning rate, batch size")
idx = 0
labels = ["$\epsilon$ sharpness", "Pac Bayes", "frobenius norm", "Fisher norm", "Trace", "Local entropy", "Low pass filter"]
for meas in ["eps_flat", "pac_bayes", "fro_norm", "fim", "eig_trace", "local_entropy", "low_pass"]:
    print(f"{meas} & ", end='')
    for key, value in param_grid.items():
        grid = copy.deepcopy(param_grid)
        del grid[key]
        
        grid = list(ParameterGrid(grid))
        corr = []
        for params in grid:
            flat_measure = []
            hyper_param = []
            for v in value:
                params[f"{key}"] = v
                #mnist
                name = f"checkpoints/mnist/lenet/" \
                       f"*_0_{params['mo']}_{params['wd']}" \
                       f"_{params['lr']}_{params['bs']}_{False}"
                # cifar
#                 name = f"checkpoints/cifar10/resnet/" \
#                        f"*_0_{params['mo']}_{params['width']}_{params['wd']}_" \
#                        f"{params['lr']}_{params['bs']}_{params['skip']}_{params['batchnorm']}"

                fol = glob.glob(name)[0]

                with open(f"{fol}/run_ms_0/measures.pkl", 'rb') as f:
                    measures = pickle.load(f)

                if np.nan in list(measures.values()):
                    continue

                if measures['train_loss'] > 0.01:
                    continue
                else:
                    flat_measure.append(measures[meas])
                    if v is True:
                        hyper_param.append(1)
                    elif v is False:
                        hyper_param.append(0)
                    else:
                        hyper_param.append(v)
            if len(hyper_param)>1:
                corr.append(scipy.stats.pearsonr(hyper_param, flat_measure)[0])

        print(f"{np.mean(corr):0.3f} & ", end='')
    print('')

In [None]:
name