In [1]:
import sys
import os
import torchvision
from torchvision import transforms
sys.path.append(os.getcwd()[:-7])
from sanity_get_data import get_data, CIFAR10_Wrapper, MNIST_Wrapper
from install_packages import install_packages

In [2]:
install_packages()


scikit-learn

numpy

matplotlib

torch

pytorch-lightning

wandb

einops

torchvision

pandas


In [3]:
import torch
import pandas as pd
import numpy as np
from functools import reduce
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import seaborn as sns
from get_grok import get_data
import wandb
KEY = '8b81e715f744716c02701d1b0a23c4342e62ad45'
wandb.login(key = KEY)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jovyan/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjmryan[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Toy Test 2: Ensuring Grok Manual Seeding Works

This sanity check does 2 things. First, it makes sure that the data and the 'original' indices and the labels are organized and tracked properly through training and processing. Second, it makes sure that the manual seeding for grok works so that the train val split is the same for every initialization.

In [15]:
LABEL_INDEX = -3
Y_HAT_INDEX = -2
IDX_INDEX = -1

In [4]:
paths = [f'class_means_distance/grok_toy_test_2_data_0.5_train_fraction_0.5_iter_{x}_max_epochs_11/train/train_raw_activations_epoch_' for x in list(range(5))]

In [8]:
all_idxs = []
epoch = 10
for i, path in enumerate(paths):
    dat = torch.load(path + str(epoch), map_location=torch.device('cpu'), weights_only=True)
    print(dat.shape)
    labels = dat[: , LABEL_INDEX]
    y_hats = dat[: , Y_HAT_INDEX]
    idx = dat[: , IDX_INDEX]
    all_idxs.append(idx)
    activations = dat[: , :-3]

intersection_all = reduce(np.intersect1d, all_idxs)
union_all = reduce(np.union1d, all_idxs)
len(intersection_all) == len(union_all)

torch.Size([33, 131])
torch.Size([33, 131])
torch.Size([33, 131])
torch.Size([33, 131])
torch.Size([33, 131])


In [13]:
intersection_all = reduce(np.intersect1d, all_idxs)
union_all = reduce(np.union1d, all_idxs)

In [14]:
len(intersection_all) == len(union_all)

True

In [16]:
def process_merge(paths, epoch, distance_metric = 'l2', normalization = 'none'):
    all_dfs = []
    for i, path in enumerate(paths):
        dat = torch.load(path + str(epoch), map_location=torch.device('cpu'), weights_only=True)
        labels = dat[: , LABEL_INDEX]
        y_hats = dat[: , Y_HAT_INDEX]
        idx = dat[: , IDX_INDEX]
        activations = dat[: , :-3]
        uniq_labels = np.unique(labels)
        
        global_mean = activations.mean(axis=0)
        ordered_l2 = []
        ordered_labels = []
        ordered_idx = []
        ordered_y_hats = []
        
        for j in uniq_labels:
            labels_indexing = (labels == j)
            class_activations = activations[labels_indexing, :]
            class_mean = class_activations.mean(axis=0) - global_mean

            if normalization == 'class':
                class_std = class_activations.std(axis=0) + 0.0001
                class_activations = class_activations / repeat(class_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / class_std
                # print(class_std)

            elif normalization == 'global':
                global_std = activations.std(axis=0) + 0.0001
                class_activations = class_activations - repeat(global_mean, 'u -> n u', n = class_activations.shape[0])
                class_activations = class_activations / repeat(global_std, 'u -> n u', n = class_activations.shape[0])
                class_mean = class_mean / global_std

            if distance_metric == 'cosine':
                cos_dist = 1 - torch.nn.functional.cosine_similarity(class_activations, repeat(class_mean, 'u -> n u', n = class_activations.shape[0]))
                _ = [ordered_l2.append(diff) for diff in cos_dist]
            
            elif distance_metric == 'l2':
                class_diff = class_activations - repeat(class_mean, 'u -> n u', n = class_activations.shape[0])
                _ = [ordered_l2.append(torch.linalg.vector_norm(diff)) for diff in class_diff]
            
            _ = [ordered_labels.append(lab) for lab in labels[labels_indexing]]
            _ = [ordered_idx.append(ids) for ids in idx[labels_indexing]]
            _ = [ordered_y_hats.append(y_hat) for y_hat in y_hats[labels_indexing]]
        
        all_together = torch.vstack([torch.tensor(x) for x in [ordered_l2, ordered_labels, ordered_idx, ordered_y_hats]]).T
        df = pd.DataFrame(all_together, columns = [f'dist_{i}', f'labels', f'idx', f'y_hats_{i}'])
        if i != 0:
            df = df.drop(columns=['labels'])
        
        all_dfs.append(df)
    merged_df = reduce(lambda left, right: pd.merge(left, right, on='idx', how='outer'), all_dfs)
    return merged_df

In [17]:
def graph_hists(merged_df):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]
    num_cols = len(l2_columns)
    
    # Create a grid: 2 rows (general + label-wise), N columns (1 per l2_x)
    fig, axes = plt.subplots(2, num_cols, figsize=(5 * num_cols, 10), sharey='row')
    
    # Row 1: General histograms
    for i, col in enumerate(l2_columns):
        ax = axes[0, i]
        sns.histplot(merged_df[col], bins=30, kde=False, ax=ax, color='skyblue')
        ax.set_title(f"{col} - Overall")
        ax.set_xlabel("")
        ax.set_ylabel("Count")
    
    # Row 2: Label-wise histograms
    for i, col in enumerate(l2_columns):
        ax = axes[1, i]
        for label in sorted(merged_df['labels'].unique()):
            subset = merged_df[merged_df['labels'] == label]
            sns.histplot(subset[col], bins=30, kde=False, label=f"Label {label}", alpha=0.5, ax=ax)
        ax.set_title(f"{col} - By Label")
        ax.set_xlabel("Value")
        ax.set_ylabel("Count")
        ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Group by 'labels' and calculate mean and std
    grouped_stats = merged_df.groupby('labels')[l2_columns].agg(['mean', 'std'])
    
    # Optional: flatten MultiIndex columns
    grouped_stats.columns = [f"{col}_{stat}" for col, stat in grouped_stats.columns]
    
    # Display the result
    display(grouped_stats)

In [18]:
def top_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    top_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        top_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(quant)
            idxs = subset[subset[col] >= threshold]['idx'].to_numpy()
            top_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(top_idxs))
        top_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, top_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, top_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return top_10_percent_idxs, union_all, intersection_all

In [19]:
def bottom_x_percent_per_label(merged_df, percent):
    l2_columns = [col for col in merged_df.columns if col.startswith('dist_')]

    # List to store the final output
    bottom_10_percent_idxs = []
    quant = 1 - percent
    
    for col in l2_columns:
        bottom_idxs = []
        for label in merged_df['labels'].unique():
            subset = merged_df[merged_df['labels'] == label]
            threshold = subset[col].quantile(percent)
            idxs = subset[subset[col] <= threshold]['idx'].to_numpy()
            bottom_idxs.append(idxs)
        
        # Combine all labels' top idxs into a single array (e.g., union)
        combined = np.unique(np.concatenate(bottom_idxs))
        bottom_10_percent_idxs.append(combined)

    union_all = reduce(np.union1d, bottom_10_percent_idxs)
    intersection_all = reduce(np.intersect1d, bottom_10_percent_idxs)

    print(f'Jaccard Similarity: {len(intersection_all)/len(union_all)}')
    print(f'Size of Union: {len(union_all)}')
    print(f'Size of Intersection: {len(intersection_all)}')
    
    return bottom_10_percent_idxs, union_all, intersection_all

### CIFAR Sanity Lightning 1

#### Get Seed

In [6]:
run_path = "jmryan/emergence_redo/uhgk2tas"  # Example: "jmryan19/my-cool-project/abcd1234"

# Fetch the run
api = wandb.Api()
run = api.run(run_path)

# Access the config
config = run.config

In [7]:
config

{'base': False,
 'data': 'cifar10',
 'iter': 4,
 'loss': 'cross_entropy',
 'seed': 3721966775248150500,
 'model': 'resnet18',
 'prime': 67,
 'openai': True,
 'repeat': False,
 'dim_model': 256,
 'num_heads': 4,
 'num_steps': 100000,
 'batch_size': 64,
 'experiment': 'cifar_baseline_final_t',
 'max_epochs': 400,
 'num_layers': 2,
 'data_budget': 1000,
 'kind_custom': 'top',
 'custom_cifar': False,
 'train_budget': 5000,
 'weight_decay': 0.0005,
 'learning_rate': 0.001,
 'max_train_val': 67,
 'wandb_offline': False,
 'train_data_len': 50000,
 'fcn_hidden_width': 512,
 'kernel_bandwidth': 1,
 'training_fraction': 0.5,
 'plot_inner_products': False,
 'validation_data_len': 10000,
 'save_dist_class_means': True}

#### Get

In [7]:
sanity_path_0 = '../class_means_distance/cifar_sanity_no_shuffle_data_cifar10_iter_0_max_epochs_25/train/train_raw_activations_epoch_'
sanity_path_1 = '../class_means_distance/cifar_sanity_no_shuffle_data_cifar10_iter_1_max_epochs_25/train/train_raw_activations_epoch_'
lightweight_path = '../class_means_distance/cifar_lightweight_no_shuffle_data_cifar10_iter_0_max_epochs_25/train/train_raw_activations_epoch_'

In [8]:
epoch = 10
sanity_dat_0 = torch.load(sanity_path_0 + str(epoch), map_location=torch.device('cpu'), weights_only=True)
sanity_dat_1 = torch.load(sanity_path_1 + str(epoch), map_location=torch.device('cpu'), weights_only=True)
lightweight_dat = torch.load(lightweight_path + str(epoch), map_location=torch.device('cpu'), weights_only=True)

In [10]:
same_indices = True
for i in range(len(sanity_dat_0[:,-1])):
    if sanity_dat_0[i,-1] != lightweight_dat[i, -1]:
        same_indices = False
        break
print(same_indices)

True


In [11]:
same_labels = True
for i in range(len(sanity_dat_0[:,-3])):
    if sanity_dat_0[i,-3] != lightweight_dat[i, -3]:
        same_labels = False
        break
print(same_labels)

True


In [12]:
same_yhats = True
for i in range(len(sanity_dat_0[:,-2])):
    if sanity_dat_0[i,-2] != lightweight_dat[i, -2]:
        same_yhats = False
        break
print(same_yhats)

True


Using same seeding, simplified sanity model gives same outputs as original lightweight model.

### Sanity No Shuffle MNIST Resnet18 Seed 515

In [24]:
mnist_515 = 'class_means_distance/sanity_no_shuffle_data_mnist_resnet18_seed_515_iter_0_max_epochs_400/train/train_raw_activations_epoch_'

In [25]:
epoch = 5
mnist_515_data = torch.load(sanity_path_0 + str(epoch), map_location=torch.device('cpu'), weights_only=True)

In [26]:
merged_df = process_merge([mnist_515], 5)

In [27]:
merged_df

Unnamed: 0,dist_0,labels,idx,y_hats_0
0,29.665813,0.0,31202.0,0.0
1,28.374050,0.0,36749.0,0.0
2,27.017120,0.0,47156.0,0.0
3,28.395668,0.0,9717.0,0.0
4,27.514494,0.0,42752.0,0.0
...,...,...,...,...
59995,28.558506,9.0,43870.0,9.0
59996,24.637981,9.0,21378.0,9.0
59997,27.728691,9.0,30701.0,9.0
59998,26.201992,9.0,49226.0,9.0
