In [1]:
import sys
import os
import torchvision
from torchvision import transforms
sys.path.append(os.getcwd()[:-7])
from sanity_get_data import get_data, CIFAR10_Wrapper, MNIST_Wrapper
from install_packages import install_packages

In [85]:
install_packages()


scikit-learn

numpy

matplotlib

torch

pytorch-lightning

wandb

einops

torchvision

pandas


In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from functools import reduce
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import seaborn as sns
from get_grok import get_data
import wandb
KEY = '8b81e715f744716c02701d1b0a23c4342e62ad45'
wandb.login(key = KEY)
from helper_process import process_merge, top_x_percent_per_label, bottom_x_percent_per_label

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jovyan/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjmryan[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


This notebook will go through CIFAR and MNIST flattened to go through a simple MLP with varying hidden layers (x in [2,3,4,5,6] so x+2 total layers) where each hidden layer is dim 512 to 512. This goes across 3 seeds (515, 650, 713) with the same initialization hyperparameters. At the end is analysis of how jaccard similarity correlates with total number of trainable model parameters

# CIFAR Across MLP Complexities

In [3]:
cifar_jaccards_top = []
cifar_jaccards_bottom = []
cifar_jaccards_top_cos = []
cifar_jaccards_bottom_cos = []

## 2 Hidden Layers

In [4]:
paths = [f'class_means_distance/low_complexity_2_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [5]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [6]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.6221183298403998
Size of Union: 6203
Size of Intersection: 3859


In [7]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.6335865524486827
Size of Union: 6187
Size of Intersection: 3920


### Unnormalized Cos

In [8]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [9]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.39371358894458747
Size of Union: 7381
Size of Intersection: 2906


In [10]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.37593283582089554
Size of Union: 7504
Size of Intersection: 2821


## 3 Hidden Layers

In [11]:
paths = [f'class_means_distance/low_complexity_3_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [12]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [13]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5177824267782427
Size of Union: 6692
Size of Intersection: 3465


In [14]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.54300015167602
Size of Union: 6593
Size of Intersection: 3580


### Unnormalized Cos

In [15]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [16]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.2634274145528165
Size of Union: 8397
Size of Intersection: 2212


In [17]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.27206946454413894
Size of Union: 8292
Size of Intersection: 2256


## 4 Hidden Layers

In [18]:
paths = [f'class_means_distance/low_complexity_4_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [19]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [20]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4828288707799767
Size of Union: 6872
Size of Intersection: 3318


In [21]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4702834940279177
Size of Union: 6949
Size of Intersection: 3268


### Unnormalized Cos

In [22]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [23]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.20237698544929467
Size of Union: 9003
Size of Intersection: 1822


In [24]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.2211702977073115
Size of Union: 8767
Size of Intersection: 1939


## 5 Hidden Layers

In [25]:
paths = [f'class_means_distance/low_complexity_5_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [26]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [27]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41157822191592003
Size of Union: 7255
Size of Intersection: 2986


In [28]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3730528558114765
Size of Union: 7511
Size of Intersection: 2802


### Unnormalized Cos

In [29]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [30]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.16920293710758752
Size of Union: 9397
Size of Intersection: 1590


In [31]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.1824785861433373
Size of Union: 9223
Size of Intersection: 1683


## 6 Hidden Layers

In [32]:
paths = [f'class_means_distance/low_complexity_6_data_cifar10_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [33]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [34]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3215100076007094
Size of Union: 7894
Size of Intersection: 2538


In [35]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.2995157084316404
Size of Union: 8053
Size of Intersection: 2412


### Unnormalized Cos

In [36]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [37]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.12065217391304348
Size of Union: 10120
Size of Intersection: 1221


In [38]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
cifar_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.135046919624643
Size of Union: 9804
Size of Intersection: 1324


# MNIST Across MLP Complexities

In [39]:
mnist_jaccards_top = []
mnist_jaccards_bottom = []
mnist_jaccards_top_cos = []
mnist_jaccards_bottom_cos = []

## 2 Hidden Layers

In [40]:
paths = [f'class_means_distance/low_complexity_2_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [41]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [42]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5126519474075911
Size of Union: 8062
Size of Intersection: 4133


In [43]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41901245962159667
Size of Union: 8668
Size of Intersection: 3632


### Unnormalized Cos

In [44]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [45]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5912917653249145
Size of Union: 7602
Size of Intersection: 4495


In [46]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.45082938388625593
Size of Union: 8440
Size of Intersection: 3805


## 3 Hidden Layers

In [47]:
paths = [f'class_means_distance/low_complexity_3_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [48]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [49]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.462850327966607
Size of Union: 8385
Size of Intersection: 3881


In [50]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3410642679093963
Size of Union: 9227
Size of Intersection: 3147


### Unnormalized Cos

In [51]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [52]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.5413040736536764
Size of Union: 7929
Size of Intersection: 4292


In [53]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4178074312665363
Size of Union: 8693
Size of Intersection: 3632


## 4 Hidden Layers

In [54]:
paths = [f'class_means_distance/low_complexity_4_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [55]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [56]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4401404330017554
Size of Union: 8545
Size of Intersection: 3761


In [57]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.40840086748088117
Size of Union: 8761
Size of Intersection: 3578


### Unnormalized Cos

In [58]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [59]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.45264030310206016
Size of Union: 8446
Size of Intersection: 3823


In [60]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3464864864864865
Size of Union: 9250
Size of Intersection: 3205


## 5 Hidden Layers

In [61]:
paths = [f'class_means_distance/low_complexity_5_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [62]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [63]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.4455026455026455
Size of Union: 8505
Size of Intersection: 3789


In [64]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.3837170129140932
Size of Union: 8905
Size of Intersection: 3417


### Unnormalized Cos

In [65]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [66]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.40819578827546954
Size of Union: 8785
Size of Intersection: 3586


In [67]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.2529440870856012
Size of Union: 10105
Size of Intersection: 2556


## 6 Hidden Layers

In [68]:
paths = [f'class_means_distance/low_complexity_6_data_mnist_mlp_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [69]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [70]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.435268116784925
Size of Union: 8597
Size of Intersection: 3742


In [71]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.41947826086956524
Size of Union: 8625
Size of Intersection: 3618


### Unnormalized Cos

In [72]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [73]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_top_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.31007350657417954
Size of Union: 9659
Size of Intersection: 2995


In [74]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)
mnist_jaccards_bottom_cos.append(len(intersection_all)/len(union_all))

Jaccard Similarity: 0.14884932696482847
Size of Union: 11515
Size of Intersection: 1714


# Analysis

In [75]:
from scipy.stats import pearsonr

In [76]:
all_trainable_param_sizes = []

for la in [2,3,4,5,6]:
    layers = []
    layers.append(nn.Linear(32*32*3, 512))
    layers.append(nn.ReLU())
    for i in range(la):
        layers.append(nn.Linear(512, 512))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(512,10))
    model = nn.Sequential(*layers)

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_trainable_param_sizes.append(trainable)

In [77]:
print('CIFAR Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_top, all_trainable_param_sizes))

CIFAR Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9901667015072153), pvalue=np.float64(0.0011688023385244464))


In [78]:
print('CIFAR Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_bottom, all_trainable_param_sizes))

CIFAR Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9990958036279844), pvalue=np.float64(3.263390806716465e-05))


In [79]:
print('CIFAR Top 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_top_cos, all_trainable_param_sizes))

CIFAR Top 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9616980528454018), pvalue=np.float64(0.008946520177047286))


In [80]:
print('CIFAR Bottom 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(cifar_jaccards_bottom_cos, all_trainable_param_sizes))

CIFAR Bottom 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.977735785738284), pvalue=np.float64(0.003974561305189631))


In [81]:
print('MNIST Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_top, all_trainable_param_sizes))

MNIST Top 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.861220610457522), pvalue=np.float64(0.06075298205715873))


In [82]:
print('MNIST Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_bottom, all_trainable_param_sizes))

MNIST Bottom 10% of Unnormed L2 Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(0.20802807277800647), pvalue=np.float64(0.7370534299562757))


In [83]:
print('MNIST Top 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_top_cos, all_trainable_param_sizes))

MNIST Top 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9930153226769322), pvalue=np.float64(0.0006999996850230254))


In [84]:
print('MNIST Bottom 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:')
print(pearsonr(mnist_jaccards_bottom_cos, all_trainable_param_sizes))

MNIST Bottom 10% of Unnormed Cosine Distance to Class Means Correlation with Model Trainable Parameter Size:
PearsonRResult(statistic=np.float64(-0.9834431881278443), pvalue=np.float64(0.002551040794614565))
