In [1]:
# Factors of Clusterability Analysis: CIFAR-10 VGGs

In [2]:
import sys
sys.path.append('..')
import pandas as pd
from tqdm import tqdm
from src.visualization import run_spectral_cluster
from src.experiment_tagging import get_model_path
from src.utils import get_weights_paths

Using TensorFlow backend.


In [3]:
n_clust = 12
n_samples = 50
n_workers = 5
n_reps = 5

model_tags = ('CNN-VGG-CIFAR10', 'CNN-VGG-CIFAR10+L1REG', 'CNN-VGG-CIFAR10+L2REG',
              'CNN-VGG-CIFAR10+DROPOUT', 'CNN-VGG-CIFAR10+DROPOUT+L2REG', 'CNN-VGG-CIFAR10+MOD-INIT')
tag_to_net = {'CNN-VGG-CIFAR10': 'Control', 'CNN-VGG-CIFAR10+L1REG': 'L1 Reg',
              'CNN-VGG-CIFAR10+L2REG': 'L2 Reg', 'CNN-VGG-CIFAR10+DROPOUT': 'Dropout',
              'CNN-VGG-CIFAR10+DROPOUT+L2REG': 'Dropout, L2 Reg',
              'CNN-VGG-CIFAR10+MOD-INIT': 'Clusterable Init'}

model_paths = {tag: get_model_path(tag, filter_='all')[-n_reps:] for tag in model_tags}
assert all([len(mps)==n_reps for mps in model_paths.values()])

clustering_results = {}
clustering_results_pruned = {}

In [4]:
for tag, paths in tqdm(model_paths.items()):
    
    clustering_results[tag] = {}
    clustering_results_pruned[tag] = {}

    for rep in range(n_reps):

        weight_paths = get_weights_paths(paths[rep])
        results = run_spectral_cluster(weight_paths[True], n_clusters=n_clust, n_samples=n_samples,
                                       n_workers=n_workers, eigen_solver='arpack')
        clustering_results[tag][rep] = results

        results_pruned = run_spectral_cluster(weight_paths[False], n_clusters=n_clust, n_samples=n_samples,
                                              n_workers=n_workers, eigen_solver='arpack')
        clustering_results_pruned[tag][rep] = results_pruned

all_results = []
for i, res in enumerate([clustering_results, clustering_results_pruned]):
    for tag in res:

        network = tag_to_net[tag]
        if i == 1:
            network += ', Pruning'

        for rep in res[tag]:
            result = {'model': tag,
                      'network_type': 'cnn' if 'CNN' in tag else 'mlp',
                      'Network': network}
            labels, metrics = res[tag][rep]
            result.update(metrics)
            all_results.append(pd.Series(result))

result_df = pd.DataFrame(all_results)
savepath = '../results/clustering_factors_cnn_vgg.csv'
result_df.to_csv(savepath)
result_df

  0%|          | 0/32 [00:00<?, ?it/s]

  3%|▎         | 1/32 [46:57<24:15:35, 2817.26s/it]

  6%|▋         | 2/32 [1:15:55<20:46:48, 2493.61s/it]

  9%|▉         | 3/32 [1:37:18<17:09:43, 2130.47s/it]

 12%|█▎        | 4/32 [2:04:58<15:28:17, 1989.21s/it]

 16%|█▌        | 5/32 [2:25:25<13:12:16, 1760.61s/it]

 19%|█▉        | 6/32 [3:08:56<14:33:26, 2015.63s/it]

 22%|██▏       | 7/32 [3:34:46<13:01:39, 1875.97s/it]

 25%|██▌       | 8/32 [3:54:53<11:10:06, 1675.28s/it]

 28%|██▊       | 9/32 [4:20:03<10:23:14, 1625.85s/it]

 31%|███▏      | 10/32 [4:39:40<9:06:44, 1491.13s/it]

 34%|███▍      | 11/32 [5:25:18<10:52:47, 1865.14s/it]

 38%|███▊      | 12/32 [5:48:50<9:36:20, 1729.04s/it] 

 41%|████      | 13/32 [6:07:44<8:11:00, 1550.53s/it]

 44%|████▍     | 14/32 [6:31:55<7:36:12, 1520.69s/it]

 47%|████▋     | 15/32 [6:54:28<6:56:39, 1470.54s/it]

 50%|█████     | 16/32 [7:14:06<6:08:44, 1382.76s/it]

 53%|█████▎    | 17/32 [7:29:00<5:09:00, 1236.02s/it]

 56%|█████▋    | 18/32 [7:29:13<3:22:50, 869.34s/it] 

 59%|█████▉    | 19/32 [7:29:21<2:12:18, 610.69s/it]

 62%|██████▎   | 20/32 [7:29:27<1:25:51, 429.30s/it]

 66%|██████▌   | 21/32 [7:29:34<55:28, 302.58s/it]  

 69%|██████▉   | 22/32 [7:29:41<35:39, 213.90s/it]

 72%|███████▏  | 23/32 [7:29:54<23:04, 153.84s/it]

 75%|███████▌  | 24/32 [7:30:01<14:38, 109.76s/it]

 78%|███████▊  | 25/32 [7:30:08<09:11, 78.78s/it] 

 81%|████████▏ | 26/32 [7:30:15<05:43, 57.26s/it]

 84%|████████▍ | 27/32 [7:30:22<03:30, 42.18s/it]

 88%|████████▊ | 28/32 [7:30:35<02:14, 33.61s/it]

 91%|█████████ | 29/32 [7:30:42<01:16, 25.63s/it]

 94%|█████████▍| 30/32 [7:30:49<00:40, 20.01s/it]

 97%|█████████▋| 31/32 [7:30:56<00:16, 16.10s/it]

100%|██████████| 32/32 [7:31:03<00:00, 13.30s/it]

100%|██████████| 32/32 [7:31:03<00:00, 845.73s/it]


