In [1]:
import sys
import os
import torchvision
from torchvision import transforms
sys.path.append(os.getcwd()[:-7])
from sanity_get_data import get_data, CIFAR10_Wrapper, MNIST_Wrapper
from install_packages import install_packages

In [2]:
install_packages()


scikit-learn

numpy

matplotlib

torch

pytorch-lightning

wandb

einops

torchvision

pandas


In [2]:
import torch
import pandas as pd
import numpy as np
from functools import reduce
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import seaborn as sns
from get_grok import get_data
import wandb
KEY = '8b81e715f744716c02701d1b0a23c4342e62ad45'
wandb.login(key = KEY)
from helper_process import process_merge, top_x_percent_per_label, bottom_x_percent_per_label

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jovyan/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjmryan[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


This notebook will go through 3 experiments of 'subsetting' the CIFAR data to see how it effects Jaccard similarity. 1k Total train samples takes a random (same across seeds) 1k images (100 from each class) and trains Resnet on these 1k images. 2 Classes combines the 10 classes into 2 based on whether the image is a vehicle or an animal (20k images vs 30k images) and trains Resnet on these new classes. Label Change grabs 10 random images (same across seeds) and randomly selects a new class (purposefully mislabeling) and trains a Resnet on all the data.

In [3]:
LABEL_INDEX = -3
Y_HAT_INDEX = -2
IDX_INDEX = -1

In [4]:
paths = [f'class_means_distance/grok_toy_test_2_data_0.5_train_fraction_0.5_iter_{x}_max_epochs_11/train/train_raw_activations_epoch_' for x in list(range(5))]

In [5]:
def process_merge(paths, epoch, distance_metric = 'l2', normalization = 'none'):
    all_dfs = []
    for i, path in enumerate(paths):
        dat = torch.load(path + str(epoch), map_location=torch.device('cpu'), weights_only=True)
        labels = dat[: , LABEL_INDEX]
        y_hats = dat[: , Y_HAT_INDEX]
        idx = dat[: , IDX_INDEX]
        activations = dat[: , :-3]
        uniq_labels = np.unique(labels)
        
        global_mean = activations.mean(axis=0)
        ordered_l2 = []
        ordered_labels = []
        ordered_idx = []
        ordered_y_hats = []
        
        for j in uniq_labels:
            labels_indexing = (labels == j)
            class_activations = activations[labels_indexing, :]
            class_mean = class_activations.mean(axis=0) - global_mean

            if normalization == 'class':
                class_std = class_activations.std(axis=0) + 0.0001
                class_activations = class_activations / repeat(class_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / class_std
                # print(class_std)

            elif normalization == 'global':
                global_std = activations.std(axis=0) + 0.0001
                class_activations = class_activations - repeat(global_mean, 'u -> n u', n = class_activations.shape[0])
                class_activations = class_activations / repeat(global_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / global_std

            if distance_metric == 'cosine':
                cos_dist = 1 - torch.nn.functional.cosine_similarity(class_activations, repeat(class_mean, 'u -> n u', n = class_activations.shape[0]))
                _ = [ordered_l2.append(diff) for diff in cos_dist]
            
            elif distance_metric == 'l2':
                class_diff = class_activations - repeat(class_mean, 'u -> n u', n = class_activations.shape[0])
                _ = [ordered_l2.append(torch.linalg.vector_norm(diff)) for diff in class_diff]
            
            _ = [ordered_labels.append(lab) for lab in labels[labels_indexing]]
            _ = [ordered_idx.append(ids) for ids in idx[labels_indexing]]
            _ = [ordered_y_hats.append(y_hat) for y_hat in y_hats[labels_indexing]]
        
        all_together = torch.vstack([torch.tensor(x) for x in [ordered_l2, ordered_labels, ordered_idx, ordered_y_hats]]).T
        df = pd.DataFrame(all_together, columns = [f'dist_{i}', f'labels', f'idx', f'y_hats_{i}'])
        if i != 0:
            df = df.drop(columns=['labels'])
        
        all_dfs.append(df)
    merged_df = reduce(lambda left, right: pd.merge(left, right, on='idx', how='outer'), all_dfs)
    return merged_df

In [6]:
def graph_hists(merged_df):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]
    num_cols = len(l2_columns)
    
    # Create a grid: 2 rows (general + label-wise), N columns (1 per l2_x)
    fig, axes = plt.subplots(2, num_cols, figsize=(5 * num_cols, 10), sharey='row')
    
    # Row 1: General histograms
    for i, col in enumerate(l2_columns):
        ax = axes[0, i]
        sns.histplot(merged_df[col], bins=30, kde=False, ax=ax, color='skyblue')
        ax.set_title(f"{col} - Overall")
        ax.set_xlabel("")
        ax.set_ylabel("Count")
    
    # Row 2: Label-wise histograms
    for i, col in enumerate(l2_columns):
        ax = axes[1, i]
        for label in sorted(merged_df['labels'].unique()):
            subset = merged_df[merged_df['labels'] == label]
            sns.histplot(subset[col], bins=30, kde=False, label=f"Label {label}", alpha=0.5, ax=ax)
        ax.set_title(f"{col} - By Label")
        ax.set_xlabel("Value")
        ax.set_ylabel("Count")
        ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Group by 'labels' and calculate mean and std
    grouped_stats = merged_df.groupby('labels')[l2_columns].agg(['mean', 'std'])
    
    # Optional: flatten MultiIndex columns
    grouped_stats.columns = [f"{col}_{stat}" for col, stat in grouped_stats.columns]
    
    # Display the result
    display(grouped_stats)

In [7]:
def top_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    top_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        top_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(quant)
            idxs = subset[subset[col] >= threshold]['idx'].to_numpy()
            top_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(top_idxs))
        top_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, top_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, top_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return top_10_percent_idxs, union_all, intersection_all

In [8]:
def bottom_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    bottom_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        bottom_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(percent)
            idxs = subset[subset[col] <= threshold]['idx'].to_numpy()
            bottom_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(bottom_idxs))
        bottom_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, bottom_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, bottom_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return bottom_10_percent_idxs, union_all, intersection_all

# CIFAR 1k Total Train Samples ResNet18

In [3]:
paths = [f'class_means_distance/sanity_no_shuffle_data_cifar10_1k_resnet18_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [4]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [5]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.09174311926605505
Size of Union: 218
Size of Intersection: 20


In [6]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.02032520325203252
Size of Union: 246
Size of Intersection: 5


### Unnormalized Cos

In [7]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [8]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.12376237623762376
Size of Union: 202
Size of Intersection: 25


In [9]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.07476635514018691
Size of Union: 214
Size of Intersection: 16


# CIFAR 2 Classes (Vehicle vs Animal) Resnet

In [10]:
paths = [f'class_means_distance/sanity_no_shuffle_data_cifar10_2class_resnet18_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [11]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [12]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.13083071941718102
Size of Union: 9883
Size of Intersection: 1293


In [13]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.02896746347941567
Size of Union: 12048
Size of Intersection: 349


### Unnormalized Cos

In [14]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [15]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.10944438961807955
Size of Union: 10133
Size of Intersection: 1109


In [16]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.15636942675159235
Size of Union: 9420
Size of Intersection: 1473


# CIFAR Label Chance Resnet18

In [17]:
paths = [f'class_means_distance/sanity_no_shuffle_data_cifar10_label_change_resnet18_seed_{x}_iter_0_max_epochs_400/train/train_raw_activations_epoch_' for x in [515,650,713]]

## Look at Class Means Across Seeds (Epoch 355)

### Unnormalized L2

In [18]:
epoch = 355
merged_df = process_merge(paths, epoch)

In [19]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.054776626317164505
Size of Union: 11483
Size of Intersection: 629


In [20]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Intersection: {changed_idx in intersection_all}')

Index 9749 in Intersection: False
Index 28898 in Intersection: False
Index 49412 in Intersection: False
Index 9282 in Intersection: False
Index 36567 in Intersection: False
Index 4740 in Intersection: False
Index 39328 in Intersection: False
Index 16927 in Intersection: False
Index 44496 in Intersection: False
Index 2688 in Intersection: False


In [22]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Union: {changed_idx in union_all}')

Index 9749 in Union: False
Index 28898 in Union: False
Index 49412 in Union: True
Index 9282 in Union: False
Index 36567 in Union: True
Index 4740 in Union: False
Index 39328 in Union: True
Index 16927 in Union: False
Index 44496 in Union: True
Index 2688 in Union: False


In [23]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.020079025885009275
Size of Union: 12401
Size of Intersection: 249


In [24]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Intersection: {changed_idx in intersection_all}')

Index 9749 in Intersection: False
Index 28898 in Intersection: False
Index 49412 in Intersection: False
Index 9282 in Intersection: False
Index 36567 in Intersection: False
Index 4740 in Intersection: False
Index 39328 in Intersection: False
Index 16927 in Intersection: False
Index 44496 in Intersection: False
Index 2688 in Intersection: False


In [25]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Union: {changed_idx in union_all}')

Index 9749 in Union: False
Index 28898 in Union: False
Index 49412 in Union: False
Index 9282 in Union: False
Index 36567 in Union: True
Index 4740 in Union: False
Index 39328 in Union: False
Index 16927 in Union: False
Index 44496 in Union: False
Index 2688 in Union: False


In [26]:
changed_idxs = [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]
merged_df[merged_df['idx'].isin(changed_idxs)]

Unnamed: 0,dist_0,labels,idx,y_hats_0,dist_1,y_hats_1,dist_2,y_hats_2
2688,26.927023,2.0,2688.0,2.0,27.282066,2.0,29.58423,2.0
4740,30.494396,0.0,4740.0,0.0,29.411949,0.0,31.859819,0.0
9282,29.290165,3.0,9282.0,3.0,29.947004,3.0,27.852316,3.0
9749,32.424957,5.0,9749.0,5.0,28.669069,5.0,28.992195,5.0
16927,28.671465,5.0,16927.0,5.0,28.765451,5.0,30.716932,5.0
28898,29.686842,6.0,28898.0,6.0,28.213039,6.0,27.726608,6.0
36567,28.830286,5.0,36567.0,5.0,25.062857,5.0,33.718384,5.0
39328,29.588921,9.0,39328.0,9.0,33.928448,9.0,31.063854,9.0
44496,29.698242,1.0,44496.0,1.0,31.317736,1.0,35.274971,1.0
49412,29.55755,9.0,49412.0,9.0,32.537693,9.0,32.055241,9.0


### Unnormalized Cos

In [27]:
epoch = 355
merged_df = process_merge(paths, epoch, distance_metric = 'cosine')

In [28]:
top_10_percent_idxs, union_all, intersection_all = top_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.08595072409253339
Size of Union: 10634
Size of Intersection: 914


In [29]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Intersection: {changed_idx in intersection_all}')

Index 9749 in Intersection: False
Index 28898 in Intersection: False
Index 49412 in Intersection: True
Index 9282 in Intersection: False
Index 36567 in Intersection: False
Index 4740 in Intersection: True
Index 39328 in Intersection: True
Index 16927 in Intersection: False
Index 44496 in Intersection: True
Index 2688 in Intersection: False


In [30]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Union: {changed_idx in union_all}')

Index 9749 in Union: False
Index 28898 in Union: True
Index 49412 in Union: True
Index 9282 in Union: True
Index 36567 in Union: False
Index 4740 in Union: True
Index 39328 in Union: True
Index 16927 in Union: False
Index 44496 in Union: True
Index 2688 in Union: True


In [31]:
bottom_10_percent_idxs, union_all, intersection_all = bottom_x_percent_per_label(merged_df, 0.1)

Jaccard Similarity: 0.07984755530767801
Size of Union: 10758
Size of Intersection: 859


In [32]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Intersection: {changed_idx in intersection_all}')

Index 9749 in Intersection: False
Index 28898 in Intersection: False
Index 49412 in Intersection: False
Index 9282 in Intersection: False
Index 36567 in Intersection: False
Index 4740 in Intersection: False
Index 39328 in Intersection: False
Index 16927 in Intersection: False
Index 44496 in Intersection: False
Index 2688 in Intersection: False


In [33]:
for changed_idx in [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]:
    print(f'Index {changed_idx} in Union: {changed_idx in union_all}')

Index 9749 in Union: False
Index 28898 in Union: False
Index 49412 in Union: False
Index 9282 in Union: False
Index 36567 in Union: False
Index 4740 in Union: False
Index 39328 in Union: False
Index 16927 in Union: False
Index 44496 in Union: False
Index 2688 in Union: False


In [34]:
changed_idxs = [9749, 28898, 49412, 9282, 36567, 4740, 39328, 16927, 44496, 2688]
merged_df[merged_df['idx'].isin(changed_idxs)]

Unnamed: 0,dist_0,labels,idx,y_hats_0,dist_1,y_hats_1,dist_2,y_hats_2
2688,0.683455,2.0,2688.0,2.0,0.568842,2.0,0.714764,2.0
4740,0.741977,0.0,4740.0,0.0,0.785641,0.0,0.764125,0.0
9282,0.666783,3.0,9282.0,3.0,0.779143,3.0,0.695057,3.0
9749,0.658805,5.0,9749.0,5.0,0.622609,5.0,0.559387,5.0
16927,0.581304,5.0,16927.0,5.0,0.631494,5.0,0.626594,5.0
28898,0.711015,6.0,28898.0,6.0,0.726841,6.0,0.645662,6.0
36567,0.563574,5.0,36567.0,5.0,0.54717,5.0,0.550981,5.0
39328,0.635402,9.0,39328.0,9.0,0.694663,9.0,0.722221,9.0
44496,0.650794,1.0,44496.0,1.0,0.656589,1.0,0.702334,1.0
49412,0.703564,9.0,49412.0,9.0,0.691691,9.0,0.666334,9.0
