In [None]:
# Factors of Clusterability Analysis: small CNNs

In [None]:
import sys
sys.path.append('..')
import pandas as pd
from tqdm import tqdm
from src.visualization import run_spectral_cluster
from src.experiment_tagging import get_model_path
from src.utils import get_weights_paths

In [None]:
n_clust = 12
n_samples = 50
n_workers = 10
n_reps = 5

model_tags = ('CNN-MNIST', 'CNN-MNIST+DROPOUT', 'CNN-MNIST+L1REG', 'CNN-MNIST+L2REG',
              'CNN-MNIST+MOD-INIT', 'CNN-STACKED-MNIST', 'CNN-STACKED-SAME-MNIST',
              'CNN-FASHION', 'CNN-FASHION+DROPOUT', 'CNN-FASHION+L1REG', 'CNN-FASHION+L2REG',
              'CNN-FASHION+MOD-INIT', 'CNN-STACKED-FASHION', 'CNN-STACKED-SAME-FASHION')
tag_to_net = {'CNN-MNIST': 'Control', 'CNN-MNIST+DROPOUT': 'Dropout', 'CNN-MNIST+L1REG': 'L1 Reg',
              'CNN-MNIST+L2REG': 'L2 Reg', 'CNN-MNIST+MOD-INIT': 'Clusterable Init',
              'CNN-STACKED-SAME-MNIST': 'Stacked Same', 'CNN-STACKED-MNIST': 'Stacked Diff',
              'CNN-FASHION': 'Control', 'CNN-FASHION+DROPOUT': 'Dropout', 'CNN-FASHION+L1REG': 'L1 Reg',
              'CNN-FASHION+L2REG': 'L2 Reg', 'CNN-FASHION+MOD-INIT': 'Clusterable Init',
              'CNN-STACKED-SAME-FASHION': 'Stacked Same', 'CNN-STACKED-FASHION': 'Stacked Diff'}

model_paths = {tag: get_model_path(tag, filter_='all')[-n_reps:] for tag in model_tags}
assert all([len(mps)==n_reps for mps in model_paths.values()])

clustering_results = {}
clustering_results_pruned = {}

In [None]:
for tag, paths in tqdm(model_paths.items()):

    clustering_results[tag] = {}
    clustering_results_pruned[tag] = {}

    for rep in range(n_reps):

        weight_paths = get_weights_paths(paths[rep])
        results = run_spectral_cluster(weight_paths[True], n_clusters=n_clust, n_samples=n_samples,
                                       n_workers=n_workers, eigen_solver='arpack')
        clustering_results[tag][rep] = results

        results_pruned = run_spectral_cluster(weight_paths[False], n_clusters=n_clust, n_samples=n_samples,
                                              n_workers=n_workers, eigen_solver='arpack')
        clustering_results_pruned[tag][rep] = results_pruned

all_results = []
for i, res in enumerate([clustering_results, clustering_results_pruned]):
    for tag in res:

        network = tag_to_net[tag]
        if i == 1:
            network += ', Pruning'
        if 'MNIST' in tag:
            dset = 'MNIST'
        elif 'CIFAR10' in tag:
            dset = 'CIFAR-10'
        else:
            dset = 'FASHION'

        for rep in res[tag]:
            result = {'model': tag,
                      'network_type': 'cnn' if 'CNN' in tag else 'mlp',
                      'Network': network,
                      'Dataset': dset}
            labels, metrics = res[tag][rep]
            result.update(metrics)
            all_results.append(pd.Series(result))

result_df = pd.DataFrame(all_results)
savepath = '../results/clustering_factors_cnn.csv'
result_df.to_csv(savepath)
result_df