In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from vae_interp.sae import SAE
from vae_interp.dataset import NpyDataset
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from vae_interp.analysis import get_similar_features, get_activations_info, get_features_pca

In [3]:
# config_path = "./checkpoints/sae/sae_config.json"
# weights_path = "./checkpoints/sae/sae.pth"
sae = SAE.load_from_checkpoint("./checkpoints/sae_experiments2/lr=5.0e-04_l1=1.0e-03/sae_config.json",
                               "./checkpoints/sae_experiments2/lr=5.0e-04_l1=1.0e-03/sae.pth")

In [4]:
dataset = NpyDataset("./vae_embeddings.npy")

In [5]:
get_similar_features(sae, k=10)

SimilarFeatures(k=10, top_k_indices_per_feature=tensor([[287, 415,  76,  ..., 286, 372, 351],
        [ 88, 297, 279,  ...,  38, 503, 293],
        [130, 489,  36,  ..., 177, 464, 394],
        ...,
        [363, 423, 136,  ..., 168, 267, 281],
        [326,  30, 182,  ..., 244, 193, 386],
        [319, 354, 234,  ..., 475, 289, 128]]), top_k_cosine_sim_per_feature=tensor([[0.3491, 0.2737, 0.2525,  ..., 0.2058, 0.2057, 0.2028],
        [0.3528, 0.3357, 0.3184,  ..., 0.2759, 0.2707, 0.2610],
        [0.2859, 0.2828, 0.2743,  ..., 0.2353, 0.2328, 0.2305],
        ...,
        [0.4083, 0.3586, 0.3488,  ..., 0.2819, 0.2682, 0.2635],
        [0.4350, 0.4086, 0.3748,  ..., 0.3532, 0.3213, 0.3033],
        [0.4316, 0.3546, 0.3500,  ..., 0.2843, 0.2796, 0.2766]]))

In [6]:
# find topk similar features for each feature
k = 10
num_features = len(sae.features)
features_norm = sae.features / torch.linalg.norm(sae.features, dim=1, keepdim=True)
features_norm.shape

torch.Size([512, 64])

In [7]:
cosine_similarity = features_norm @ features_norm.T
topk_indices = torch.topk(cosine_similarity, k=k + 1, dim=1).indices
topk_cosine_sim = torch.topk(cosine_similarity, k=k + 1, dim=1).values
feature_indices = torch.arange(0, num_features).view(-1, 1)
feature_indices.shape, topk_indices.shape

(torch.Size([512, 1]), torch.Size([512, 11]))

In [8]:
topk = topk_indices[topk_indices != feature_indices].view(num_features, k)
topk

tensor([[287, 415,  76,  ..., 286, 372, 351],
        [ 88, 297, 279,  ...,  38, 503, 293],
        [130, 489,  36,  ..., 177, 464, 394],
        ...,
        [363, 423, 136,  ..., 168, 267, 281],
        [326,  30, 182,  ..., 244, 193, 386],
        [319, 354, 234,  ..., 475, 289, 128]])

In [9]:
topk_cosine_sim[topk_indices != feature_indices].view(num_features, k)

tensor([[0.3491, 0.2737, 0.2525,  ..., 0.2058, 0.2057, 0.2028],
        [0.3528, 0.3357, 0.3184,  ..., 0.2759, 0.2707, 0.2610],
        [0.2859, 0.2828, 0.2743,  ..., 0.2353, 0.2328, 0.2305],
        ...,
        [0.4083, 0.3586, 0.3488,  ..., 0.2819, 0.2682, 0.2635],
        [0.4350, 0.4086, 0.3748,  ..., 0.3532, 0.3213, 0.3033],
        [0.4316, 0.3546, 0.3500,  ..., 0.2843, 0.2796, 0.2766]])

In [10]:
# find topk activations
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
sparse_embeddings = None

for batch in tqdm(dataloader):
    batch_sparse_embeddings = sae.encode(batch)
    if sparse_embeddings is None:
        sparse_embeddings = batch_sparse_embeddings
    else:
        sparse_embeddings = torch.cat([sparse_embeddings, batch_sparse_embeddings], dim=0)

sparse_embeddings.shape # column is activation density

  0%|          | 0/780 [00:00<?, ?it/s]

100%|██████████| 780/780 [00:05<00:00, 136.29it/s]


torch.Size([49859, 512])

In [5]:
activations_info = get_activations_info(sae, dataset, batch_size=256, top_k=10, device="cuda")

 17%|█▋        | 34/195 [00:00<00:00, 334.38it/s]

100%|██████████| 195/195 [00:01<00:00, 180.45it/s]


In [6]:
activations_info.max_activation_per_feature

tensor([0.8164, 3.8927, 0.3528, 3.5169, 0.1126, 0.2878, 2.5195, 3.6785, 4.3074,
        5.6759, 0.6988, 0.0000, 0.0000, 3.3058, 0.0000, 3.5640, 0.0000, 3.3782,
        4.1229, 3.5375, 0.5085, 3.9410, 3.8771, 0.0000, 0.0000, 3.6403, 3.5874,
        3.3486, 3.6475, 3.9769, 5.4837, 0.0000, 4.0419, 3.6554, 4.4705, 0.0000,
        0.0000, 0.0000, 0.6897, 1.0008, 2.8607, 4.0323, 0.6343, 4.2257, 0.9271,
        0.0000, 3.1493, 4.1227, 4.0980, 0.0000, 3.4585, 1.0210, 4.5268, 0.0000,
        0.0000, 0.0000, 3.4571, 4.3208, 2.5995, 3.6250, 3.7554, 4.7817, 0.0000,
        5.0083, 0.0000, 4.3049, 0.0000, 0.0000, 2.8333, 3.2704, 0.6040, 0.0000,
        0.0000, 3.7971, 4.0668, 4.0322, 0.0000, 3.5896, 5.4781, 0.0000, 4.2766,
        6.1268, 0.2186, 3.8158, 0.4553, 0.4632, 0.0000, 4.8671, 1.3258, 0.0000,
        3.5555, 1.9127, 0.8776, 4.6681, 2.2922, 4.7673, 0.4356, 0.0000, 0.0000,
        3.0070, 1.4221, 3.3932, 3.3355, 0.0000, 4.5514, 0.0000, 3.1328, 5.1226,
        0.0000, 0.0000, 4.6807, 5.0369, 

In [12]:
get_features_pca(sae)

PCA!!!


In [52]:
# rows = sparse embedding of an image feature, cols i = activations for feature i

topk_image_activations = torch.topk(sparse_embeddings, k=k, dim=0).indices
topk_image_activations.shape

torch.Size([10, 512])

In [53]:
topk_image_activations = torch.topk(sparse_embeddings, k=k, dim=0).values
topk_image_activations

tensor([[0.0772, 0.0000, 0.0511,  ..., 0.0491, 0.0641, 0.0000],
        [0.0654, 0.0000, 0.0497,  ..., 0.0474, 0.0611, 0.0000],
        [0.0621, 0.0000, 0.0486,  ..., 0.0456, 0.0608, 0.0000],
        ...,
        [0.0582, 0.0000, 0.0462,  ..., 0.0426, 0.0601, 0.0000],
        [0.0579, 0.0000, 0.0460,  ..., 0.0424, 0.0601, 0.0000],
        [0.0575, 0.0000, 0.0454,  ..., 0.0423, 0.0600, 0.0000]],
       grad_fn=<TopkBackward0>)

In [57]:
binary_embeddings = (sparse_embeddings != 0).float()
total_activations = binary_embeddings.sum(dim=0)
activation_densities = total_activations / binary_embeddings.shape[0]
activation_densities

tensor([4.8126e-01, 0.0000e+00, 5.1279e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 4.9030e-01, 4.5017e-01, 4.8102e-01, 4.0113e-03, 4.4544e-01,
        5.5358e-01, 0.0000e+00, 0.0000e+00, 4.9247e-01, 5.0137e-01, 4.1317e-03,
        2.7173e-01, 0.0000e+00, 0.0000e+00, 4.0715e-03, 4.8870e-01, 4.7602e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 5.4634e-02, 1.8994e-02, 0.0000e+00,
        8.1430e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.3422e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 4.8336e-03, 4.3902e-01, 4.6307e-01,
        8.4639e-03, 3.7506e-03, 0.0000e+00, 4.9407e-01, 3.9311e-03, 3.1942e-01,
        4.7656e-01, 2.0337e-02, 3.3051e-01, 0.0000e+00, 8.9653e-03, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 4.3382e-02, 0.0000e+00, 0.0000e+00, 8.5642e-03,
        0.0000e+00, 4.4124e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3836e-03,
        0.0000e+00, 4.9688e-01, 6.0791e-02, 3.9913e-03, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [62]:
activation_densities[activation_densities == 0].shape[0]

277