In [84]:
import sys
import os
import torchvision
from torchvision import transforms
sys.path.append(os.getcwd()[:-7])
from sanity_get_data import get_data, CIFAR10_Wrapper, MNIST_Wrapper
from install_packages import install_packages

In [85]:
install_packages()


scikit-learn

numpy

matplotlib

torch

pytorch-lightning

wandb

einops

torchvision

pandas


In [174]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from functools import reduce
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import seaborn as sns
from get_grok import get_data
import wandb
KEY = '8b81e715f744716c02701d1b0a23c4342e62ad45'
wandb.login(key = KEY)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jovyan/.netrc


True

This notebook will go through traditional CIFAR Resnet18 Experiment and traditional MNIST Resnet18 across 3 seeds (515, 650, 713) with the same initialization hyperparameters. At the top of each experiment, there will be a confirmation that iteration 0 and 1 of seed 515 have identical class means across epochs to ensure reproducibility.

### Workhorse Functions

In [87]:
LABEL_INDEX = -3
Y_HAT_INDEX = -2
IDX_INDEX = -1

In [88]:
paths = [f'class_means_distance/grok_toy_test_2_data_0.5_train_fraction_0.5_iter_{x}_max_epochs_11/train/train_raw_activations_epoch_' for x in list(range(5))]

In [89]:
def process_merge(paths, epoch, distance_metric = 'l2', normalization = 'none'):
    all_dfs = []
    for i, path in enumerate(paths):
        dat = torch.load(path + str(epoch), map_location=torch.device('cpu'), weights_only=True)
        labels = dat[: , LABEL_INDEX]
        y_hats = dat[: , Y_HAT_INDEX]
        idx = dat[: , IDX_INDEX]
        activations = dat[: , :-3]
        uniq_labels = np.unique(labels)
        
        global_mean = activations.mean(axis=0)
        ordered_l2 = []
        ordered_labels = []
        ordered_idx = []
        ordered_y_hats = []
        
        for j in uniq_labels:
            labels_indexing = (labels == j)
            class_activations = activations[labels_indexing, :]
            class_mean = class_activations.mean(axis=0) - global_mean

            if normalization == 'class':
                class_std = class_activations.std(axis=0) + 0.0001
                class_activations = class_activations / repeat(class_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / class_std
                # print(class_std)

            elif normalization == 'global':
                global_std = activations.std(axis=0) + 0.0001
                class_activations = class_activations - repeat(global_mean, 'u -> n u', n = class_activations.shape[0])
                class_activations = class_activations / repeat(global_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / global_std

            if distance_metric == 'cosine':
                cos_dist = 1 - torch.nn.functional.cosine_similarity(class_activations, repeat(class_mean, 'u -> n u', n = class_activations.shape[0]))
                _ = [ordered_l2.append(diff) for diff in cos_dist]
            
            elif distance_metric == 'l2':
                class_diff = class_activations - repeat(class_mean, 'u -> n u', n = class_activations.shape[0])
                _ = [ordered_l2.append(torch.linalg.vector_norm(diff)) for diff in class_diff]
            
            _ = [ordered_labels.append(lab) for lab in labels[labels_indexing]]
            _ = [ordered_idx.append(ids) for ids in idx[labels_indexing]]
            _ = [ordered_y_hats.append(y_hat) for y_hat in y_hats[labels_indexing]]
        
        all_together = torch.vstack([torch.tensor(x) for x in [ordered_l2, ordered_labels, ordered_idx, ordered_y_hats]]).T
        df = pd.DataFrame(all_together, columns = [f'dist_{i}', f'labels', f'idx', f'y_hats_{i}'])
        if i != 0:
            df = df.drop(columns=['labels'])
        
        all_dfs.append(df)
    merged_df = reduce(lambda left, right: pd.merge(left, right, on='idx', how='outer'), all_dfs)
    return merged_df

In [90]:
def graph_hists(merged_df):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]
    num_cols = len(l2_columns)
    
    # Create a grid: 2 rows (general + label-wise), N columns (1 per l2_x)
    fig, axes = plt.subplots(2, num_cols, figsize=(5 * num_cols, 10), sharey='row')
    
    # Row 1: General histograms
    for i, col in enumerate(l2_columns):
        ax = axes[0, i]
        sns.histplot(merged_df[col], bins=30, kde=False, ax=ax, color='skyblue')
        ax.set_title(f"{col} - Overall")
        ax.set_xlabel("")
        ax.set_ylabel("Count")
    
    # Row 2: Label-wise histograms
    for i, col in enumerate(l2_columns):
        ax = axes[1, i]
        for label in sorted(merged_df['labels'].unique()):
            subset = merged_df[merged_df['labels'] == label]
            sns.histplot(subset[col], bins=30, kde=False, label=f"Label {label}", alpha=0.5, ax=ax)
        ax.set_title(f"{col} - By Label")
        ax.set_xlabel("Value")
        ax.set_ylabel("Count")
        ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Group by 'labels' and calculate mean and std
    grouped_stats = merged_df.groupby('labels')[l2_columns].agg(['mean', 'std'])
    
    # Optional: flatten MultiIndex columns
    grouped_stats.columns = [f"{col}_{stat}" for col, stat in grouped_stats.columns]
    
    # Display the result
    display(grouped_stats)

In [91]:
def top_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    top_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        top_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(quant)
            idxs = subset[subset[col] >= threshold]['idx'].to_numpy()
            top_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(top_idxs))
        top_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, top_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, top_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return top_10_percent_idxs, union_all, intersection_all

In [92]:
def bottom_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    bottom_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        bottom_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(percent)
            idxs = subset[subset[col] <= threshold]['idx'].to_numpy()
            bottom_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(bottom_idxs))
        bottom_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, bottom_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, bottom_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return bottom_10_percent_idxs, union_all, intersection_all

# CIFAR Across MLP Complexities

In [93]:
cifar_jaccards_top = []
cifar_jaccards_bottom = []

## 2 Hidden Layers

In [94]:
paths = [f'class_means_distance/low_complexity_2_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [95]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [97]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.6221183298403998
Size of Union: 6203
Size of Intersection: 3859


In [98]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.6335865524486827
Size of Union: 6187
Size of Intersection: 3920


### Unnormalized Cos

In [99]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [100]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.39371358894458747
Size of Union: 7381
Size of Intersection: 2906


In [101]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.37593283582089554
Size of Union: 7504
Size of Intersection: 2821


## 3 Hidden Layers

In [102]:
paths = [f'class_means_distance/low_complexity_3_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [103]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [104]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5177824267782427
Size of Union: 6692
Size of Intersection: 3465


In [105]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.54300015167602
Size of Union: 6593
Size of Intersection: 3580


### Unnormalized Cos

In [106]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [107]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.2634274145528165
Size of Union: 8397
Size of Intersection: 2212


In [108]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.27206946454413894
Size of Union: 8292
Size of Intersection: 2256


## 4 Hidden Layers

In [109]:
paths = [f'class_means_distance/low_complexity_4_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [110]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [111]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4828288707799767
Size of Union: 6872
Size of Intersection: 3318


In [112]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4702834940279177
Size of Union: 6949
Size of Intersection: 3268


### Unnormalized Cos

In [113]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [114]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.20237698544929467
Size of Union: 9003
Size of Intersection: 1822


In [115]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.2211702977073115
Size of Union: 8767
Size of Intersection: 1939


## 5 Hidden Layers

In [116]:
paths = [f'class_means_distance/low_complexity_5_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [117]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [118]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41157822191592003
Size of Union: 7255
Size of Intersection: 2986


In [119]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3730528558114765
Size of Union: 7511
Size of Intersection: 2802


### Unnormalized Cos

In [120]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [121]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.16920293710758752
Size of Union: 9397
Size of Intersection: 1590


In [122]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.1824785861433373
Size of Union: 9223
Size of Intersection: 1683


## 6 Hidden Layers

In [123]:
paths = [f'class_means_distance/low_complexity_6_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [124]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [125]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3215100076007094
Size of Union: 7894
Size of Intersection: 2538


In [126]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.2995157084316404
Size of Union: 8053
Size of Intersection: 2412


### Unnormalized Cos

In [127]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [128]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.12065217391304348
Size of Union: 10120
Size of Intersection: 1221


In [129]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.135046919624643
Size of Union: 9804
Size of Intersection: 1324


# MNIST Across MLP Complexities

In [130]:
mnist_jaccards_top = []
mnist_jaccards_bottom = []

## 2 Hidden Layers

In [131]:
paths = [f'class_means_distance/low_complexity_2_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [132]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [133]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5126519474075911
Size of Union: 8062
Size of Intersection: 4133


In [134]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41901245962159667
Size of Union: 8668
Size of Intersection: 3632


### Unnormalized Cos

In [135]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [136]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.5912917653249145
Size of Union: 7602
Size of Intersection: 4495


In [137]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.45082938388625593
Size of Union: 8440
Size of Intersection: 3805


## 3 Hidden Layers

In [138]:
paths = [f'class_means_distance/low_complexity_3_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [139]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [140]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.462850327966607
Size of Union: 8385
Size of Intersection: 3881


In [141]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3410642679093963
Size of Union: 9227
Size of Intersection: 3147


### Unnormalized Cos

In [142]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [143]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.5413040736536764
Size of Union: 7929
Size of Intersection: 4292


In [144]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.4178074312665363
Size of Union: 8693
Size of Intersection: 3632


## 4 Hidden Layers

In [145]:
paths = [f'class_means_distance/low_complexity_4_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [146]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [147]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4401404330017554
Size of Union: 8545
Size of Intersection: 3761


In [148]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.40840086748088117
Size of Union: 8761
Size of Intersection: 3578


### Unnormalized Cos

In [149]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [150]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.45264030310206016
Size of Union: 8446
Size of Intersection: 3823


In [151]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.3464864864864865
Size of Union: 9250
Size of Intersection: 3205


## 5 Hidden Layers

In [152]:
paths = [f'class_means_distance/low_complexity_5_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [153]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [154]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4455026455026455
Size of Union: 8505
Size of Intersection: 3789


In [155]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3837170129140932
Size of Union: 8905
Size of Intersection: 3417


### Unnormalized Cos

In [156]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [157]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.40819578827546954
Size of Union: 8785
Size of Intersection: 3586


In [158]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.2529440870856012
Size of Union: 10105
Size of Intersection: 2556


## 6 Hidden Layers

In [159]:
paths = [f'class_means_distance/low_complexity_6_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [160]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [161]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.435268116784925
Size of Union: 8597
Size of Intersection: 3742


In [162]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41947826086956524
Size of Union: 8625
Size of Intersection: 3618


### Unnormalized Cos

In [163]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [164]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.31007350657417954
Size of Union: 9659
Size of Intersection: 2995


In [165]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.14884932696482847
Size of Union: 11515
Size of Intersection: 1714


# Analysis

In [179]:
from scipy.stats import pearsonr

In [175]:
all_trainable_param_sizes = []

for la in [2,3,4,5,6]:
    layers = []
    layers.append(nn.Linear(32*32*3, 512))
    layers.append(nn.ReLU())
    for i in range(la):
        layers.append(nn.Linear(512, 512))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(512,10))
    model = nn.Sequential(*layers)

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_trainable_param_sizes.append(trainable)

In [184]:
print('CIFAR Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_top, all_trainable_param_sizes))

CIFAR Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9901667015072153), pvalue=np.float64(0.0011688023385244464))


In [185]:
print('CIFAR Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_bottom, all_trainable_param_sizes))

CIFAR Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9990958036279844), pvalue=np.float64(3.263390806716465e-05))


In [186]:
print('MNIST Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_top, all_trainable_param_sizes))

MNIST Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.861220610457522), pvalue=np.float64(0.06075298205715873))


In [187]:
print('MNIST Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_bottom, all_trainable_param_sizes))

MNIST Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(0.20802807277800647), pvalue=np.float64(0.7370534299562757))


In [188]:
merged_df

Unnamed: 0,dist_0,labels,idx,y_hats_0,dist_1,y_hats_1,dist_2,y_hats_2
0,0.243336,5.0,0.0,5.0,0.312373,5.0,0.273135,5.0
1,0.508444,0.0,1.0,0.0,0.421109,0.0,0.555723,0.0
2,0.666010,4.0,2.0,4.0,0.405823,4.0,0.568715,4.0
3,0.454452,1.0,3.0,1.0,0.415198,1.0,0.421538,1.0
4,0.553099,9.0,4.0,9.0,0.471692,9.0,0.511535,9.0
...,...,...,...,...,...,...,...,...
59995,0.061186,8.0,59995.0,8.0,0.086263,8.0,0.076206,8.0
59996,0.307935,3.0,59996.0,3.0,0.348779,3.0,0.266522,3.0
59997,0.238002,5.0,59997.0,5.0,0.221697,5.0,0.252045,5.0
59998,0.555131,6.0,59998.0,6.0,0.400045,6.0,0.782932,6.0
