In [1]:
# Factors of Clusterability Analysis: MLPs

In [2]:
import sys
sys.path.append('..')
import pandas as pd
from tqdm import tqdm
from src.visualization import run_spectral_cluster
from src.experiment_tagging import get_model_path
from src.utils import get_weights_paths

Using TensorFlow backend.


In [3]:
n_clust = 12
n_samples = 50
n_workers = 10
n_reps = 5

model_tags = ('MNIST', 'MNIST+DROPOUT', 'MNIST+L1REG', 'MNIST+L2REG', 'MNIST+MOD-INIT',
              'HALVES-SAME-MNIST', 'HALVES-MNIST',
              'FASHION', 'FASHION+DROPOUT', 'FASHION+L1REG', 'FASHION+L2REG', 'FASHION+MOD-INIT',
              'HALVES-SAME-FASHION', 'HALVES-FASHION',
              'POLY', 'POLY+L1REG', 'POLY+L2REG')
tag_to_net = {'MNIST': 'Control', 'MNIST+DROPOUT': 'Dropout', 'MNIST+L1REG': 'L1 Reg',
              'MNIST+L2REG': 'L2 Reg', 'MNIST+MOD-INIT': 'Clusterable Init',
              'HALVES-SAME-MNIST': 'Halves Same', 'HALVES-MNIST': 'Halves Diff',
              'FASHION': 'Control', 'FASHION+DROPOUT': 'Dropout', 'FASHION+L1REG': 'L1 Reg',
              'FASHION+L2REG': 'L2 Reg', 'FASHION+MOD-INIT': 'Clusterable Init',
              'HALVES-SAME-FASHION': 'Halves Same', 'HALVES-FASHION': 'Halves Diff',
              'POLY': 'Control', 'POLY+L1REG': 'L1 Reg', 'POLY+L2REG': 'L2 Reg'}

model_paths = {tag: get_model_path(tag, filter_='all')[-n_reps:] for tag in model_tags}
assert all([len(mps)==n_reps for mps in model_paths.values()])

clustering_results = {}
clustering_results_pruned = {}

In [4]:
for tag, paths in tqdm(model_paths.items()):
    
    clustering_results[tag] = {}
    if 'POLY' not in tag:
        clustering_results_pruned[tag] = {}

    for rep in range(n_reps):

        weight_paths = get_weights_paths(paths[rep])
        results = run_spectral_cluster(weight_paths[True], n_clusters=n_clust, n_samples=n_samples,
                                       n_workers=n_workers, eigen_solver='arpack')
        clustering_results[tag][rep] = results

        if 'POLY' not in tag:
            results_pruned = run_spectral_cluster(weight_paths[False], n_clusters=n_clust, n_samples=n_samples,
                                                  n_workers=n_workers, eigen_solver='arpack')
            clustering_results_pruned[tag][rep] = results_pruned

all_results = []
for i, res in enumerate([clustering_results, clustering_results_pruned]):
    for tag in res:

        network = tag_to_net[tag]
        if i == 1:
            network += ', Pruning'
        if 'MNIST' in tag:
            dset = 'MNIST'
        elif 'CIFAR10' in tag:
            dset = 'CIFAR-10'
        elif 'FASHION' in tag:
            dset = 'FASHION'
        else:
            dset = 'Polynomials'

        for rep in res[tag]:
            result = {'model': tag,
                      'network_type': 'cnn' if 'CNN' in tag else 'mlp',
                      'Network': network,
                      'Dataset': dset}
            labels, metrics = res[tag][rep]
            result.update(metrics)
            all_results.append(pd.Series(result))

result_df = pd.DataFrame(all_results)
savepath = '../results/clustering_factors_mlp.csv'
result_df.to_csv(savepath)
result_df

  0%|          | 0/17 [00:00<?, ?it/s]

  6%|▌         | 1/17 [41:08<10:58:14, 2468.38s/it]

 12%|█▏        | 2/17 [1:17:12<9:54:15, 2377.02s/it]

 18%|█▊        | 3/17 [1:44:56<8:24:44, 2163.17s/it]

 24%|██▎       | 4/17 [2:10:43<7:08:37, 1978.30s/it]

 29%|██▉       | 5/17 [2:41:12<6:26:43, 1933.62s/it]

 35%|███▌      | 6/17 [3:21:27<6:20:59, 2078.09s/it]

 41%|████      | 7/17 [3:57:08<5:49:29, 2096.97s/it]

 47%|████▋     | 8/17 [4:31:51<5:13:54, 2092.71s/it]

 53%|█████▎    | 9/17 [5:10:47<4:48:46, 2165.79s/it]

 59%|█████▉    | 10/17 [5:38:17<3:54:36, 2010.93s/it]

 65%|██████▍   | 11/17 [6:10:19<3:18:25, 1984.18s/it]

 71%|███████   | 12/17 [6:42:22<2:43:49, 1965.99s/it]

 76%|███████▋  | 13/17 [7:14:56<2:10:48, 1962.21s/it]

 82%|████████▏ | 14/17 [7:48:15<1:38:40, 1973.35s/it]

 88%|████████▊ | 15/17 [8:03:51<55:24, 1662.25s/it]  

 94%|█████████▍| 16/17 [8:17:19<23:25, 1405.77s/it]

100%|██████████| 17/17 [8:30:20<00:00, 1218.35s/it]

100%|██████████| 17/17 [8:30:20<00:00, 1801.20s/it]




Unnamed: 0,model,network_type,Network,Dataset,ncut,ave_in_out,n_samples,mean,stdev,z_score,percentile,train_acc,train_loss,test_acc,test_loss
0,MNIST,mlp,Control,MNIST,10.042213,0.097478,50,10.151821,0.021391,-5.124001,0.019608,0.996617,0.011576,0.9784,0.106379
1,MNIST,mlp,Control,MNIST,10.004540,0.099728,50,10.185344,0.067399,-2.682591,0.019608,0.996800,0.009840,0.9795,0.111574
2,MNIST,mlp,Control,MNIST,10.022471,0.098655,50,10.175329,0.033044,-4.625901,0.019608,0.997400,0.009474,0.9781,0.105555
3,MNIST,mlp,Control,MNIST,10.039690,0.097628,50,10.162452,0.021965,-5.588989,0.019608,0.996933,0.010426,0.9826,0.091993
4,MNIST,mlp,Control,MNIST,10.012309,0.099262,50,10.150767,0.031958,-4.332525,0.019608,0.996600,0.010693,0.9823,0.086990
5,MNIST+DROPOUT,mlp,Dropout,MNIST,10.038707,0.097687,50,10.294030,0.111634,-2.287151,0.019608,0.972350,0.099336,0.9815,0.073181
6,MNIST+DROPOUT,mlp,Dropout,MNIST,9.934570,0.103952,50,10.296853,0.117277,-3.089139,0.019608,0.973117,0.098400,0.9797,0.080660
7,MNIST+DROPOUT,mlp,Dropout,MNIST,9.857192,0.108693,50,10.306801,0.160214,-2.806302,0.019608,0.972567,0.099975,0.9787,0.078930
8,MNIST+DROPOUT,mlp,Dropout,MNIST,9.962812,0.102240,50,10.298360,0.105600,-3.177528,0.019608,0.973200,0.099699,0.9798,0.080059
9,MNIST+DROPOUT,mlp,Dropout,MNIST,10.061293,0.096345,50,10.279509,0.061847,-3.528343,0.019608,0.972917,0.098393,0.9798,0.075725
