In [1]:
import matplotlib.pyplot as plt
from os import listdir
import os
from os.path import isfile, join
import json
import pandas as pd
import seaborn as sns
import numpy as np
sns.set(rc={'figure.figsize':(1.33*11.7,1.33*8.27)})

In [2]:
paths = [
    '../logfiles/baseline/',
    '../logfiles/experiments/re_pruning/',
    '../logfiles/experiments/gd_top_k_mc_ac_dk/',
    
    '../logfiles/ablation_study/alexnet_mixed/',
    '../logfiles/ablation_study/resnet_mixed/',
    
    '../logfiles/ablation_study/admm_intra/',
    '../logfiles/ablation_study/admm_retrain/',
    
    '../logfiles/ablation_study/gd_top_k/',
    '../logfiles/ablation_study/gd_top_k_mc/',
    '../logfiles/ablation_study/gd_top_k_mc_ac/',
    '../logfiles/ablation_study/gd_top_k_mc_ac_dk/',
    '../logfiles/ablation_study/gd_top_k_mc_ac_dk_admm_intra/',
    '../logfiles/ablation_study/gd_top_k_mc_ac_dk_admm_retrain/',
    
    '../logfiles/ablation_study/re_pruning/',
    '../logfiles/ablation_study/re_pruning_admm_retrain/',
    '../logfiles/ablation_study/re_pruning_admm_intra/',
    '../logfiles/ablation_study/re_pruning_ac/',
    '../logfiles/ablation_study/re_pruning_ac_admm_intra/',
    '../logfiles/ablation_study/re_pruning_ac_admm_retrain/',
    '../logfiles/ablation_study/re_pruning_gd_top_k_mc_ac_dk_admm_intra/',
    '../logfiles/ablation_study/re_pruning_gd_top_k_mc_ac_dk_admm_retrain/'
        ]

In [3]:
def listdirs(rootdir, dirs):
    #https://www.techiedelight.com/list-all-subdirectories-in-directory-python/
    for file in os.listdir(rootdir):
        d = os.path.join(rootdir, file)
        if os.path.isdir(d):
            dirs.append(d+'/')
            #print(d)
            listdirs(d, dirs)
    return dirs
rootdir = '../logfiles/'
dirs = []
dirs = listdirs(rootdir, dirs)
dirs.append(rootdir)

In [4]:
paths = dirs
logs = []
for path in paths:
    print(path, flush=True)
    fnames = [f for f in listdir(path) if isfile(join(path, f))]
    for fname in fnames:
        #if 'vgg8' in fname:
        #    print(fname)
        if 'json' in fname:
            with open(path+fname, 'r') as f:
                logs.append(json.load(f))
        

../logfiles/ablation_study/
../logfiles/ablation_study/admm_intra/
../logfiles/ablation_study/admm_retrain/
../logfiles/ablation_study/alexnet_mixed/
../logfiles/ablation_study/gd_top_k/
../logfiles/ablation_study/gd_top_k_mc/
../logfiles/ablation_study/gd_top_k_mc_ac/
../logfiles/ablation_study/gd_top_k_mc_ac_dk/
../logfiles/ablation_study/gd_top_k_mc_ac_dk_admm_intra/
../logfiles/ablation_study/gd_top_k_mc_ac_dk_admm_retrain/
../logfiles/ablation_study/resnet_mixed/
../logfiles/ablation_study/re_pruning/
../logfiles/ablation_study/re_pruning_ac/
../logfiles/ablation_study/re_pruning_ac_admm_intra/
../logfiles/ablation_study/re_pruning_ac_admm_retrain/
../logfiles/ablation_study/re_pruning_admm_intra/
../logfiles/ablation_study/re_pruning_admm_retrain/
../logfiles/ablation_study/re_pruning_gd_top_k_mc_ac_dk_admm_intra/
../logfiles/ablation_study/re_pruning_gd_top_k_mc_ac_dk_admm_retrain/
../logfiles/baseline/
../logfiles/experiments/
../logfiles/experiments/gd_top_k_mc_ac_dk/
../logfi

In [5]:
def single_eval(dataset, model, name, specs_to_print, results_to_print, plt_corr = True):
    pd_dict = {}
    for log in logs:
        if (log['METADATA']['EXPERIMENT']['dataset'] == dataset and 
            log['METADATA']['EXPERIMENT']['name'] == name and
            log['METADATA']['EXPERIMENT']['model'] == model):
            #if log['LOGDATA']['test_accuracy'][-1] <= 0.1:
            #    continue
            outstring = model.upper() + ' ' + dataset.upper() + ' ' + name.upper() + '\n'
            for key in specs_to_print:
                section = None
                if key in log['METADATA']['SPECIFICATION']:
                    section = 'SPECIFICATION'
                if key in log['METADATA']['EXPERIMENT']:
                    section = 'EXPERIMENT'
                
                outstring += key + ':' + log['METADATA'][section][key] + '\n'
                if key not in pd_dict:
                    pd_dict[key] = []
                pd_dict[key].append(float(log['METADATA'][section][key]))
            #outstring += '\n'
            for key in results_to_print:
                if type(log['LOGDATA'][key]) == type([]):
                    outstring += key + ':' + str(round(log['LOGDATA'][key][-1], 2)) + '\n'
                    if key not in pd_dict:
                        pd_dict[key] = []
                    pd_dict[key].append(log['LOGDATA'][key][-1])
                else:
                    outstring += key + ':' + str(round(log['LOGDATA'][key], 2)) + '\n'
                    if key not in pd_dict:
                        pd_dict[key] = []
                    pd_dict[key].append(log['LOGDATA'][key])
                    
            #    if type(log['LOGDATA'][key]) == type([]):
            #        plt.plot(log['LOGDATA'][key], label=key)
            #plt.title(model.upper() + ' ' + dataset.upper() + ' ' + name.upper())
            #plt.legend()
            #plt.show()
            outstring+='\n'
            print(outstring)
                
    if plt_corr:
        pd_df = pd.DataFrame(pd_dict)
        if len(pd_df) > 0:
            pd_df = pd_df.loc[:, (pd_df != pd_df.iloc[0]).any()] #drop const cols
            sns.heatmap(pd_df.corr(), cbar=True, annot=True, cmap='RdBu')
            plt.title(model.upper() + ' ' + dataset.upper() + ' ' + name.upper())
            plt.xticks(rotation=45) 
            plt.show()
            
def cross_eval(best_results, datasets, models, name, specs_to_print, results_to_print, plt_corr = True):
    pd_dict = {}
    res_dict = {}
    best_config = {}
    prec = 2
    
    for model in models:
        for dataset in datasets:
            if model not in res_dict:
                res_dict[model] = {}
            if dataset not in res_dict[model]:
                res_dict[model][dataset] = {}
            if model not in best_config:
                best_config[model] = {}
            if dataset not in best_config[model]:
                best_config[model][dataset] = {}
            res_dict[model][dataset]['test_accuracy'] = 0.0
            res_dict[model][dataset]['total_su'] = 1.0
            res_dict[model][dataset]['current_su_fwd'] = 1.0
            res_dict[model][dataset]['current_sparsity'] = 0.0
            
            for log in logs:
                if (log['METADATA']['EXPERIMENT']['dataset'] == dataset and 
                    log['METADATA']['EXPERIMENT']['name'] == name and
                    log['METADATA']['EXPERIMENT']['model'] == model):
                    outstring = model.upper() + ' ' + dataset.upper() + ' ' + name.upper() + '\n'
                    #if log['LOGDATA']['test_accuracy'][-1] <= 0.1:
                    #    continue
                    if log['LOGDATA']['test_accuracy'][-1] > 1:
                        log['LOGDATA']['test_accuracy'] = [x * 1e-2 for x in log['LOGDATA']['test_accuracy']]
                        
                    #TODO = selectable . in rounding
                    
                    if round(res_dict[model][dataset]['test_accuracy'],prec) < round(log['LOGDATA']['test_accuracy'][-1],prec):
                        res_dict[model][dataset]['test_accuracy'] = log['LOGDATA']['test_accuracy'][-1]
                        if name != 'baseline':
                            res_dict[model][dataset]['current_su_fwd'] = log['LOGDATA']['current_su_fwd'][-1]
                            res_dict[model][dataset]['total_su'] = log['LOGDATA']['total_su'][-1]
                            res_dict[model][dataset]['current_sparsity'] = log['LOGDATA']['current_sparsity'][-1]
                            best_config[model][dataset]['METADATA'] = log['METADATA']
                    if round(res_dict[model][dataset]['test_accuracy'],prec) == round(log['LOGDATA']['test_accuracy'][-1],prec):
                        if name != 'baseline':
                            if round(res_dict[model][dataset]['total_su'],prec) < round(log['LOGDATA']['total_su'][-1],prec):
                                res_dict[model][dataset]['current_su_fwd'] = log['LOGDATA']['current_su_fwd'][-1]
                                res_dict[model][dataset]['test_accuracy'] = log['LOGDATA']['test_accuracy'][-1]
                                res_dict[model][dataset]['total_su'] = log['LOGDATA']['total_su'][-1]
                                res_dict[model][dataset]['current_sparsity'] = log['LOGDATA']['current_sparsity'][-1]
                                best_config[model][dataset]['METADATA'] = log['METADATA']
                    
                    for key in specs_to_print:
                        section = None
                        if key in log['METADATA']['SPECIFICATION']:
                            section = 'SPECIFICATION'
                        if key in log['METADATA']['EXPERIMENT']:
                            section = 'EXPERIMENT'

                        outstring += key + ':' + log['METADATA'][section][key] + '\n'
                        if key not in pd_dict:
                            pd_dict[key] = []
                        pd_dict[key].append(float(log['METADATA'][section][key]))
                    #outstring += '\n'
                    for key in results_to_print:
                        if type(log['LOGDATA'][key]) == type([]):
                            outstring += key + ':' + str(round(log['LOGDATA'][key][-1], prec)) + '\n'
                            if key not in pd_dict:
                                pd_dict[key] = []
                            pd_dict[key].append(log['LOGDATA'][key][-1])
                        else:
                            outstring += key + ':' + str(round(log['LOGDATA'][key], prec)) + '\n'
                            if key not in pd_dict:
                                pd_dict[key] = []
                            pd_dict[key].append(log['LOGDATA'][key])

                    #    if type(log['LOGDATA'][key]) == type([]):
                    #        plt.plot(log['LOGDATA'][key], label=key)
                    #plt.title(model.upper() + ' ' + dataset.upper() + ' ' + name.upper())
                    #plt.legend()
                    #plt.show()
                    outstring+='\n'
                    #print(outstring)
    for model in res_dict:
        for dataset in res_dict[model]:
            if res_dict[model][dataset]['test_accuracy'] > 0.0:
                #print(res_dict[model][dataset].keys())
                print('EXP.: {}, MODEL: {}, DATA: {}'.format(name, model, dataset))
                print('ACC.: {}, TRAIN SU: {}, INF. SU: {}, SP.: {}'.format(
                      round(res_dict[model][dataset]['test_accuracy'],prec),
                      round(res_dict[model][dataset]['total_su'],prec),
                      round(res_dict[model][dataset]['current_su_fwd'],prec),
                      round(res_dict[model][dataset]['current_sparsity'],prec)))
                #print(best_config[model][dataset])
    print('\n\n')
    if plt_corr:
        pd_df = pd.DataFrame(pd_dict)
        if len(pd_df) > 0:
            '''
            corr = pd_df.corr()
            mask = np.zeros_like(corr, dtype=np.bool)
            mask[np.triu_indices_from(mask)] = True
            pd_df = pd_df.loc[:, (pd_df != pd_df.iloc[0]).any()] #drop const cols
            sns.heatmap(corr, mask=mask, cbar=True, square=True, annot=True, cmap='RdBu')
            plt.title(name.upper())
            plt.xticks(rotation=90) 
            plt.show()
            '''
            #for key in pd_df:
            #    pd_df.boxplot(column=key)
            #    plt.show()
            #pd_df.boxplot()
            #print(pd_df.describe(), flush=True)
            display(pd_df.describe())
            pass
            
    best_results[name] = {'cfg' : best_config, 'res': res_dict}

In [6]:
#TODO visualize hyperparameter distributions
#TODO correlate percentages ADMM as avaerages

In [7]:
datasets = [
    'cifar10', 
    'cifar100', 
    'mnist',
    'imagenet_tiny',
    'imagenet_full',
]
models = [
    #'resnet18', 
    #'resnet20',
    #'resnet32',
    #'resnet34',
    #'resnet50',
    #'alexnet_s', 
    #'alexnet',
    'lenet', 
    #'mobilenet_v2', 
    #'mobilenet_v3_s', 
    #'vgg8',
    #'vgg11', 
    #'vgg13', 
    #'vgg16'
         ]
best_results = {}
plt_corr = True
sns.set(rc={'figure.figsize':(1.33*11.7,1.33*8.27)})
sns.set_theme(style="white")


cross_eval(best_results, datasets, models, 'baseline', ['lr', 'epochs', 'train_batch_size'], ['test_accuracy'], plt_corr)


cross_eval(best_results, datasets, models, 'admm_intra', 
     ['lr', 'pre_epochs', 'epochs', 're_epochs', 'repeat', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'admm_retrain', 
     ['lr', 'pre_epochs', 'epochs', 're_epochs', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)

cross_eval(best_results, datasets, models, 'gd_top_k', ['lr', 'k', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'gd_top_k_mc', ['lr', 'k', 'se', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'gd_top_k_mc_ac', ['lr', 'k', 'se', 'ac', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'gd_top_k_mc_ac_dk', ['lr', 'k', 'se', 'ac', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'gd_top_k_mc_ac_dk_admm_intra', 
     ['lr', 'k', 'se', 'ac', 'pre_epochs', 'epochs', 're_epochs', 'repeat', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 'gd_top_k_mc_ac_dk_admm_retrain',  
     ['lr', 'k', 'se', 'ac', 'pre_epochs', 'epochs', 're_epochs', 'train_batch_size'], 
     ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
      'current_relative_overhead'], plt_corr)


cross_eval(best_results, datasets, models, 're_pruning', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'l1', 'l2', 
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
      'current_sparsity', 'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 're_pruning_ac', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'l1', 'l2', 'ac', 
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
      'current_sparsity', 'current_relative_overhead'], plt_corr)


cross_eval(best_results, datasets, models, 're_pruning_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'repeat', 'pre_epochs', 'epochs', 're_epochs',
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
      'current_sparsity', 'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 're_pruning_admm_retrain', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'pre_epochs', 'epochs', 're_epochs',
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
      'current_sparsity', 'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 're_pruning_ac_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'ac', 'repeat',
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd',  'pre_epochs', 'epochs', 're_epochs',
      'current_sparsity', 'current_relative_overhead'], plt_corr)
cross_eval(best_results, datasets, models, 're_pruning_ac_admm_retrain', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
      'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'ac', 'pre_epochs', 'epochs', 're_epochs',
      'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
      'current_sparsity', 'current_relative_overhead'], plt_corr)



cross_eval(best_results, datasets, models, 're_pruning_gd_top_k_mc_ac_dk_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 
      'metric_q_c', 'scale_l', 'scale_c', 'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 
      'magnitude_t_l', 'ac', 'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 
      'total_su_bwd', 'current_sparsity', 'current_relative_overhead'], plt_corr)

cross_eval(best_results, datasets, models, 're_pruning_gd_top_k_mc_ac_dk_admm_retrain', ['lr', 'prune_epochs', 'metric_q_l', 
      'metric_q_c', 'scale_l', 'scale_c', 'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 
      'magnitude_t_l', 'ac', 'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 
      'total_su_bwd', 'current_sparsity', 'current_relative_overhead'], plt_corr)

EXP.: baseline, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 1.0, INF. SU: 1.0, SP.: 0.0





Unnamed: 0,lr,epochs,train_batch_size,test_accuracy
count,2.0,2.0,2.0,2.0
mean,0.1,10.0,256.0,0.99365
std,0.0,0.0,0.0,0.000354
min,0.1,10.0,256.0,0.9934
25%,0.1,10.0,256.0,0.993525
50%,0.1,10.0,256.0,0.99365
75%,0.1,10.0,256.0,0.993775
max,0.1,10.0,256.0,0.9939


EXP.: admm_intra, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.13, INF. SU: 12.01, SP.: 0.99





Unnamed: 0,lr,pre_epochs,epochs,re_epochs,repeat,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0,31.0
mean,0.001,1.387097,1.387097,1.387097,2.419355,163.096774,0.989848,1.722696,1.451732,1.920139,0.949317,0.03892
std,2.2042479999999997e-19,0.495138,0.495138,0.495138,0.50161,97.536098,0.00275,0.219689,0.114819,0.356833,0.079452,0.037846
min,0.001,1.0,1.0,1.0,2.0,64.0,0.9819,1.370012,1.262184,1.390325,0.701568,0.002847
25%,0.001,1.0,1.0,1.0,2.0,64.0,0.9878,1.568072,1.362519,1.646363,0.949042,0.00755
50%,0.001,1.0,1.0,1.0,2.0,256.0,0.9904,1.671245,1.465781,1.80726,0.985947,0.024962
75%,0.001,2.0,2.0,2.0,3.0,256.0,0.99185,1.870736,1.526635,2.242672,0.985947,0.065417
max,0.001,2.0,2.0,2.0,3.0,256.0,0.9941,2.13142,1.658899,2.598491,0.985947,0.101839


EXP.: admm_retrain, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 1.8, INF. SU: 12.01, SP.: 0.99





Unnamed: 0,lr,pre_epochs,epochs,re_epochs,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,22.0,22.0,22.0,22.0,22.0,22.0,22.0,22.0,22.0,22.0,22.0
mean,0.001,4.409091,7.5,4.090909,151.272727,0.991773,1.369581,1.327132,1.392365,0.9819,0.050091
std,2.219433e-19,2.423282,2.988072,2.091003,97.852261,0.001087,0.19227,0.163022,0.209476,0.00823,0.038347
min,0.001,3.0,3.0,3.0,64.0,0.9897,1.201755,1.189529,1.207955,0.959141,0.00421
25%,0.001,3.0,6.0,3.0,64.0,0.990775,1.221166,1.207571,1.230057,0.983797,0.024935
50%,0.001,3.0,10.0,3.0,64.0,0.99195,1.238255,1.207571,1.254189,0.985947,0.038189
75%,0.001,5.0,10.0,5.0,256.0,0.99265,1.455195,1.401513,1.490844,0.985947,0.099875
max,0.001,10.0,10.0,10.0,256.0,0.9937,1.800106,1.662755,1.877657,0.985947,0.101347


EXP.: gd_top_k, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.07, INF. SU: 1.0, SP.: 0.0





Unnamed: 0,lr,k,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0
mean,0.001,2.333333,128.0,0.691517,1.707067,1.0,2.984932,0.0,0.003813
std,2.264824e-19,0.984732,94.534265,0.438344,0.372986,7.177563e-11,1.315722,0.0,0.002012
min,0.001,1.0,64.0,0.098,1.211814,1.0,1.355356,0.0,0.001242
25%,0.001,1.0,64.0,0.098,1.241419,1.0,1.411841,0.0,0.002636
50%,0.001,3.0,64.0,0.98635,1.833823,1.0,3.187171,0.0,0.00341
75%,0.001,3.0,256.0,0.9889,2.051198,1.0,4.323762,0.0,0.004651
max,0.001,3.0,256.0,0.9909,2.073161,1.0,4.473619,0.0,0.007813


EXP.: gd_top_k_mc, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.06, INF. SU: 1.0, SP.: 0.0





Unnamed: 0,lr,k,se,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0
mean,0.001,2.333333,2.666667,128.0,0.690733,1.569743,1.0,2.426123,0.0,0.004473
std,2.264824e-19,0.984732,0.492366,94.534265,0.437783,0.352069,1.333594e-11,1.044245,0.0,0.002897
min,0.001,1.0,2.0,64.0,0.098,1.070817,1.0,1.110125,0.0,0.002179
25%,0.001,1.0,2.0,64.0,0.098,1.139493,1.0,1.224955,0.0,0.002318
50%,0.001,3.0,3.0,64.0,0.9862,1.693044,1.0,2.590905,0.0,0.003391
75%,0.001,3.0,3.0,256.0,0.99005,1.797027,1.0,2.993625,0.0,0.004829
max,0.001,3.0,3.0,256.0,0.9907,2.055838,1.0,4.354838,0.0,0.010292


EXP.: gd_top_k_mc_ac, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.33, INF. SU: 1.0, SP.: 0.0





Unnamed: 0,lr,k,se,ac,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0
mean,0.001,3.0,3.0,6.5,112.0,0.988619,2.221818,1.0,5.802051,0.0,0.007738
std,0.0,0.0,0.0,3.966527,85.86501,0.001842,0.090178,0.0,0.848892,0.0,0.003429
min,0.001,3.0,3.0,2.0,64.0,0.9852,2.035482,1.0,4.220724,0.0,0.002121
25%,0.001,3.0,3.0,3.5,64.0,0.98745,2.178059,1.0,5.31071,0.0,0.006381
50%,0.001,3.0,3.0,6.0,64.0,0.98895,2.256514,1.0,6.071423,0.0,0.00816
75%,0.001,3.0,3.0,9.0,112.0,0.99005,2.280501,1.0,6.339752,0.0,0.010355
max,0.001,3.0,3.0,12.0,256.0,0.991,2.333686,1.0,7.004769,0.0,0.011683


EXP.: gd_top_k_mc_ac_dk, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.21, INF. SU: 1.0, SP.: 0.0





Unnamed: 0,lr,k,se,ac,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,17.0,17.0,17.0,17.0,17.0,17.0,17.0,17.0,17.0,17.0,17.0
mean,0.001,2.764706,2.764706,8.0,131.764706,0.9665,1.918469,1.0,4.268949,0.0,0.005237
std,0.0,1.09141,0.437237,0.0,94.577699,0.07147,0.388331,8.124844e-10,2.301938,0.0,0.003353
min,0.001,1.0,2.0,8.0,64.0,0.696,1.258558,1.0,1.44542,0.0,0.001131
25%,0.001,3.0,3.0,8.0,64.0,0.9877,1.757525,1.0,2.829072,0.0,0.002193
50%,0.001,3.0,3.0,8.0,64.0,0.9891,1.988371,1.0,3.931031,0.0,0.004571
75%,0.001,3.0,3.0,8.0,256.0,0.9902,2.212975,1.0,5.623648,0.0,0.008742
max,0.001,4.0,3.0,8.0,256.0,0.9915,2.508275,1.0,10.201946,0.0,0.011688


EXP.: gd_top_k_mc_ac_dk_admm_intra, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 4.27, INF. SU: 4.42, SP.: 0.83





Unnamed: 0,lr,k,se,ac,pre_epochs,epochs,re_epochs,repeat,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0
mean,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,146.285714,0.984557,3.45347,1.87216,6.229555,0.957364,0.035163
std,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,102.628317,0.004515,0.635004,0.166327,2.250004,0.058303,0.028431
min,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,64.0,0.978,2.881362,1.68975,4.278066,0.830688,0.005989
25%,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,64.0,0.9819,2.950951,1.755477,4.513933,0.963538,0.014209
50%,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,64.0,0.9866,3.279859,1.804247,5.096503,0.985947,0.031695
75%,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,256.0,0.9874,3.858506,1.990099,7.863354,0.985947,0.046375
max,0.001,3.0,3.0,8.0,1.0,1.0,1.0,3.0,256.0,0.9887,4.394155,2.119972,9.477742,0.985947,0.087291


EXP.: gd_top_k_mc_ac_dk_admm_retrain, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 2.89, INF. SU: 4.16, SP.: 0.92





Unnamed: 0,lr,k,se,ac,pre_epochs,epochs,re_epochs,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0
mean,0.001,3.0,3.0,8.0,2.0,6.0,2.0,128.0,0.974883,2.329637,1.217308,4.860839,0.976017,0.065514
std,2.264824e-19,0.0,0.0,0.0,0.0,0.0,0.0,94.534265,0.024987,0.332491,0.016858,2.440573,0.023462,0.041812
min,0.001,3.0,3.0,8.0,2.0,6.0,2.0,64.0,0.9065,1.77411,1.179107,2.28744,0.918026,0.006305
25%,0.001,3.0,3.0,8.0,2.0,6.0,2.0,64.0,0.978925,2.114851,1.224516,3.322865,0.985947,0.026392
50%,0.001,3.0,3.0,8.0,2.0,6.0,2.0,64.0,0.98455,2.304161,1.224516,4.243915,0.985947,0.076212
75%,0.001,3.0,3.0,8.0,2.0,6.0,2.0,256.0,0.988175,2.474343,1.224516,5.075883,0.985947,0.092646
max,0.001,3.0,3.0,8.0,2.0,6.0,2.0,256.0,0.99,2.88778,1.224516,10.484353,0.985947,0.12277


EXP.: re_pruning, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 4.69, INF. SU: 4.05, SP.: 0.14





Unnamed: 0,lr,prune_epochs,metric_q_l,metric_q_c,scale_l,scale_c,sample_l,sample_c,softness_l,softness_c,...,magnitude_t_l,l1,l2,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,...,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0
mean,0.1,0.4,0.0352,0.185,0.17,0.076,820.0,820.0,0.18,0.18,...,0.001,0.0001,8.0002e-05,160.0,0.97352,13.104644,8.218006,19.987263,0.582957,0.026784
std,0.0,0.843274,0.020714,0.08756,0.075277,0.098793,379.473319,379.473319,0.379473,0.379473,...,2.2856989999999996e-19,0.0,4.215949e-05,101.192885,0.023407,13.221078,7.166728,24.444213,0.419718,0.029152
min,0.1,0.0,0.001,0.05,0.1,0.01,100.0,100.0,0.0,0.0,...,0.001,0.0001,1e-08,64.0,0.9145,2.209322,1.319949,3.176753,0.08722,0.002751
25%,0.1,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,0.001,0.0001,0.0001,64.0,0.971,3.308599,2.982198,3.553956,0.160415,0.011368
50%,0.1,0.0,0.05,0.25,0.15,0.01,1000.0,1000.0,0.0,0.0,...,0.001,0.0001,0.0001,160.0,0.9845,5.119611,3.963887,6.014473,0.641943,0.015797
75%,0.1,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,0.001,0.0001,0.0001,256.0,0.986525,22.265483,14.229037,31.028729,0.992416,0.022041
max,0.1,2.0,0.05,0.25,0.25,0.25,1000.0,1000.0,0.9,0.9,...,0.001,0.0001,0.0001,256.0,0.9905,39.953028,20.09964,78.938928,0.994771,0.091882


EXP.: re_pruning_ac, MODEL: lenet, DATA: mnist
ACC.: 0.98, TRAIN SU: 31.33, INF. SU: 14.29, SP.: 0.99





Unnamed: 0,lr,prune_epochs,metric_q_l,metric_q_c,scale_l,scale_c,sample_l,sample_c,softness_l,softness_c,...,l1,l2,ac,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0,...,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0,13.0
mean,0.1,0.0,0.042308,0.211538,0.146154,0.037692,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,2.923077,152.615385,0.957992,30.214258,13.496152,92.753558,0.992089,0.042538
std,1.444446e-17,0.0,0.01201,0.060048,0.072058,0.043235,0.0,0.0,0.0,0.0,...,1.410592e-20,1.410592e-20,1.037749,99.623908,0.019678,17.771963,6.236558,92.893506,0.003215,0.038859
min,0.1,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,2.0,64.0,0.9334,9.585555,5.85337,14.071709,0.984123,0.007092
25%,0.1,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,2.0,64.0,0.9387,18.324637,9.437538,34.629528,0.991503,0.011997
50%,0.1,0.0,0.05,0.25,0.1,0.01,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,2.0,64.0,0.9541,23.257225,11.252508,50.4533,0.993326,0.02758
75%,0.1,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,4.0,256.0,0.9793,31.329714,14.628076,85.579965,0.994012,0.064893
max,0.1,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,0.0001,0.0001,4.0,256.0,0.9845,69.198344,27.078503,311.333727,0.994987,0.111399


EXP.: re_pruning_admm_intra, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 13.7, INF. SU: 12.02, SP.: 0.99





Unnamed: 0,lr,prune_epochs,metric_q_l,metric_q_c,scale_l,scale_c,sample_l,sample_c,softness_l,softness_c,...,pre_epochs,epochs,re_epochs,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0,...,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0,36.0
mean,0.001,0.0,0.041667,0.208333,0.15,0.04,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,160.0,0.939108,8.799212,6.523557,11.244347,0.986744,inf
std,0.0,0.0,0.011952,0.059761,0.071714,0.043028,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,97.36177,0.206897,9.651804,8.931164,10.594989,0.003261,
min,0.001,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,64.0,0.098,4.658405,3.286416,5.887296,0.985947,0.032019
25%,0.001,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,64.0,0.987475,4.980152,3.601271,6.081007,0.985949,0.032022
50%,0.001,0.0,0.05,0.25,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,160.0,0.989,5.540613,4.078805,7.208564,0.985954,0.080065
75%,0.001,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,256.0,0.9894,8.64846,5.556454,10.991124,0.985972,0.128273
max,0.001,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,256.0,0.9904,49.075764,44.193139,51.945322,1.0,inf


EXP.: re_pruning_admm_retrain, MODEL: lenet, DATA: mnist
ACC.: 0.99, TRAIN SU: 8.18, INF. SU: 12.14, SP.: 0.99





Unnamed: 0,lr,prune_epochs,metric_q_l,metric_q_c,scale_l,scale_c,sample_l,sample_c,softness_l,softness_c,...,pre_epochs,epochs,re_epochs,train_batch_size,test_accuracy,total_su,total_su_fwd,total_su_bwd,current_sparsity,current_relative_overhead
count,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0,...,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0,48.0
mean,0.001,0.0,0.0375,0.1875,0.1375,0.0325,1000.0,1000.0,0.0,0.0,...,1.5,3.5,1.5,184.0,0.909877,7.162793,6.055424,7.963043,0.987244,inf
std,0.0,0.0,0.012632,0.063161,0.065639,0.039384,0.0,0.0,0.0,0.0,...,0.505291,2.526456,0.505291,93.935243,0.247452,4.132728,3.46944,4.731617,0.003891,
min,0.001,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,64.0,0.098,4.601207,3.9721,4.996917,0.985954,0.032032
25%,0.001,0.0,0.025,0.125,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.0,1.0,1.0,64.0,0.9772,5.289321,4.690808,5.553624,0.985993,0.032082
50%,0.001,0.0,0.0375,0.1875,0.1,0.01,1000.0,1000.0,0.0,0.0,...,1.5,3.5,1.5,256.0,0.98325,5.559343,4.937125,5.817858,0.986058,0.032255
75%,0.001,0.0,0.05,0.25,0.1375,0.0325,1000.0,1000.0,0.0,0.0,...,2.0,6.0,2.0,256.0,0.9893,6.560154,5.16047,7.804492,0.986145,0.129386
max,0.001,0.0,0.05,0.25,0.25,0.1,1000.0,1000.0,0.0,0.0,...,2.0,6.0,2.0,256.0,0.9909,22.117961,18.511103,24.505378,1.0,inf


KeyError: 'pre_epochs'

In [None]:
search_n, search_m, search_d = 're_pruning_gd_top_k_mc_ac_dk_admm_intra', 'resnet50', 'cifar10'
for name in best_results:
    if name == search_n:
        for model in best_results[name]['cfg']:
            if search_m == model:
                for data in best_results[name]['cfg'][model]:
                    if search_d == data:
                        if 'METADATA' in best_results[name]['cfg'][model][data]:
                            print(best_results[name]['cfg'][model][data]['METADATA'])
        #print(best_results[name]['cfg'])

In [None]:
import collections

In [None]:
new_logdict = {}
for log in logs:
    for key in log['LOGDATA']:
        new_key = log['METADATA']['EXPERIMENT']['dataset'] + '+'
        new_key += log['METADATA']['EXPERIMENT']['name'] + '+'
        new_key += log['METADATA']['EXPERIMENT']['model']
        if log['LOGDATA']['test_accuracy'][-1] <= 0.1:
                continue
        if new_key not in new_logdict:
            new_logdict[new_key] = {}
        if key not in new_logdict[new_key]:
            new_logdict[new_key][key] = []
            
        if 'overhead' in key:
            #print(sum(log['LOGDATA'][key])/len(log['LOGDATA'][key]), flush=True)
            if ('_mc' in log['METADATA']['EXPERIMENT']['name'] and
                'gd' in log['METADATA']['EXPERIMENT']['name'] and not
                'admm' in log['METADATA']['EXPERIMENT']['name']):
                
                    idx = int((len(log['LOGDATA'][key])*
                           float(log['METADATA']['SPECIFICATION']['se'])/
                           float(log['METADATA']['SPECIFICATION']['epochs'])))+1
                    new_logdict[new_key][key].append(sum(log['LOGDATA'][key][:idx])/len(log['LOGDATA'][key]))
                    #print(log['METADATA']['EXPERIMENT']['name'], 
                    #      sum(log['LOGDATA'][key][:idx])/len(log['LOGDATA'][key]),
                    #      sum(log['LOGDATA'][key])/len(log['LOGDATA'][key]))
            else:
                new_logdict[new_key][key].append(sum(log['LOGDATA'][key])/len(log['LOGDATA'][key]))
        elif 'gradient' in key:
            new_logdict[new_key][key].append(sum(log['LOGDATA'][key])/len(log['LOGDATA'][key]))
        else:
            new_logdict[new_key][key].append(log['LOGDATA'][key][-1])
        
n = 5
prec = 2
x = []
y = []
area = []
names = []
colors = []
for key in new_logdict:
    if (not 'baseline' in key and '+lenet' in key and not 'cifar100' in key):
       #and 're_pruning_gd_top_k_mc_ac_dk_admm_intra' in key):#('+resnet' in key or '+alexnet' in key or '+lenet' in key):
        print(key)
        idx = sorted(range(len(new_logdict[key]['test_accuracy'])), 
                     key=lambda i: new_logdict[key]['test_accuracy'][i])[-n:]
        #idx = idx[:10]
        idx.reverse()
        print('Acc, C/F, P, FLOPs(I), FLOPs(T), G')
        
        y.append(round((1-1/new_logdict[key]['total_su'][0])/
                 (1-new_logdict[key]['test_accuracy'][0]),prec))
        
        x.append(round(1-1/new_logdict[key]['total_su'][0],prec))
        
        area.append(round(new_logdict[key]['test_accuracy'][0]/
                 (1-new_logdict[key]['current_sparsity'][0])*100+1,prec))
        
        colors.append(int(round(new_logdict[key]['current_relative_overhead'][0]/
                                max(new_logdict[key]['current_relative_overhead'])*100,prec)))
        
        name_candidate = key.split('+')[1]
        name_candidate = name_candidate.replace('re_pruning', 'REP')
        name_candidate = name_candidate.replace('gd_top_k_mc', 'GDTopKMC')
        name_candidate = name_candidate.replace('gd_top_k', 'GDTopK')
        name_candidate = name_candidate.replace('ac', 'AC')
        name_candidate = name_candidate.replace('dk', 'DK')
        name_candidate = name_candidate.replace('admm_retrain', 'ADMMR')
        name_candidate = name_candidate.replace('admm_intra', 'ADMMI')
        name_candidate = name_candidate.replace('_', '+')
        names.append(name_candidate)
        for i in idx:
            #print(new_logdict[key].keys()) #current_channel_sparsity', 'current_linear_sparsity #current_relative_overhead'
            print(round(new_logdict[key]['test_accuracy'][i],prec), 
                  round(new_logdict[key]['current_channel_sparsity'][i],prec),
                  round(new_logdict[key]['current_sparsity'][i],prec),
                  round(1-1/new_logdict[key]['current_su_fwd'][i],prec),
                  round(1-1/new_logdict[key]['total_su'][i],prec),
                  round(new_logdict[key]['current_gradient_sparsity'][i],prec),
                  #round(new_logdict[key]['current_relative_overhead'][i]*1e5, prec),
                  #round(1-1/new_logdict[key]['current_su_bwd'][i],prec),
                  #round(new_logdict[key]['total_su'][i],prec),
                  #round(new_logdict[key]['current_sparsity'][i],prec))+
                  )
                  #round(new_logdict[key]['current_relative_overhead'][i],prec))
        print('\n')
    #for subkey in new_logdict[key]:
    #    if not 'features' in subkey and not 'weight' in subkey:
            
sns.set(font_scale=2)
sns.set_style("whitegrid")

plt.scatter(x, y, s=area, c=colors, alpha=0.5)
for i, txt in enumerate(names):
    plt.annotate(txt, (x[i], y[i]))
plt.ylabel('FLOPs(T) per Delta Acc.')
plt.xlabel('FLOPs(T)')
#plt.legend()
plt.show()

datasets = [
    'cifar10', 
    'cifar100', 
    'mnist',
    'imagenet'
]
models = [
    'resnet18', 
    'alexnet_s', 
    'lenet', 
    'mobilenet_v2', 
    'mobilenet_v3', 
    'vgg8',
    'vgg11', 
    'vgg13', 
    'vgg16'
         ]
plt_corr = False

for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 'baseline', ['lr', 'epochs', 'train_batch_size'], ['test_accuracy'], plt_corr)


for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 'admm_intra', 
             ['lr', 'pre_epochs', 'epochs', 're_epochs', 'repeat', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'admm_retrain', 
             ['lr', 'pre_epochs', 'epochs', 're_epochs', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)

for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 'gd_top_k', ['lr', 'k', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'gd_top_k_mc', ['lr', 'k', 'se', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'gd_top_k_mc_ac', ['lr', 'k', 'se', 'ac', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'gd_top_k_mc_ac_dk', ['lr', 'k', 'se', 'ac', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'gd_top_k_mc_ac_dk_admm_intra', 
             ['lr', 'k', 'se', 'ac', 'pre_epochs', 'epochs', 're_epochs', 'repeat', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 'gd_top_k_mc_ac_dk_admm_retrain',  
             ['lr', 'k', 'se', 'ac', 'pre_epochs', 'epochs', 're_epochs', 'train_batch_size'], 
             ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 'current_sparsity', 
              'current_relative_overhead'], plt_corr)


for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 're_pruning', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'l1', 'l2', 
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 're_pruning_ac', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'l1', 'l2', 'ac', 
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)


for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 're_pruning_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'repeat',
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 're_pruning_admm_retrain', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 're_pruning_ac_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'ac', 'repeat',
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)
        single_eval(dataset, model, 're_pruning_ac_admm_retrain', ['lr', 'prune_epochs', 'metric_q_l', 'metric_q_c', 'scale_l', 'scale_c',
              'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 'magnitude_t_l', 'ac', 
              'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 'total_su_bwd', 
              'current_sparsity', 'current_relative_overhead'], plt_corr)


for model in models:
    for dataset in datasets:
        single_eval(dataset, model, 're_pruning_gd_top_k_mc_ac_dk_admm_intra', ['lr', 'prune_epochs', 'metric_q_l', 
              'metric_q_c', 'scale_l', 'scale_c', 'sample_l', 'sample_c', 'softness_l', 'softness_c', 'magnitude_t_c', 
              'magnitude_t_l', 'l1', 'l2', 'ac', 'train_batch_size'], ['test_accuracy', 'total_su', 'total_su_fwd', 
              'total_su_bwd', 'current_sparsity', 'current_relative_overhead'], False)