In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from vae_interp.sae import SAE
from vae_interp.dataset import NpyDataset
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from vae_interp.analysis import get_similar_features, get_activations_info

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config_path = "./checkpoints/sae/sae_config.json"
weights_path = "./checkpoints/sae/sae.pth"
sae = SAE.load_from_checkpoint(config_path, weights_path)

In [4]:
dataset = NpyDataset("./vae_embeddings.npy")

In [5]:
get_similar_features(sae, k=10)

SimilarFeatures(k=10, top_k_indices_per_feature=tensor([[287, 415,  76,  ..., 286, 372, 351],
        [ 88, 297, 279,  ...,  38, 503, 293],
        [130, 489,  36,  ..., 177, 464, 394],
        ...,
        [363, 423, 136,  ..., 168, 267, 281],
        [326,  30, 182,  ..., 244, 193, 386],
        [319, 354, 234,  ..., 475, 289, 128]]), top_k_cosine_sim_per_feature=tensor([[0.3491, 0.2737, 0.2525,  ..., 0.2058, 0.2057, 0.2028],
        [0.3528, 0.3357, 0.3184,  ..., 0.2759, 0.2707, 0.2610],
        [0.2859, 0.2828, 0.2743,  ..., 0.2353, 0.2328, 0.2305],
        ...,
        [0.4083, 0.3586, 0.3488,  ..., 0.2819, 0.2682, 0.2635],
        [0.4350, 0.4086, 0.3748,  ..., 0.3532, 0.3213, 0.3033],
        [0.4316, 0.3546, 0.3500,  ..., 0.2843, 0.2796, 0.2766]]))

In [6]:
# find topk similar features for each feature
k = 10
num_features = len(sae.features)
features_norm = sae.features / torch.linalg.norm(sae.features, dim=1, keepdim=True)
features_norm.shape

torch.Size([512, 64])

In [7]:
cosine_similarity = features_norm @ features_norm.T
topk_indices = torch.topk(cosine_similarity, k=k + 1, dim=1).indices
topk_cosine_sim = torch.topk(cosine_similarity, k=k + 1, dim=1).values
feature_indices = torch.arange(0, num_features).view(-1, 1)
feature_indices.shape, topk_indices.shape

(torch.Size([512, 1]), torch.Size([512, 11]))

In [8]:
topk = topk_indices[topk_indices != feature_indices].view(num_features, k)
topk

tensor([[287, 415,  76,  ..., 286, 372, 351],
        [ 88, 297, 279,  ...,  38, 503, 293],
        [130, 489,  36,  ..., 177, 464, 394],
        ...,
        [363, 423, 136,  ..., 168, 267, 281],
        [326,  30, 182,  ..., 244, 193, 386],
        [319, 354, 234,  ..., 475, 289, 128]])

In [9]:
topk_cosine_sim[topk_indices != feature_indices].view(num_features, k)

tensor([[0.3491, 0.2737, 0.2525,  ..., 0.2058, 0.2057, 0.2028],
        [0.3528, 0.3357, 0.3184,  ..., 0.2759, 0.2707, 0.2610],
        [0.2859, 0.2828, 0.2743,  ..., 0.2353, 0.2328, 0.2305],
        ...,
        [0.4083, 0.3586, 0.3488,  ..., 0.2819, 0.2682, 0.2635],
        [0.4350, 0.4086, 0.3748,  ..., 0.3532, 0.3213, 0.3033],
        [0.4316, 0.3546, 0.3500,  ..., 0.2843, 0.2796, 0.2766]])

In [10]:
# find topk activations
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
sparse_embeddings = None

for batch in tqdm(dataloader):
    batch_sparse_embeddings = sae.encode(batch)
    if sparse_embeddings is None:
        sparse_embeddings = batch_sparse_embeddings
    else:
        sparse_embeddings = torch.cat([sparse_embeddings, batch_sparse_embeddings], dim=0)

sparse_embeddings.shape # column is activation density

  0%|          | 0/780 [00:00<?, ?it/s]

100%|██████████| 780/780 [00:05<00:00, 136.29it/s]


torch.Size([49859, 512])

In [11]:
get_activations_info(sae, dataset, batch_size=256, top_k=10, device="cuda")

 14%|█▍        | 28/195 [00:00<00:00, 279.84it/s]

100%|██████████| 195/195 [00:01<00:00, 133.28it/s]


ActivationsInfo(activations_per_feature=tensor([[0.0015, 0.0192, 0.0374,  ..., 0.0000, 0.0031, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0085, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0048, 0.0250, 0.0157,  ..., 0.0230, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]), activation_density_per_feature=tensor([4.6865e+01, 0.0000e+00, 4.9936e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 4.7746e+01, 4.3838e+01, 4.6842e+01, 3.9062e-01, 4.3377e+01,
        5.3908e+01, 0.0000e+00, 0.0000e+00, 4.7957e+01, 4.8824e+01, 4.0234e-01,
        2.6461e+01, 0.0000e+00, 0.0000e+00, 3.9648e-01, 4.7590e+01, 4.6355e+01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 5.3203e+00, 1.8496e+00, 0.0000e+00,
        7.9297e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2285e+01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 4.7070

In [52]:
# rows = sparse embedding of an image feature, cols i = activations for feature i

topk_image_activations = torch.topk(sparse_embeddings, k=k, dim=0).indices
topk_image_activations.shape

torch.Size([10, 512])

In [53]:
topk_image_activations = torch.topk(sparse_embeddings, k=k, dim=0).values
topk_image_activations

tensor([[0.0772, 0.0000, 0.0511,  ..., 0.0491, 0.0641, 0.0000],
        [0.0654, 0.0000, 0.0497,  ..., 0.0474, 0.0611, 0.0000],
        [0.0621, 0.0000, 0.0486,  ..., 0.0456, 0.0608, 0.0000],
        ...,
        [0.0582, 0.0000, 0.0462,  ..., 0.0426, 0.0601, 0.0000],
        [0.0579, 0.0000, 0.0460,  ..., 0.0424, 0.0601, 0.0000],
        [0.0575, 0.0000, 0.0454,  ..., 0.0423, 0.0600, 0.0000]],
       grad_fn=<TopkBackward0>)

In [57]:
binary_embeddings = (sparse_embeddings != 0).float()
total_activations = binary_embeddings.sum(dim=0)
activation_densities = total_activations / binary_embeddings.shape[0]
activation_densities

tensor([4.8126e-01, 0.0000e+00, 5.1279e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 4.9030e-01, 4.5017e-01, 4.8102e-01, 4.0113e-03, 4.4544e-01,
        5.5358e-01, 0.0000e+00, 0.0000e+00, 4.9247e-01, 5.0137e-01, 4.1317e-03,
        2.7173e-01, 0.0000e+00, 0.0000e+00, 4.0715e-03, 4.8870e-01, 4.7602e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 5.4634e-02, 1.8994e-02, 0.0000e+00,
        8.1430e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.3422e-01,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 4.8336e-03, 4.3902e-01, 4.6307e-01,
        8.4639e-03, 3.7506e-03, 0.0000e+00, 4.9407e-01, 3.9311e-03, 3.1942e-01,
        4.7656e-01, 2.0337e-02, 3.3051e-01, 0.0000e+00, 8.9653e-03, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 4.3382e-02, 0.0000e+00, 0.0000e+00, 8.5642e-03,
        0.0000e+00, 4.4124e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3836e-03,
        0.0000e+00, 4.9688e-01, 6.0791e-02, 3.9913e-03, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [62]:
activation_densities[activation_densities == 0].shape[0]

277